ocr_service.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import base64
  2. import requests
  3. from typing import Dict, Any, Optional
  4. from utils.logger import chat_logger
  5. import aiohttp
  6. import asyncio
  7. class PaddleOCRService:
  8. """PaddleOCR API服务封装"""
  9. def __init__(self, api_url: str, token: str):
  10. self.api_url = api_url
  11. self.token = token
  12. self.headers = {
  13. "Authorization": f"token {token}",
  14. "Content-Type": "application/json",
  15. }
  16. async def recognize_image_async(
  17. self,
  18. image_bytes: bytes,
  19. file_type: int = 1, # 1: 图片, 0: PDF
  20. use_doc_orientation_classify: bool = False,
  21. use_doc_unwarping: bool = False,
  22. use_textline_orientation: bool = False,
  23. ) -> Dict[str, Any]:
  24. """异步调用OCR API"""
  25. # 同步调用(如果不需要异步,可以用requests)
  26. return self.recognize_image_sync(
  27. image_bytes=image_bytes,
  28. file_type=file_type,
  29. use_doc_orientation_classify=use_doc_orientation_classify,
  30. use_doc_unwarping=use_doc_unwarping,
  31. use_textline_orientation=use_textline_orientation,
  32. )
  33. def recognize_image_sync(
  34. self,
  35. image_bytes: bytes,
  36. file_type: int = 1,
  37. use_doc_orientation_classify: bool = False,
  38. use_doc_unwarping: bool = False,
  39. use_textline_orientation: bool = False,
  40. ) -> Dict[str, Any]:
  41. """同步调用OCR API"""
  42. try:
  43. # 转换为base64
  44. file_data = base64.b64encode(image_bytes).decode("ascii")
  45. payload = {
  46. "file": file_data,
  47. "fileType": file_type,
  48. "useDocOrientationClassify": use_doc_orientation_classify,
  49. "useDocUnwarping": use_doc_unwarping,
  50. "useTextlineOrientation": use_textline_orientation,
  51. }
  52. response = requests.post(
  53. self.api_url, json=payload, headers=self.headers, timeout=30
  54. )
  55. if response.status_code != 200:
  56. chat_logger.error(
  57. f"OCR API调用失败: {response.status_code}, {response.text}"
  58. )
  59. raise Exception(f"OCR识别失败: {response.status_code}")
  60. result = response.json()
  61. if "result" not in result:
  62. chat_logger.error(f"OCR返回格式错误: {result}")
  63. raise Exception("OCR返回格式错误")
  64. return result["result"]
  65. except requests.exceptions.Timeout:
  66. chat_logger.error("OCR API调用超时")
  67. raise Exception("OCR识别超时")
  68. except Exception as e:
  69. chat_logger.error(f"OCR识别异常: {str(e)}")
  70. raise
  71. def extract_text_from_result(self, json_data: Dict) -> str:
  72. """从OCR结果中提取文本(格式化)"""
  73. try:
  74. ocr_result = json_data["ocrResults"][0]["prunedResult"]
  75. texts = ocr_result["rec_texts"]
  76. scores = ocr_result["rec_scores"]
  77. boxes = ocr_result["rec_boxes"]
  78. ocr_text = "识别文本 | 识别框坐标"
  79. ocr_text += "\r\n" + ("-" * 50)
  80. for text, score, box in zip(texts, scores, boxes):
  81. if score > 0.5:
  82. ocr_text += f"\r\n{text} | {box}"
  83. # # 按置信度过滤并排序
  84. # filtered_results = []
  85. # for i, (text, score) in enumerate(zip(texts, scores)):
  86. # if score > 0.5: # 置信度阈值
  87. # filtered_results.append(
  88. # {
  89. # "index": i,
  90. # "text": text,
  91. # "score": float(score),
  92. # "box": (
  93. # ocr_data["rec_boxes"][i]
  94. # if i < len(ocr_data["rec_boxes"])
  95. # else None
  96. # ),
  97. # }
  98. # )
  99. # # 按位置排序(从上到下,从左到右)
  100. # filtered_results.sort(
  101. # key=lambda x: (
  102. # x["box"][0][1] if x["box"] else 0, # y坐标
  103. # x["box"][0][0] if x["box"] else 0, # x坐标
  104. # )
  105. # )
  106. # # 拼接为文本
  107. # ocr_text = "\n".join([item["text"] for item in filtered_results])
  108. # chat_logger.info(f"OCR识别成功,识别到{len(filtered_results)}个文本块")
  109. return ocr_text
  110. except Exception as e:
  111. chat_logger.error(f"提取OCR文本失败: {str(e)}")
  112. raise