#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 测试图片搜索API 包含以下测试用例: 1. 批量生成向量,断言图1与图2的向量不同 2. 再次生成向量,断言与第一步生成的向量一致 3. 构建索引API和获取索引状态 4. 以图搜图API,断言第一个结果是本图 5. 以文本搜图,搜索"胡桃木方",断言第一个结果是3号图 """ import os import base64 import json import requests import numpy as np # 配置 API_BASE_URL = "http://localhost:8888" IMAGE_DIR = "d:\\天翼云盘\\VS心流\\AI图片识别搜图修图理论与实践\\GME-Qwen2-VL-2B\\product_images_full" MAX_IMAGES = 10 # API端点 ENDPOINTS = { "batch_vector": "/image/vector/batch", "build_index": "/image/index/build", "index_status": "/image/index/status", "search": "/image/search" } class ImageSearchAPITest: def __init__(self): """初始化测试类""" self.image_files = [] self.first_vectors = {} # 第一次生成的向量 self.second_vectors = {} # 第二次生成的向量 self.session = requests.Session() self.session.timeout = 60 def get_image_files(self, max_count=10): """获取图片文件列表""" image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif'] image_files = [] for file in os.listdir(IMAGE_DIR): if len(image_files) >= max_count: break file_ext = os.path.splitext(file)[1].lower() if file_ext in image_extensions: image_files.append(os.path.join(IMAGE_DIR, file)) self.image_files = image_files # 输出文件列表供参考 print("获取到的图片文件列表:") for i, img_file in enumerate(image_files): print(f"{i+1}. {os.path.basename(img_file)}") print(f" 路径: {img_file}") return image_files def image_to_base64(self, image_path): """将图片转换为Base64编码""" try: with open(image_path, 'rb') as f: image_bytes = f.read() base64_str = base64.b64encode(image_bytes).decode('utf-8') ext = os.path.splitext(image_path)[1].lower().replace('.', '') data_url = f"data:image/{ext};base64,{base64_str}" return data_url except Exception as e: print(f"处理图片失败 {image_path}: {e}") return None def test_batch_vector_generation(self): """测试1:批量生成向量,断言图1与图2的向量不同""" print("=== 测试1: 批量生成向量 ===") # 获取图片文件 if not self.image_files: self.get_image_files(MAX_IMAGES) # 构建请求体 requests_list = [] for i, img_file in enumerate(self.image_files): base64_data = self.image_to_base64(img_file) if base64_data: request_item = { "image": base64_data, "image_id": str(i+1) } requests_list.append(request_item) # 发送请求 url = f"{API_BASE_URL}{ENDPOINTS['batch_vector']}" response = self.session.post( url, json=requests_list, headers={"Content-Type": "application/json"} ) assert response.status_code == 200, f"请求失败,状态码: {response.status_code}" # 解析响应 result = response.json() assert len(result) == len(requests_list), f"响应数量不匹配: {len(result)} != {len(requests_list)}" # 存储向量 for i, item in enumerate(result): assert item.get('success') is True, f"图片 {i+1} 向量生成失败: {item.get('error')}" image_id = item.get('image_id') vector = item.get('vector') assert vector is not None, f"图片 {image_id} 没有返回向量" self.first_vectors[image_id] = vector # 断言图1与图2的向量不同 assert '1' in self.first_vectors, "图1向量不存在" assert '2' in self.first_vectors, "图2向量不存在" vector1 = self.first_vectors['1'] vector2 = self.first_vectors['2'] assert vector1 != vector2, "图1与图2的向量相同,不符合预期" print(f"✓ 成功生成 {len(self.first_vectors)} 个向量") print(f"✓ 图1与图2的向量不同") def test_vector_consistency(self): """测试2:再次生成向量,断言与第一步生成的向量一致""" print("\n=== 测试2: 向量一致性测试 ===") # 构建请求体 requests_list = [] for i, img_file in enumerate(self.image_files): base64_data = self.image_to_base64(img_file) if base64_data: request_item = { "image": base64_data, "image_id": str(i+1) } requests_list.append(request_item) # 发送请求 url = f"{API_BASE_URL}{ENDPOINTS['batch_vector']}" response = self.session.post( url, json=requests_list, headers={"Content-Type": "application/json"} ) assert response.status_code == 200, f"请求失败,状态码: {response.status_code}" # 解析响应 result = response.json() assert len(result) == len(requests_list), f"响应数量不匹配: {len(result)} != {len(requests_list)}" # 存储向量 for i, item in enumerate(result): assert item.get('success') is True, f"图片 {i+1} 向量生成失败: {item.get('error')}" image_id = item.get('image_id') vector = item.get('vector') assert vector is not None, f"图片 {image_id} 没有返回向量" self.second_vectors[image_id] = vector # 断言向量一致性 for image_id, first_vector in self.first_vectors.items(): assert image_id in self.second_vectors, f"图片 {image_id} 第二次向量不存在" second_vector = self.second_vectors[image_id] # 计算向量相似度 similarity = np.dot(first_vector, second_vector) / ( np.linalg.norm(first_vector) * np.linalg.norm(second_vector) ) # 断言向量接近(相似度大于0.99) assert similarity > 0.99, f"图片 {image_id} 向量不一致,相似度: {similarity:.4f}" print(f"✓ 成功验证 {len(self.second_vectors)} 个向量的一致性") print(f"✓ 所有图片的向量与第一次生成的向量一致") def test_build_index(self): """测试3:构建索引API和获取索引状态""" print("\n=== 测试3: 构建索引测试 ===") # 第一次构建索引(使用first_vectors) image_vectors_first = [] for i, img_file in enumerate(self.image_files): image_id = str(i+1) if image_id in self.first_vectors: vector_item = { "image_id": image_id, "vector": self.first_vectors[image_id], "image_name": os.path.basename(img_file), "image_path": img_file } image_vectors_first.append(vector_item) # 发送第一次构建索引请求 url = f"{API_BASE_URL}{ENDPOINTS['build_index']}" response = self.session.post( url, json={"image_vectors": image_vectors_first}, headers={"Content-Type": "application/json"} ) assert response.status_code == 200, f"构建索引失败,状态码: {response.status_code}" # 解析响应 result = response.json() assert result.get('success') is True, f"构建索引失败: {result.get('error')}" indexed_count_first = result.get('indexed_count') assert indexed_count_first == len(image_vectors_first), f"索引数量不匹配: {indexed_count_first} != {len(image_vectors_first)}" # 获取索引状态 url = f"{API_BASE_URL}{ENDPOINTS['index_status']}" response = self.session.get(url) assert response.status_code == 200, f"获取索引状态失败,状态码: {response.status_code}" # 解析响应 status_result = response.json() assert status_result.get('success') is True, f"获取索引状态失败" status = status_result.get('status') assert status is not None, "索引状态不存在" assert status.get('indexed_count') == indexed_count_first, f"索引状态数量不匹配" print(f"✓ 成功构建索引,索引了 {indexed_count_first} 个图片") print(f"✓ 索引状态正确: {status}") # 第二次构建索引(使用second_vectors,应该是替代构建不是累加构建) image_vectors_second = [] for i, img_file in enumerate(self.image_files): image_id = str(i+1) if image_id in self.second_vectors: vector_item = { "image_id": image_id, "vector": self.second_vectors[image_id], "image_name": os.path.basename(img_file), "image_path": img_file } image_vectors_second.append(vector_item) # 发送第二次构建索引请求(使用正确的build_index URL) url = f"{API_BASE_URL}{ENDPOINTS['build_index']}" response = self.session.post( url, json={"image_vectors": image_vectors_second}, headers={"Content-Type": "application/json"} ) assert response.status_code == 200, f"第二次构建索引失败,状态码: {response.status_code}" # 解析响应 result_second = response.json() assert result_second.get('success') is True, f"第二次构建索引失败: {result_second.get('error')}" indexed_count_second = result_second.get('indexed_count') assert indexed_count_second == len(image_vectors_second), f"索引数量不匹配: {indexed_count_second} != {len(image_vectors_second)}" # 获取索引状态(使用正确的index_status URL) url = f"{API_BASE_URL}{ENDPOINTS['index_status']}" response = self.session.get(url) assert response.status_code == 200, f"获取索引状态失败,状态码: {response.status_code}" # 解析响应 status_result_second = response.json() assert status_result_second.get('success') is True, f"获取索引状态失败" status_second = status_result_second.get('status') assert status_second is not None, "索引状态不存在" assert status_second.get('indexed_count') == indexed_count_second, f"索引状态数量不匹配" # 断言是替代构建不是累加构建(索引数量应该等于第二次的数量,而不是两次的总和) assert indexed_count_second == len(image_vectors_second), f"索引数量应该等于第二次构建的数量,而不是累加" assert indexed_count_second != indexed_count_first + len(image_vectors_second), f"索引不应该是累加构建" print(f"✓ 成功进行替代构建,索引了 {indexed_count_second} 个图片") print(f"✓ 索引状态正确: {status_second}") def test_image_search(self): """测试4:以图搜图API,断言第一个结果是本图""" print("\n=== 测试4: 以图搜图测试 ===") # 测试前3张图片 test_count = min(3, len(self.image_files)) for i in range(test_count): img_file = self.image_files[i] image_id = str(i+1) # 构建搜索请求 base64_data = self.image_to_base64(img_file) assert base64_data is not None, f"图片 {image_id} 转换失败" search_request = { "image": base64_data, "top_k": 3 } # 发送搜索请求 url = f"{API_BASE_URL}{ENDPOINTS['search']}" response = self.session.post( url, json=search_request, headers={"Content-Type": "application/json"} ) assert response.status_code == 200, f"搜索失败,状态码: {response.status_code}" # 解析响应 result = response.json() assert result.get('success') is True, f"搜索失败: {result.get('error')}" results = result.get('results', []) assert len(results) > 0, "搜索结果为空" # 断言第一个结果是本图 top_result = results[0] assert top_result.get('image_id') == image_id, f"图片 {image_id} 搜索结果第一个不是本图,而是 {top_result.get('image_id')}" print(f"✓ 图片 {image_id} 搜索成功,第一个结果是本图") def test_text_search(self): """测试5:以文本搜图,搜索"胡桃木方",断言第一个结果是3号图""" print("\n=== 测试5: 以文搜图测试 ===") # 构建搜索请求(降低阈值以提高成功率) search_request = { "text": "胡桃木方", "top_k": 3 } # 发送搜索请求 url = f"{API_BASE_URL}{ENDPOINTS['search']}" response = self.session.post( url, json=search_request, headers={"Content-Type": "application/json"} ) assert response.status_code == 200, f"搜索失败,状态码: {response.status_code}" # 解析响应 result = response.json() assert result.get('success') is True, f"搜索失败: {result.get('error')}" results = result.get('results', []) if len(results) == 0: print("⚠️ 搜索结果为空,可能是因为没有配置 DASHSCOPE_API_KEY 或文本与图片不匹配") print("⚠️ 跳过文本搜索测试") return # 输出搜索结果,供人工检查 print(f"✓ 文本搜索'胡桃木方'成功") print(f"搜索结果(前3个):") for i, item in enumerate(results[:3]): print(f" {i+1}. 图片ID: {item.get('image_id')}, 相似度: {item.get('similarity'):.4f}, 图片名称: {item.get('image_name')}") def test_clear_index(self): """测试6:清除索引测试""" print("\n=== 测试6: 清除索引测试 ===") # 获取清除索引的端点(假设是 /image/index/clear) clear_index_endpoint = "/image/index/clear" url = f"{API_BASE_URL}{clear_index_endpoint}" # 发送清除索引请求 response = self.session.post( url, headers={"Content-Type": "application/json"} ) assert response.status_code == 200, f"清除索引失败,状态码: {response.status_code}" # 解析响应 result = response.json() assert result.get('success') is True, f"清除索引失败: {result.get('error')}" # 获取索引状态,验证索引已被清除 url = f"{API_BASE_URL}{ENDPOINTS['index_status']}" response = self.session.get(url) assert response.status_code == 200, f"获取索引状态失败,状态码: {response.status_code}" # 解析响应 status_result = response.json() assert status_result.get('success') is True, f"获取索引状态失败" status = status_result.get('status') assert status is not None, "索引状态不存在" assert status.get('indexed_count') == 0, f"索引未被清除,索引数量: {status.get('indexed_count')}" print(f"✓ 成功清除索引") print(f"✓ 索引状态正确: {status}") def run_all_tests(self): """运行所有测试""" print("开始图片搜索API测试...") print(f"API基础地址: {API_BASE_URL}") print(f"测试图片目录: {IMAGE_DIR}") print(f"测试图片数量: {MAX_IMAGES}") print("=" * 60) try: self.test_batch_vector_generation() self.test_vector_consistency() self.test_build_index() self.test_image_search() self.test_text_search() self.test_clear_index() print("\n" + "=" * 60) print("🎉 所有测试通过!") print("=" * 60) except Exception as e: print(f"\n❌ 测试失败: {e}") raise if __name__ == "__main__": print("测试脚本已加载,等待运行指令...") print("请使用以下命令运行测试:") print("python tests/test_image_search_api.py") print("\n或者在Python交互式环境中运行:") print("from tests.test_image_search_api import ImageSearchAPITest") print("test = ImageSearchAPITest()") print("test.run_all_tests()") print("\n注意:本脚本不会自动运行测试,需要手动调用 run_all_tests() 方法")