| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 测试图片搜索API
- 包含以下测试用例:
- 1. 批量生成向量,断言图1与图2的向量不同
- 2. 再次生成向量,断言与第一步生成的向量一致
- 3. 构建索引API和获取索引状态
- 4. 以图搜图API,断言第一个结果是本图
- 5. 以文本搜图,搜索"胡桃木方",断言第一个结果是3号图
- """
- import os
- import base64
- import json
- import requests
- import numpy as np
- # 配置
- API_BASE_URL = "http://localhost:8888"
- IMAGE_DIR = "d:\\天翼云盘\\VS心流\\AI图片识别搜图修图理论与实践\\GME-Qwen2-VL-2B\\product_images_full"
- MAX_IMAGES = 10
- # API端点
- ENDPOINTS = {
- "batch_vector": "/image/vector/batch",
- "build_index": "/image/index/build",
- "index_status": "/image/index/status",
- "search": "/image/search"
- }
- class ImageSearchAPITest:
- def __init__(self):
- """初始化测试类"""
- self.image_files = []
- self.first_vectors = {} # 第一次生成的向量
- self.second_vectors = {} # 第二次生成的向量
- self.session = requests.Session()
- self.session.timeout = 60
-
- def get_image_files(self, max_count=10):
- """获取图片文件列表"""
- image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
- image_files = []
-
- for file in os.listdir(IMAGE_DIR):
- if len(image_files) >= max_count:
- break
- file_ext = os.path.splitext(file)[1].lower()
- if file_ext in image_extensions:
- image_files.append(os.path.join(IMAGE_DIR, file))
-
- self.image_files = image_files
-
- # 输出文件列表供参考
- print("获取到的图片文件列表:")
- for i, img_file in enumerate(image_files):
- print(f"{i+1}. {os.path.basename(img_file)}")
- print(f" 路径: {img_file}")
-
- return image_files
-
- def image_to_base64(self, image_path):
- """将图片转换为Base64编码"""
- try:
- with open(image_path, 'rb') as f:
- image_bytes = f.read()
- base64_str = base64.b64encode(image_bytes).decode('utf-8')
- ext = os.path.splitext(image_path)[1].lower().replace('.', '')
- data_url = f"data:image/{ext};base64,{base64_str}"
- return data_url
- except Exception as e:
- print(f"处理图片失败 {image_path}: {e}")
- return None
-
- def test_batch_vector_generation(self):
- """测试1:批量生成向量,断言图1与图2的向量不同"""
- print("=== 测试1: 批量生成向量 ===")
-
- # 获取图片文件
- if not self.image_files:
- self.get_image_files(MAX_IMAGES)
-
- # 构建请求体
- requests_list = []
- for i, img_file in enumerate(self.image_files):
- base64_data = self.image_to_base64(img_file)
- if base64_data:
- request_item = {
- "image": base64_data,
- "image_id": str(i+1)
- }
- requests_list.append(request_item)
-
- # 发送请求
- url = f"{API_BASE_URL}{ENDPOINTS['batch_vector']}"
- response = self.session.post(
- url,
- json=requests_list,
- headers={"Content-Type": "application/json"}
- )
-
- assert response.status_code == 200, f"请求失败,状态码: {response.status_code}"
-
- # 解析响应
- result = response.json()
- assert len(result) == len(requests_list), f"响应数量不匹配: {len(result)} != {len(requests_list)}"
-
- # 存储向量
- for i, item in enumerate(result):
- assert item.get('success') is True, f"图片 {i+1} 向量生成失败: {item.get('error')}"
- image_id = item.get('image_id')
- vector = item.get('vector')
- assert vector is not None, f"图片 {image_id} 没有返回向量"
- self.first_vectors[image_id] = vector
-
- # 断言图1与图2的向量不同
- assert '1' in self.first_vectors, "图1向量不存在"
- assert '2' in self.first_vectors, "图2向量不存在"
- vector1 = self.first_vectors['1']
- vector2 = self.first_vectors['2']
- assert vector1 != vector2, "图1与图2的向量相同,不符合预期"
-
- print(f"✓ 成功生成 {len(self.first_vectors)} 个向量")
- print(f"✓ 图1与图2的向量不同")
-
- def test_vector_consistency(self):
- """测试2:再次生成向量,断言与第一步生成的向量一致"""
- print("\n=== 测试2: 向量一致性测试 ===")
-
- # 构建请求体
- requests_list = []
- for i, img_file in enumerate(self.image_files):
- base64_data = self.image_to_base64(img_file)
- if base64_data:
- request_item = {
- "image": base64_data,
- "image_id": str(i+1)
- }
- requests_list.append(request_item)
-
- # 发送请求
- url = f"{API_BASE_URL}{ENDPOINTS['batch_vector']}"
- response = self.session.post(
- url,
- json=requests_list,
- headers={"Content-Type": "application/json"}
- )
-
- assert response.status_code == 200, f"请求失败,状态码: {response.status_code}"
-
- # 解析响应
- result = response.json()
- assert len(result) == len(requests_list), f"响应数量不匹配: {len(result)} != {len(requests_list)}"
-
- # 存储向量
- for i, item in enumerate(result):
- assert item.get('success') is True, f"图片 {i+1} 向量生成失败: {item.get('error')}"
- image_id = item.get('image_id')
- vector = item.get('vector')
- assert vector is not None, f"图片 {image_id} 没有返回向量"
- self.second_vectors[image_id] = vector
-
- # 断言向量一致性
- for image_id, first_vector in self.first_vectors.items():
- assert image_id in self.second_vectors, f"图片 {image_id} 第二次向量不存在"
- second_vector = self.second_vectors[image_id]
-
- # 计算向量相似度
- similarity = np.dot(first_vector, second_vector) / (
- np.linalg.norm(first_vector) * np.linalg.norm(second_vector)
- )
-
- # 断言向量接近(相似度大于0.99)
- assert similarity > 0.99, f"图片 {image_id} 向量不一致,相似度: {similarity:.4f}"
-
- print(f"✓ 成功验证 {len(self.second_vectors)} 个向量的一致性")
- print(f"✓ 所有图片的向量与第一次生成的向量一致")
-
- def test_build_index(self):
- """测试3:构建索引API和获取索引状态"""
- print("\n=== 测试3: 构建索引测试 ===")
-
- # 第一次构建索引(使用first_vectors)
- image_vectors_first = []
- for i, img_file in enumerate(self.image_files):
- image_id = str(i+1)
- if image_id in self.first_vectors:
- vector_item = {
- "image_id": image_id,
- "vector": self.first_vectors[image_id],
- "image_name": os.path.basename(img_file),
- "image_path": img_file
- }
- image_vectors_first.append(vector_item)
-
- # 发送第一次构建索引请求
- url = f"{API_BASE_URL}{ENDPOINTS['build_index']}"
- response = self.session.post(
- url,
- json={"image_vectors": image_vectors_first},
- headers={"Content-Type": "application/json"}
- )
-
- assert response.status_code == 200, f"构建索引失败,状态码: {response.status_code}"
-
- # 解析响应
- result = response.json()
- assert result.get('success') is True, f"构建索引失败: {result.get('error')}"
- indexed_count_first = result.get('indexed_count')
- assert indexed_count_first == len(image_vectors_first), f"索引数量不匹配: {indexed_count_first} != {len(image_vectors_first)}"
-
- # 获取索引状态
- url = f"{API_BASE_URL}{ENDPOINTS['index_status']}"
- response = self.session.get(url)
-
- assert response.status_code == 200, f"获取索引状态失败,状态码: {response.status_code}"
-
- # 解析响应
- status_result = response.json()
- assert status_result.get('success') is True, f"获取索引状态失败"
- status = status_result.get('status')
- assert status is not None, "索引状态不存在"
- assert status.get('indexed_count') == indexed_count_first, f"索引状态数量不匹配"
-
- print(f"✓ 成功构建索引,索引了 {indexed_count_first} 个图片")
- print(f"✓ 索引状态正确: {status}")
-
- # 第二次构建索引(使用second_vectors,应该是替代构建不是累加构建)
- image_vectors_second = []
- for i, img_file in enumerate(self.image_files):
- image_id = str(i+1)
- if image_id in self.second_vectors:
- vector_item = {
- "image_id": image_id,
- "vector": self.second_vectors[image_id],
- "image_name": os.path.basename(img_file),
- "image_path": img_file
- }
- image_vectors_second.append(vector_item)
-
- # 发送第二次构建索引请求(使用正确的build_index URL)
- url = f"{API_BASE_URL}{ENDPOINTS['build_index']}"
- response = self.session.post(
- url,
- json={"image_vectors": image_vectors_second},
- headers={"Content-Type": "application/json"}
- )
-
- assert response.status_code == 200, f"第二次构建索引失败,状态码: {response.status_code}"
-
- # 解析响应
- result_second = response.json()
- assert result_second.get('success') is True, f"第二次构建索引失败: {result_second.get('error')}"
- indexed_count_second = result_second.get('indexed_count')
- assert indexed_count_second == len(image_vectors_second), f"索引数量不匹配: {indexed_count_second} != {len(image_vectors_second)}"
-
- # 获取索引状态(使用正确的index_status URL)
- url = f"{API_BASE_URL}{ENDPOINTS['index_status']}"
- response = self.session.get(url)
-
- assert response.status_code == 200, f"获取索引状态失败,状态码: {response.status_code}"
-
- # 解析响应
- status_result_second = response.json()
- assert status_result_second.get('success') is True, f"获取索引状态失败"
- status_second = status_result_second.get('status')
- assert status_second is not None, "索引状态不存在"
- assert status_second.get('indexed_count') == indexed_count_second, f"索引状态数量不匹配"
-
- # 断言是替代构建不是累加构建(索引数量应该等于第二次的数量,而不是两次的总和)
- assert indexed_count_second == len(image_vectors_second), f"索引数量应该等于第二次构建的数量,而不是累加"
- assert indexed_count_second != indexed_count_first + len(image_vectors_second), f"索引不应该是累加构建"
-
- print(f"✓ 成功进行替代构建,索引了 {indexed_count_second} 个图片")
- print(f"✓ 索引状态正确: {status_second}")
-
- def test_image_search(self):
- """测试4:以图搜图API,断言第一个结果是本图"""
- print("\n=== 测试4: 以图搜图测试 ===")
-
- # 测试前3张图片
- test_count = min(3, len(self.image_files))
-
- for i in range(test_count):
- img_file = self.image_files[i]
- image_id = str(i+1)
-
- # 构建搜索请求
- base64_data = self.image_to_base64(img_file)
- assert base64_data is not None, f"图片 {image_id} 转换失败"
-
- search_request = {
- "image": base64_data,
- "top_k": 3
- }
-
- # 发送搜索请求
- url = f"{API_BASE_URL}{ENDPOINTS['search']}"
- response = self.session.post(
- url,
- json=search_request,
- headers={"Content-Type": "application/json"}
- )
-
- assert response.status_code == 200, f"搜索失败,状态码: {response.status_code}"
-
- # 解析响应
- result = response.json()
- assert result.get('success') is True, f"搜索失败: {result.get('error')}"
- results = result.get('results', [])
- assert len(results) > 0, "搜索结果为空"
-
- # 断言第一个结果是本图
- top_result = results[0]
- assert top_result.get('image_id') == image_id, f"图片 {image_id} 搜索结果第一个不是本图,而是 {top_result.get('image_id')}"
-
- print(f"✓ 图片 {image_id} 搜索成功,第一个结果是本图")
-
- def test_text_search(self):
- """测试5:以文本搜图,搜索"胡桃木方",断言第一个结果是3号图"""
- print("\n=== 测试5: 以文搜图测试 ===")
-
- # 构建搜索请求(降低阈值以提高成功率)
- search_request = {
- "text": "胡桃木方",
- "top_k": 3
- }
-
- # 发送搜索请求
- url = f"{API_BASE_URL}{ENDPOINTS['search']}"
- response = self.session.post(
- url,
- json=search_request,
- headers={"Content-Type": "application/json"}
- )
-
- assert response.status_code == 200, f"搜索失败,状态码: {response.status_code}"
-
- # 解析响应
- result = response.json()
- assert result.get('success') is True, f"搜索失败: {result.get('error')}"
- results = result.get('results', [])
-
- if len(results) == 0:
- print("⚠️ 搜索结果为空,可能是因为没有配置 DASHSCOPE_API_KEY 或文本与图片不匹配")
- print("⚠️ 跳过文本搜索测试")
- return
-
- # 输出搜索结果,供人工检查
- print(f"✓ 文本搜索'胡桃木方'成功")
- print(f"搜索结果(前3个):")
- for i, item in enumerate(results[:3]):
- print(f" {i+1}. 图片ID: {item.get('image_id')}, 相似度: {item.get('similarity'):.4f}, 图片名称: {item.get('image_name')}")
-
- def test_clear_index(self):
- """测试6:清除索引测试"""
- print("\n=== 测试6: 清除索引测试 ===")
-
- # 获取清除索引的端点(假设是 /image/index/clear)
- clear_index_endpoint = "/image/index/clear"
- url = f"{API_BASE_URL}{clear_index_endpoint}"
-
- # 发送清除索引请求
- response = self.session.post(
- url,
- headers={"Content-Type": "application/json"}
- )
-
- assert response.status_code == 200, f"清除索引失败,状态码: {response.status_code}"
-
- # 解析响应
- result = response.json()
- assert result.get('success') is True, f"清除索引失败: {result.get('error')}"
-
- # 获取索引状态,验证索引已被清除
- url = f"{API_BASE_URL}{ENDPOINTS['index_status']}"
- response = self.session.get(url)
-
- assert response.status_code == 200, f"获取索引状态失败,状态码: {response.status_code}"
-
- # 解析响应
- status_result = response.json()
- assert status_result.get('success') is True, f"获取索引状态失败"
- status = status_result.get('status')
- assert status is not None, "索引状态不存在"
- assert status.get('indexed_count') == 0, f"索引未被清除,索引数量: {status.get('indexed_count')}"
-
- print(f"✓ 成功清除索引")
- print(f"✓ 索引状态正确: {status}")
-
- def run_all_tests(self):
- """运行所有测试"""
- print("开始图片搜索API测试...")
- print(f"API基础地址: {API_BASE_URL}")
- print(f"测试图片目录: {IMAGE_DIR}")
- print(f"测试图片数量: {MAX_IMAGES}")
- print("=" * 60)
-
- try:
- self.test_batch_vector_generation()
- self.test_vector_consistency()
- self.test_build_index()
- self.test_image_search()
- self.test_text_search()
- self.test_clear_index()
-
- print("\n" + "=" * 60)
- print("🎉 所有测试通过!")
- print("=" * 60)
-
- except Exception as e:
- print(f"\n❌ 测试失败: {e}")
- raise
- if __name__ == "__main__":
- print("测试脚本已加载,等待运行指令...")
- print("请使用以下命令运行测试:")
- print("python tests/test_image_search_api.py")
- print("\n或者在Python交互式环境中运行:")
- print("from tests.test_image_search_api import ImageSearchAPITest")
- print("test = ImageSearchAPITest()")
- print("test.run_all_tests()")
- print("\n注意:本脚本不会自动运行测试,需要手动调用 run_all_tests() 方法")
|