#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 图片搜索服务 功能:批量计算图片特征向量、构建索引、搜索相似图片 """ import os import json import numpy as np import faiss import base64 import time from typing import List, Dict, Optional, Tuple from utils.logger import chat_logger # 导入dashscope库,必须安装 try: import dashscope from dashscope import MultiModalEmbedding from http import HTTPStatus DASHSCOPE_AVAILABLE = True except ImportError: chat_logger.error("dashscope库未安装,请使用 pip install dashscope 安装") raise ImportError("dashscope库未安装,请使用 pip install dashscope 安装") class ImageSearchService: def __init__(self): """初始化图片搜索服务""" self.index = None # FAISS索引 self.image_mapping = {} # 图片ID到图片信息的映射 self.last_build_time = None # 最后构建索引的时间 self.dimension = 768 # 默认特征向量维度(与tongyi-embedding-vision-plus模型一致) # 配置参数 self.config = { "model_name": "tongyi-embedding-vision-plus", "retry_count": 3, "batch_size": 10 } chat_logger.info("图片搜索服务初始化完成") async def calculate_vector(self, image_bytes: Optional[bytes] = None, text: Optional[str] = None) -> Optional[List[float]]: """计算图片或文本的特征向量 Args: image_bytes: 图片字节数据(可选) text: 文本数据(可选) Returns: 特征向量列表,如果失败则返回None """ try: # 检查API Key if not dashscope.api_key and 'DASHSCOPE_API_KEY' not in os.environ: chat_logger.error("请配置 DASHSCOPE_API_KEY 环境变量") return None # 构建input_data input_data = [] if image_bytes: # 以图搜图 image_base64 = base64.b64encode(image_bytes).decode('utf-8') image_data = f"data:image/jpeg;base64,{image_base64}" input_data.append({"image": image_data}) elif text: # 以文搜图 input_data.append({"text": text}) else: chat_logger.error("请提供图片或文本数据") return None # 调用模型接口 for attempt in range(self.config["retry_count"]): try: resp = dashscope.MultiModalEmbedding.call( model=self.config["model_name"], input=input_data ) if resp.status_code == HTTPStatus.OK: embedding = resp.output["embeddings"][0]["embedding"] return embedding else: chat_logger.error(f"API调用失败: {resp.message} (状态码: {resp.status_code})") if attempt < self.config["retry_count"] - 1: chat_logger.info(f"{self.config['retry_count'] - attempt - 1}秒后重试...") time.sleep(1) else: return None except Exception as e: chat_logger.error(f"处理数据时出错: {str(e)}") if attempt < self.config["retry_count"] - 1: chat_logger.info(f"{self.config['retry_count'] - attempt - 1}秒后重试...") time.sleep(1) else: return None except Exception as e: chat_logger.error(f"计算特征向量失败: {str(e)}") return None async def batch_calculate_vectors(self, image_items: List[Dict]) -> List[Dict]: """批量计算图片特征向量 Args: image_items: 图片项列表,每个项包含image(base64编码)和image_id Returns: 包含特征向量的图片项列表 """ results = [] for item in image_items: try: # 解码base64图片 if "," in item["image"]: base64_str = item["image"].split(",", 1)[1] else: base64_str = item["image"] image_bytes = base64.b64decode(base64_str) # 计算特征向量 vector = await self.calculate_vector(image_bytes) if vector: results.append({ "image_id": item.get("image_id"), "vector": vector, "success": True }) else: results.append({ "image_id": item.get("image_id"), "vector": None, "success": False, "error": "计算特征向量失败" }) except Exception as e: chat_logger.error(f"处理图片失败 (ID: {item.get('image_id')}): {str(e)}") results.append({ "image_id": item.get("image_id"), "vector": None, "success": False, "error": str(e) }) return results async def build_index(self, image_vectors: List[Dict]) -> int: """构建索引及映射关系 Args: image_vectors: 图片向量列表,每个项包含image_id、vector、image_name、image_path等 Returns: 索引的图片数量 """ try: if not image_vectors: chat_logger.warning("没有图片向量需要索引") return 0 # 提取向量和图片信息 vectors = [] self.image_mapping = {} for item in image_vectors: if "vector" in item and item["vector"]: vectors.append(item["vector"]) self.image_mapping[item["image_id"]] = { "image_name": item.get("image_name"), "image_path": item.get("image_path") } if not vectors: chat_logger.warning("没有有效的向量数据") return 0 # 转换为numpy数组 vectors_np = np.array(vectors, dtype=np.float32) # 更新维度信息 self.dimension = vectors_np.shape[1] # 创建FAISS索引 self.index = faiss.IndexFlatIP(self.dimension) # 使用内积相似度 # 归一化向量 faiss.normalize_L2(vectors_np) # 添加向量到索引 self.index.add(vectors_np) self.last_build_time = time.strftime('%Y-%m-%d %H:%M:%S') indexed_count = len(vectors) chat_logger.info(f"成功构建索引,索引了 {indexed_count} 个图片向量") return indexed_count except Exception as e: chat_logger.error(f"构建索引失败: {str(e)}") return 0 async def search(self, image_bytes: Optional[bytes] = None, text: Optional[str] = None, top_k: int = 10) -> List[Dict]: """搜索相似图片 Args: image_bytes: 图片字节数据(以图搜图) text: 文字描述(以文搜图) top_k: 返回结果数量 Returns: 搜索结果列表 """ try: if self.index is None: chat_logger.warning("索引未构建,请先构建索引") return [] # 计算查询向量 query_vector = await self.calculate_vector(image_bytes=image_bytes, text=text) if not query_vector: chat_logger.error("无法获取查询特征向量") return [] # 执行搜索 if query_vector: # 转换为numpy数组并归一化 query_vector_np = np.array([query_vector], dtype=np.float32) faiss.normalize_L2(query_vector_np) # 搜索 distances, indices = self.index.search(query_vector_np, top_k) # 构建结果 results = [] for i, (distance, idx) in enumerate(zip(distances[0], indices[0])): # 计算相似度(内积转换为0-1范围) similarity = float(distance) # 直接添加结果,不使用阈值过滤 # 获取图片ID image_id = list(self.image_mapping.keys())[idx] image_info = self.image_mapping.get(image_id, {}) results.append({ "image_id": image_id, "similarity": similarity, "image_name": image_info.get("image_name"), "image_path": image_info.get("image_path") }) chat_logger.info(f"搜索完成,找到 {len(results)} 个相似图片") return results else: return [] except Exception as e: chat_logger.error(f"搜索失败: {str(e)}") return [] async def get_index_status(self) -> Dict: """获取索引状态 Returns: 索引状态信息 """ if self.index is None: return { "indexed_count": 0, "index_size": "0KB", "last_build_time": None, "dimension": self.dimension } indexed_count = self.index.ntotal # 估算索引大小 index_size = f"{indexed_count * self.dimension * 4 / 1024:.2f}KB" # 每个float32占4字节 return { "indexed_count": indexed_count, "index_size": index_size, "last_build_time": self.last_build_time, "dimension": self.dimension } async def clear_index(self): """清空索引""" try: self.index = None self.image_mapping = {} self.last_build_time = None chat_logger.info("索引已清空") except Exception as e: chat_logger.error(f"清空索引失败: {str(e)}") # 创建服务实例 image_search_service = ImageSearchService()