| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 图片搜索服务
- 功能:批量计算图片特征向量、构建索引、搜索相似图片
- """
- import os
- import json
- import numpy as np
- import faiss
- import base64
- import time
- from typing import List, Dict, Optional, Tuple
- from utils.logger import chat_logger
- # 导入dashscope库,必须安装
- try:
- import dashscope
- from dashscope import MultiModalEmbedding
- from http import HTTPStatus
- DASHSCOPE_AVAILABLE = True
- except ImportError:
- chat_logger.error("dashscope库未安装,请使用 pip install dashscope 安装")
- raise ImportError("dashscope库未安装,请使用 pip install dashscope 安装")
- class ImageSearchService:
- def __init__(self):
- """初始化图片搜索服务"""
- self.index = None # FAISS索引
- self.image_mapping = {} # 图片ID到图片信息的映射
- self.last_build_time = None # 最后构建索引的时间
- self.dimension = 768 # 默认特征向量维度(与tongyi-embedding-vision-plus模型一致)
-
- # 配置参数
- self.config = {
- "model_name": "tongyi-embedding-vision-plus",
- "retry_count": 3,
- "batch_size": 10
- }
-
- chat_logger.info("图片搜索服务初始化完成")
-
- async def calculate_vector(self, image_bytes: Optional[bytes] = None, text: Optional[str] = None) -> Optional[List[float]]:
- """计算图片或文本的特征向量
-
- Args:
- image_bytes: 图片字节数据(可选)
- text: 文本数据(可选)
-
- Returns:
- 特征向量列表,如果失败则返回None
- """
- try:
- # 检查API Key
- if not dashscope.api_key and 'DASHSCOPE_API_KEY' not in os.environ:
- chat_logger.error("请配置 DASHSCOPE_API_KEY 环境变量")
- return None
-
- # 构建input_data
- input_data = []
- if image_bytes:
- # 以图搜图
- image_base64 = base64.b64encode(image_bytes).decode('utf-8')
- image_data = f"data:image/jpeg;base64,{image_base64}"
- input_data.append({"image": image_data})
- elif text:
- # 以文搜图
- input_data.append({"text": text})
- else:
- chat_logger.error("请提供图片或文本数据")
- return None
-
- # 调用模型接口
- for attempt in range(self.config["retry_count"]):
- try:
- resp = dashscope.MultiModalEmbedding.call(
- model=self.config["model_name"],
- input=input_data
- )
-
- if resp.status_code == HTTPStatus.OK:
- embedding = resp.output["embeddings"][0]["embedding"]
- return embedding
- else:
- chat_logger.error(f"API调用失败: {resp.message} (状态码: {resp.status_code})")
- if attempt < self.config["retry_count"] - 1:
- chat_logger.info(f"{self.config['retry_count'] - attempt - 1}秒后重试...")
- time.sleep(1)
- else:
- return None
- except Exception as e:
- chat_logger.error(f"处理数据时出错: {str(e)}")
- if attempt < self.config["retry_count"] - 1:
- chat_logger.info(f"{self.config['retry_count'] - attempt - 1}秒后重试...")
- time.sleep(1)
- else:
- return None
-
- except Exception as e:
- chat_logger.error(f"计算特征向量失败: {str(e)}")
- return None
-
- async def batch_calculate_vectors(self, image_items: List[Dict]) -> List[Dict]:
- """批量计算图片特征向量
-
- Args:
- image_items: 图片项列表,每个项包含image(base64编码)和image_id
-
- Returns:
- 包含特征向量的图片项列表
- """
- results = []
-
- for item in image_items:
- try:
- # 解码base64图片
- if "," in item["image"]:
- base64_str = item["image"].split(",", 1)[1]
- else:
- base64_str = item["image"]
-
- image_bytes = base64.b64decode(base64_str)
-
- # 计算特征向量
- vector = await self.calculate_vector(image_bytes)
-
- if vector:
- results.append({
- "image_id": item.get("image_id"),
- "vector": vector,
- "success": True
- })
- else:
- results.append({
- "image_id": item.get("image_id"),
- "vector": None,
- "success": False,
- "error": "计算特征向量失败"
- })
-
- except Exception as e:
- chat_logger.error(f"处理图片失败 (ID: {item.get('image_id')}): {str(e)}")
- results.append({
- "image_id": item.get("image_id"),
- "vector": None,
- "success": False,
- "error": str(e)
- })
-
- return results
-
- async def build_index(self, image_vectors: List[Dict]) -> int:
- """构建索引及映射关系
-
- Args:
- image_vectors: 图片向量列表,每个项包含image_id、vector、image_name、image_path等
-
- Returns:
- 索引的图片数量
- """
- try:
- if not image_vectors:
- chat_logger.warning("没有图片向量需要索引")
- return 0
-
- # 提取向量和图片信息
- vectors = []
- self.image_mapping = {}
-
- for item in image_vectors:
- if "vector" in item and item["vector"]:
- vectors.append(item["vector"])
- self.image_mapping[item["image_id"]] = {
- "image_name": item.get("image_name"),
- "image_path": item.get("image_path")
- }
-
- if not vectors:
- chat_logger.warning("没有有效的向量数据")
- return 0
-
- # 转换为numpy数组
- vectors_np = np.array(vectors, dtype=np.float32)
-
- # 更新维度信息
- self.dimension = vectors_np.shape[1]
-
- # 创建FAISS索引
- self.index = faiss.IndexFlatIP(self.dimension) # 使用内积相似度
-
- # 归一化向量
- faiss.normalize_L2(vectors_np)
-
- # 添加向量到索引
- self.index.add(vectors_np)
-
- self.last_build_time = time.strftime('%Y-%m-%d %H:%M:%S')
-
- indexed_count = len(vectors)
- chat_logger.info(f"成功构建索引,索引了 {indexed_count} 个图片向量")
-
- return indexed_count
-
- except Exception as e:
- chat_logger.error(f"构建索引失败: {str(e)}")
- return 0
-
- async def search(self, image_bytes: Optional[bytes] = None, text: Optional[str] = None,
- top_k: int = 10) -> List[Dict]:
- """搜索相似图片
-
- Args:
- image_bytes: 图片字节数据(以图搜图)
- text: 文字描述(以文搜图)
- top_k: 返回结果数量
-
- Returns:
- 搜索结果列表
- """
- try:
- if self.index is None:
- chat_logger.warning("索引未构建,请先构建索引")
- return []
-
- # 计算查询向量
- query_vector = await self.calculate_vector(image_bytes=image_bytes, text=text)
- if not query_vector:
- chat_logger.error("无法获取查询特征向量")
- return []
-
- # 执行搜索
- if query_vector:
- # 转换为numpy数组并归一化
- query_vector_np = np.array([query_vector], dtype=np.float32)
- faiss.normalize_L2(query_vector_np)
-
- # 搜索
- distances, indices = self.index.search(query_vector_np, top_k)
-
- # 构建结果
- results = []
- for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
- # 计算相似度(内积转换为0-1范围)
- similarity = float(distance)
-
- # 直接添加结果,不使用阈值过滤
- # 获取图片ID
- image_id = list(self.image_mapping.keys())[idx]
- image_info = self.image_mapping.get(image_id, {})
-
- results.append({
- "image_id": image_id,
- "similarity": similarity,
- "image_name": image_info.get("image_name"),
- "image_path": image_info.get("image_path")
- })
-
- chat_logger.info(f"搜索完成,找到 {len(results)} 个相似图片")
- return results
- else:
- return []
-
- except Exception as e:
- chat_logger.error(f"搜索失败: {str(e)}")
- return []
-
- async def get_index_status(self) -> Dict:
- """获取索引状态
-
- Returns:
- 索引状态信息
- """
- if self.index is None:
- return {
- "indexed_count": 0,
- "index_size": "0KB",
- "last_build_time": None,
- "dimension": self.dimension
- }
-
- indexed_count = self.index.ntotal
- # 估算索引大小
- index_size = f"{indexed_count * self.dimension * 4 / 1024:.2f}KB" # 每个float32占4字节
-
- return {
- "indexed_count": indexed_count,
- "index_size": index_size,
- "last_build_time": self.last_build_time,
- "dimension": self.dimension
- }
-
- async def clear_index(self):
- """清空索引"""
- try:
- self.index = None
- self.image_mapping = {}
- self.last_build_time = None
- chat_logger.info("索引已清空")
- except Exception as e:
- chat_logger.error(f"清空索引失败: {str(e)}")
- # 创建服务实例
- image_search_service = ImageSearchService()
|