test_image_search_api.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 测试图片搜索API
  5. 包含以下测试用例:
  6. 1. 批量生成向量,断言图1与图2的向量不同
  7. 2. 再次生成向量,断言与第一步生成的向量一致
  8. 3. 构建索引API和获取索引状态
  9. 4. 以图搜图API,断言第一个结果是本图
  10. 5. 以文本搜图,搜索"胡桃木方",断言第一个结果是3号图
  11. """
  12. import os
  13. import base64
  14. import json
  15. import requests
  16. import numpy as np
  17. # 配置
  18. API_BASE_URL = "http://localhost:8888"
  19. IMAGE_DIR = "d:\\天翼云盘\\VS心流\\AI图片识别搜图修图理论与实践\\GME-Qwen2-VL-2B\\product_images_full"
  20. MAX_IMAGES = 10
  21. # API端点
  22. ENDPOINTS = {
  23. "batch_vector": "/image/vector/batch",
  24. "build_index": "/image/index/build",
  25. "index_status": "/image/index/status",
  26. "search": "/image/search"
  27. }
  28. class ImageSearchAPITest:
  29. def __init__(self):
  30. """初始化测试类"""
  31. self.image_files = []
  32. self.first_vectors = {} # 第一次生成的向量
  33. self.second_vectors = {} # 第二次生成的向量
  34. self.session = requests.Session()
  35. self.session.timeout = 60
  36. def get_image_files(self, max_count=10):
  37. """获取图片文件列表"""
  38. image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
  39. image_files = []
  40. for file in os.listdir(IMAGE_DIR):
  41. if len(image_files) >= max_count:
  42. break
  43. file_ext = os.path.splitext(file)[1].lower()
  44. if file_ext in image_extensions:
  45. image_files.append(os.path.join(IMAGE_DIR, file))
  46. self.image_files = image_files
  47. # 输出文件列表供参考
  48. print("获取到的图片文件列表:")
  49. for i, img_file in enumerate(image_files):
  50. print(f"{i+1}. {os.path.basename(img_file)}")
  51. print(f" 路径: {img_file}")
  52. return image_files
  53. def image_to_base64(self, image_path):
  54. """将图片转换为Base64编码"""
  55. try:
  56. with open(image_path, 'rb') as f:
  57. image_bytes = f.read()
  58. base64_str = base64.b64encode(image_bytes).decode('utf-8')
  59. ext = os.path.splitext(image_path)[1].lower().replace('.', '')
  60. data_url = f"data:image/{ext};base64,{base64_str}"
  61. return data_url
  62. except Exception as e:
  63. print(f"处理图片失败 {image_path}: {e}")
  64. return None
  65. def test_batch_vector_generation(self):
  66. """测试1:批量生成向量,断言图1与图2的向量不同"""
  67. print("=== 测试1: 批量生成向量 ===")
  68. # 获取图片文件
  69. if not self.image_files:
  70. self.get_image_files(MAX_IMAGES)
  71. # 构建请求体
  72. requests_list = []
  73. for i, img_file in enumerate(self.image_files):
  74. base64_data = self.image_to_base64(img_file)
  75. if base64_data:
  76. request_item = {
  77. "image": base64_data,
  78. "image_id": str(i+1)
  79. }
  80. requests_list.append(request_item)
  81. # 发送请求
  82. url = f"{API_BASE_URL}{ENDPOINTS['batch_vector']}"
  83. response = self.session.post(
  84. url,
  85. json=requests_list,
  86. headers={"Content-Type": "application/json"}
  87. )
  88. assert response.status_code == 200, f"请求失败,状态码: {response.status_code}"
  89. # 解析响应
  90. result = response.json()
  91. assert len(result) == len(requests_list), f"响应数量不匹配: {len(result)} != {len(requests_list)}"
  92. # 存储向量
  93. for i, item in enumerate(result):
  94. assert item.get('success') is True, f"图片 {i+1} 向量生成失败: {item.get('error')}"
  95. image_id = item.get('image_id')
  96. vector = item.get('vector')
  97. assert vector is not None, f"图片 {image_id} 没有返回向量"
  98. self.first_vectors[image_id] = vector
  99. # 断言图1与图2的向量不同
  100. assert '1' in self.first_vectors, "图1向量不存在"
  101. assert '2' in self.first_vectors, "图2向量不存在"
  102. vector1 = self.first_vectors['1']
  103. vector2 = self.first_vectors['2']
  104. assert vector1 != vector2, "图1与图2的向量相同,不符合预期"
  105. print(f"✓ 成功生成 {len(self.first_vectors)} 个向量")
  106. print(f"✓ 图1与图2的向量不同")
  107. def test_vector_consistency(self):
  108. """测试2:再次生成向量,断言与第一步生成的向量一致"""
  109. print("\n=== 测试2: 向量一致性测试 ===")
  110. # 构建请求体
  111. requests_list = []
  112. for i, img_file in enumerate(self.image_files):
  113. base64_data = self.image_to_base64(img_file)
  114. if base64_data:
  115. request_item = {
  116. "image": base64_data,
  117. "image_id": str(i+1)
  118. }
  119. requests_list.append(request_item)
  120. # 发送请求
  121. url = f"{API_BASE_URL}{ENDPOINTS['batch_vector']}"
  122. response = self.session.post(
  123. url,
  124. json=requests_list,
  125. headers={"Content-Type": "application/json"}
  126. )
  127. assert response.status_code == 200, f"请求失败,状态码: {response.status_code}"
  128. # 解析响应
  129. result = response.json()
  130. assert len(result) == len(requests_list), f"响应数量不匹配: {len(result)} != {len(requests_list)}"
  131. # 存储向量
  132. for i, item in enumerate(result):
  133. assert item.get('success') is True, f"图片 {i+1} 向量生成失败: {item.get('error')}"
  134. image_id = item.get('image_id')
  135. vector = item.get('vector')
  136. assert vector is not None, f"图片 {image_id} 没有返回向量"
  137. self.second_vectors[image_id] = vector
  138. # 断言向量一致性
  139. for image_id, first_vector in self.first_vectors.items():
  140. assert image_id in self.second_vectors, f"图片 {image_id} 第二次向量不存在"
  141. second_vector = self.second_vectors[image_id]
  142. # 计算向量相似度
  143. similarity = np.dot(first_vector, second_vector) / (
  144. np.linalg.norm(first_vector) * np.linalg.norm(second_vector)
  145. )
  146. # 断言向量接近(相似度大于0.99)
  147. assert similarity > 0.99, f"图片 {image_id} 向量不一致,相似度: {similarity:.4f}"
  148. print(f"✓ 成功验证 {len(self.second_vectors)} 个向量的一致性")
  149. print(f"✓ 所有图片的向量与第一次生成的向量一致")
  150. def test_build_index(self):
  151. """测试3:构建索引API和获取索引状态"""
  152. print("\n=== 测试3: 构建索引测试 ===")
  153. # 第一次构建索引(使用first_vectors)
  154. image_vectors_first = []
  155. for i, img_file in enumerate(self.image_files):
  156. image_id = str(i+1)
  157. if image_id in self.first_vectors:
  158. vector_item = {
  159. "image_id": image_id,
  160. "vector": self.first_vectors[image_id],
  161. "image_name": os.path.basename(img_file),
  162. "image_path": img_file
  163. }
  164. image_vectors_first.append(vector_item)
  165. # 发送第一次构建索引请求
  166. url = f"{API_BASE_URL}{ENDPOINTS['build_index']}"
  167. response = self.session.post(
  168. url,
  169. json={"image_vectors": image_vectors_first},
  170. headers={"Content-Type": "application/json"}
  171. )
  172. assert response.status_code == 200, f"构建索引失败,状态码: {response.status_code}"
  173. # 解析响应
  174. result = response.json()
  175. assert result.get('success') is True, f"构建索引失败: {result.get('error')}"
  176. indexed_count_first = result.get('indexed_count')
  177. assert indexed_count_first == len(image_vectors_first), f"索引数量不匹配: {indexed_count_first} != {len(image_vectors_first)}"
  178. # 获取索引状态
  179. url = f"{API_BASE_URL}{ENDPOINTS['index_status']}"
  180. response = self.session.get(url)
  181. assert response.status_code == 200, f"获取索引状态失败,状态码: {response.status_code}"
  182. # 解析响应
  183. status_result = response.json()
  184. assert status_result.get('success') is True, f"获取索引状态失败"
  185. status = status_result.get('status')
  186. assert status is not None, "索引状态不存在"
  187. assert status.get('indexed_count') == indexed_count_first, f"索引状态数量不匹配"
  188. print(f"✓ 成功构建索引,索引了 {indexed_count_first} 个图片")
  189. print(f"✓ 索引状态正确: {status}")
  190. # 第二次构建索引(使用second_vectors,应该是替代构建不是累加构建)
  191. image_vectors_second = []
  192. for i, img_file in enumerate(self.image_files):
  193. image_id = str(i+1)
  194. if image_id in self.second_vectors:
  195. vector_item = {
  196. "image_id": image_id,
  197. "vector": self.second_vectors[image_id],
  198. "image_name": os.path.basename(img_file),
  199. "image_path": img_file
  200. }
  201. image_vectors_second.append(vector_item)
  202. # 发送第二次构建索引请求(使用正确的build_index URL)
  203. url = f"{API_BASE_URL}{ENDPOINTS['build_index']}"
  204. response = self.session.post(
  205. url,
  206. json={"image_vectors": image_vectors_second},
  207. headers={"Content-Type": "application/json"}
  208. )
  209. assert response.status_code == 200, f"第二次构建索引失败,状态码: {response.status_code}"
  210. # 解析响应
  211. result_second = response.json()
  212. assert result_second.get('success') is True, f"第二次构建索引失败: {result_second.get('error')}"
  213. indexed_count_second = result_second.get('indexed_count')
  214. assert indexed_count_second == len(image_vectors_second), f"索引数量不匹配: {indexed_count_second} != {len(image_vectors_second)}"
  215. # 获取索引状态(使用正确的index_status URL)
  216. url = f"{API_BASE_URL}{ENDPOINTS['index_status']}"
  217. response = self.session.get(url)
  218. assert response.status_code == 200, f"获取索引状态失败,状态码: {response.status_code}"
  219. # 解析响应
  220. status_result_second = response.json()
  221. assert status_result_second.get('success') is True, f"获取索引状态失败"
  222. status_second = status_result_second.get('status')
  223. assert status_second is not None, "索引状态不存在"
  224. assert status_second.get('indexed_count') == indexed_count_second, f"索引状态数量不匹配"
  225. # 断言是替代构建不是累加构建(索引数量应该等于第二次的数量,而不是两次的总和)
  226. assert indexed_count_second == len(image_vectors_second), f"索引数量应该等于第二次构建的数量,而不是累加"
  227. assert indexed_count_second != indexed_count_first + len(image_vectors_second), f"索引不应该是累加构建"
  228. print(f"✓ 成功进行替代构建,索引了 {indexed_count_second} 个图片")
  229. print(f"✓ 索引状态正确: {status_second}")
  230. def test_image_search(self):
  231. """测试4:以图搜图API,断言第一个结果是本图"""
  232. print("\n=== 测试4: 以图搜图测试 ===")
  233. # 测试前3张图片
  234. test_count = min(3, len(self.image_files))
  235. for i in range(test_count):
  236. img_file = self.image_files[i]
  237. image_id = str(i+1)
  238. # 构建搜索请求
  239. base64_data = self.image_to_base64(img_file)
  240. assert base64_data is not None, f"图片 {image_id} 转换失败"
  241. search_request = {
  242. "image": base64_data,
  243. "top_k": 3
  244. }
  245. # 发送搜索请求
  246. url = f"{API_BASE_URL}{ENDPOINTS['search']}"
  247. response = self.session.post(
  248. url,
  249. json=search_request,
  250. headers={"Content-Type": "application/json"}
  251. )
  252. assert response.status_code == 200, f"搜索失败,状态码: {response.status_code}"
  253. # 解析响应
  254. result = response.json()
  255. assert result.get('success') is True, f"搜索失败: {result.get('error')}"
  256. results = result.get('results', [])
  257. assert len(results) > 0, "搜索结果为空"
  258. # 断言第一个结果是本图
  259. top_result = results[0]
  260. assert top_result.get('image_id') == image_id, f"图片 {image_id} 搜索结果第一个不是本图,而是 {top_result.get('image_id')}"
  261. print(f"✓ 图片 {image_id} 搜索成功,第一个结果是本图")
  262. def test_text_search(self):
  263. """测试5:以文本搜图,搜索"胡桃木方",断言第一个结果是3号图"""
  264. print("\n=== 测试5: 以文搜图测试 ===")
  265. # 构建搜索请求(降低阈值以提高成功率)
  266. search_request = {
  267. "text": "胡桃木方",
  268. "top_k": 3
  269. }
  270. # 发送搜索请求
  271. url = f"{API_BASE_URL}{ENDPOINTS['search']}"
  272. response = self.session.post(
  273. url,
  274. json=search_request,
  275. headers={"Content-Type": "application/json"}
  276. )
  277. assert response.status_code == 200, f"搜索失败,状态码: {response.status_code}"
  278. # 解析响应
  279. result = response.json()
  280. assert result.get('success') is True, f"搜索失败: {result.get('error')}"
  281. results = result.get('results', [])
  282. if len(results) == 0:
  283. print("⚠️ 搜索结果为空,可能是因为没有配置 DASHSCOPE_API_KEY 或文本与图片不匹配")
  284. print("⚠️ 跳过文本搜索测试")
  285. return
  286. # 输出搜索结果,供人工检查
  287. print(f"✓ 文本搜索'胡桃木方'成功")
  288. print(f"搜索结果(前3个):")
  289. for i, item in enumerate(results[:3]):
  290. print(f" {i+1}. 图片ID: {item.get('image_id')}, 相似度: {item.get('similarity'):.4f}, 图片名称: {item.get('image_name')}")
  291. def test_clear_index(self):
  292. """测试6:清除索引测试"""
  293. print("\n=== 测试6: 清除索引测试 ===")
  294. # 获取清除索引的端点(假设是 /image/index/clear)
  295. clear_index_endpoint = "/image/index/clear"
  296. url = f"{API_BASE_URL}{clear_index_endpoint}"
  297. # 发送清除索引请求
  298. response = self.session.post(
  299. url,
  300. headers={"Content-Type": "application/json"}
  301. )
  302. assert response.status_code == 200, f"清除索引失败,状态码: {response.status_code}"
  303. # 解析响应
  304. result = response.json()
  305. assert result.get('success') is True, f"清除索引失败: {result.get('error')}"
  306. # 获取索引状态,验证索引已被清除
  307. url = f"{API_BASE_URL}{ENDPOINTS['index_status']}"
  308. response = self.session.get(url)
  309. assert response.status_code == 200, f"获取索引状态失败,状态码: {response.status_code}"
  310. # 解析响应
  311. status_result = response.json()
  312. assert status_result.get('success') is True, f"获取索引状态失败"
  313. status = status_result.get('status')
  314. assert status is not None, "索引状态不存在"
  315. assert status.get('indexed_count') == 0, f"索引未被清除,索引数量: {status.get('indexed_count')}"
  316. print(f"✓ 成功清除索引")
  317. print(f"✓ 索引状态正确: {status}")
  318. def run_all_tests(self):
  319. """运行所有测试"""
  320. print("开始图片搜索API测试...")
  321. print(f"API基础地址: {API_BASE_URL}")
  322. print(f"测试图片目录: {IMAGE_DIR}")
  323. print(f"测试图片数量: {MAX_IMAGES}")
  324. print("=" * 60)
  325. try:
  326. self.test_batch_vector_generation()
  327. self.test_vector_consistency()
  328. self.test_build_index()
  329. self.test_image_search()
  330. self.test_text_search()
  331. self.test_clear_index()
  332. print("\n" + "=" * 60)
  333. print("🎉 所有测试通过!")
  334. print("=" * 60)
  335. except Exception as e:
  336. print(f"\n❌ 测试失败: {e}")
  337. raise
  338. if __name__ == "__main__":
  339. print("测试脚本已加载,等待运行指令...")
  340. print("请使用以下命令运行测试:")
  341. print("python tests/test_image_search_api.py")
  342. print("\n或者在Python交互式环境中运行:")
  343. print("from tests.test_image_search_api import ImageSearchAPITest")
  344. print("test = ImageSearchAPITest()")
  345. print("test.run_all_tests()")
  346. print("\n注意:本脚本不会自动运行测试,需要手动调用 run_all_tests() 方法")