image_search_service.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 图片搜索服务
  5. 功能:批量计算图片特征向量、构建索引、搜索相似图片
  6. """
  7. import os
  8. import json
  9. import numpy as np
  10. import faiss
  11. import base64
  12. import time
  13. from typing import List, Dict, Optional, Tuple
  14. from utils.logger import chat_logger
  15. # 导入dashscope库,必须安装
  16. try:
  17. import dashscope
  18. from dashscope import MultiModalEmbedding
  19. from http import HTTPStatus
  20. DASHSCOPE_AVAILABLE = True
  21. except ImportError:
  22. chat_logger.error("dashscope库未安装,请使用 pip install dashscope 安装")
  23. raise ImportError("dashscope库未安装,请使用 pip install dashscope 安装")
  24. class ImageSearchService:
  25. def __init__(self):
  26. """初始化图片搜索服务"""
  27. self.index = None # FAISS索引
  28. self.image_mapping = {} # 图片ID到图片信息的映射
  29. self.last_build_time = None # 最后构建索引的时间
  30. self.dimension = 768 # 默认特征向量维度(与tongyi-embedding-vision-plus模型一致)
  31. # 配置参数
  32. self.config = {
  33. "model_name": "tongyi-embedding-vision-plus",
  34. "retry_count": 3,
  35. "batch_size": 10
  36. }
  37. chat_logger.info("图片搜索服务初始化完成")
  38. async def calculate_vector(self, image_bytes: Optional[bytes] = None, text: Optional[str] = None) -> Optional[List[float]]:
  39. """计算图片或文本的特征向量
  40. Args:
  41. image_bytes: 图片字节数据(可选)
  42. text: 文本数据(可选)
  43. Returns:
  44. 特征向量列表,如果失败则返回None
  45. """
  46. try:
  47. # 检查API Key
  48. if not dashscope.api_key and 'DASHSCOPE_API_KEY' not in os.environ:
  49. chat_logger.error("请配置 DASHSCOPE_API_KEY 环境变量")
  50. return None
  51. # 构建input_data
  52. input_data = []
  53. if image_bytes:
  54. # 以图搜图
  55. image_base64 = base64.b64encode(image_bytes).decode('utf-8')
  56. image_data = f"data:image/jpeg;base64,{image_base64}"
  57. input_data.append({"image": image_data})
  58. elif text:
  59. # 以文搜图
  60. input_data.append({"text": text})
  61. else:
  62. chat_logger.error("请提供图片或文本数据")
  63. return None
  64. # 调用模型接口
  65. for attempt in range(self.config["retry_count"]):
  66. try:
  67. resp = dashscope.MultiModalEmbedding.call(
  68. model=self.config["model_name"],
  69. input=input_data
  70. )
  71. if resp.status_code == HTTPStatus.OK:
  72. embedding = resp.output["embeddings"][0]["embedding"]
  73. return embedding
  74. else:
  75. chat_logger.error(f"API调用失败: {resp.message} (状态码: {resp.status_code})")
  76. if attempt < self.config["retry_count"] - 1:
  77. chat_logger.info(f"{self.config['retry_count'] - attempt - 1}秒后重试...")
  78. time.sleep(1)
  79. else:
  80. return None
  81. except Exception as e:
  82. chat_logger.error(f"处理数据时出错: {str(e)}")
  83. if attempt < self.config["retry_count"] - 1:
  84. chat_logger.info(f"{self.config['retry_count'] - attempt - 1}秒后重试...")
  85. time.sleep(1)
  86. else:
  87. return None
  88. except Exception as e:
  89. chat_logger.error(f"计算特征向量失败: {str(e)}")
  90. return None
  91. async def batch_calculate_vectors(self, image_items: List[Dict]) -> List[Dict]:
  92. """批量计算图片特征向量
  93. Args:
  94. image_items: 图片项列表,每个项包含image(base64编码)和image_id
  95. Returns:
  96. 包含特征向量的图片项列表
  97. """
  98. results = []
  99. for item in image_items:
  100. try:
  101. # 解码base64图片
  102. if "," in item["image"]:
  103. base64_str = item["image"].split(",", 1)[1]
  104. else:
  105. base64_str = item["image"]
  106. image_bytes = base64.b64decode(base64_str)
  107. # 计算特征向量
  108. vector = await self.calculate_vector(image_bytes)
  109. if vector:
  110. results.append({
  111. "image_id": item.get("image_id"),
  112. "vector": vector,
  113. "success": True
  114. })
  115. else:
  116. results.append({
  117. "image_id": item.get("image_id"),
  118. "vector": None,
  119. "success": False,
  120. "error": "计算特征向量失败"
  121. })
  122. except Exception as e:
  123. chat_logger.error(f"处理图片失败 (ID: {item.get('image_id')}): {str(e)}")
  124. results.append({
  125. "image_id": item.get("image_id"),
  126. "vector": None,
  127. "success": False,
  128. "error": str(e)
  129. })
  130. return results
  131. async def build_index(self, image_vectors: List[Dict]) -> int:
  132. """构建索引及映射关系
  133. Args:
  134. image_vectors: 图片向量列表,每个项包含image_id、vector、image_name、image_path等
  135. Returns:
  136. 索引的图片数量
  137. """
  138. try:
  139. if not image_vectors:
  140. chat_logger.warning("没有图片向量需要索引")
  141. return 0
  142. # 提取向量和图片信息
  143. vectors = []
  144. self.image_mapping = {}
  145. for item in image_vectors:
  146. if "vector" in item and item["vector"]:
  147. vectors.append(item["vector"])
  148. self.image_mapping[item["image_id"]] = {
  149. "image_name": item.get("image_name"),
  150. "image_path": item.get("image_path")
  151. }
  152. if not vectors:
  153. chat_logger.warning("没有有效的向量数据")
  154. return 0
  155. # 转换为numpy数组
  156. vectors_np = np.array(vectors, dtype=np.float32)
  157. # 更新维度信息
  158. self.dimension = vectors_np.shape[1]
  159. # 创建FAISS索引
  160. self.index = faiss.IndexFlatIP(self.dimension) # 使用内积相似度
  161. # 归一化向量
  162. faiss.normalize_L2(vectors_np)
  163. # 添加向量到索引
  164. self.index.add(vectors_np)
  165. self.last_build_time = time.strftime('%Y-%m-%d %H:%M:%S')
  166. indexed_count = len(vectors)
  167. chat_logger.info(f"成功构建索引,索引了 {indexed_count} 个图片向量")
  168. return indexed_count
  169. except Exception as e:
  170. chat_logger.error(f"构建索引失败: {str(e)}")
  171. return 0
  172. async def search(self, image_bytes: Optional[bytes] = None, text: Optional[str] = None,
  173. top_k: int = 10) -> List[Dict]:
  174. """搜索相似图片
  175. Args:
  176. image_bytes: 图片字节数据(以图搜图)
  177. text: 文字描述(以文搜图)
  178. top_k: 返回结果数量
  179. Returns:
  180. 搜索结果列表
  181. """
  182. try:
  183. if self.index is None:
  184. chat_logger.warning("索引未构建,请先构建索引")
  185. return []
  186. # 计算查询向量
  187. query_vector = await self.calculate_vector(image_bytes=image_bytes, text=text)
  188. if not query_vector:
  189. chat_logger.error("无法获取查询特征向量")
  190. return []
  191. # 执行搜索
  192. if query_vector:
  193. # 转换为numpy数组并归一化
  194. query_vector_np = np.array([query_vector], dtype=np.float32)
  195. faiss.normalize_L2(query_vector_np)
  196. # 搜索
  197. distances, indices = self.index.search(query_vector_np, top_k)
  198. # 构建结果
  199. results = []
  200. for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
  201. # 计算相似度(内积转换为0-1范围)
  202. similarity = float(distance)
  203. # 直接添加结果,不使用阈值过滤
  204. # 获取图片ID
  205. image_id = list(self.image_mapping.keys())[idx]
  206. image_info = self.image_mapping.get(image_id, {})
  207. results.append({
  208. "image_id": image_id,
  209. "similarity": similarity,
  210. "image_name": image_info.get("image_name"),
  211. "image_path": image_info.get("image_path")
  212. })
  213. chat_logger.info(f"搜索完成,找到 {len(results)} 个相似图片")
  214. return results
  215. else:
  216. return []
  217. except Exception as e:
  218. chat_logger.error(f"搜索失败: {str(e)}")
  219. return []
  220. async def get_index_status(self) -> Dict:
  221. """获取索引状态
  222. Returns:
  223. 索引状态信息
  224. """
  225. if self.index is None:
  226. return {
  227. "indexed_count": 0,
  228. "index_size": "0KB",
  229. "last_build_time": None,
  230. "dimension": self.dimension
  231. }
  232. indexed_count = self.index.ntotal
  233. # 估算索引大小
  234. index_size = f"{indexed_count * self.dimension * 4 / 1024:.2f}KB" # 每个float32占4字节
  235. return {
  236. "indexed_count": indexed_count,
  237. "index_size": index_size,
  238. "last_build_time": self.last_build_time,
  239. "dimension": self.dimension
  240. }
  241. async def clear_index(self):
  242. """清空索引"""
  243. try:
  244. self.index = None
  245. self.image_mapping = {}
  246. self.last_build_time = None
  247. chat_logger.info("索引已清空")
  248. except Exception as e:
  249. chat_logger.error(f"清空索引失败: {str(e)}")
  250. # 创建服务实例
  251. image_search_service = ImageSearchService()