base_template.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from abc import ABC, abstractmethod
  2. from typing import Dict, Any
  3. from pydantic import BaseModel
  4. import json
  5. from datetime import datetime
  6. from config.template_config_manager import TemplateConfigManager # 新增导入
  7. class DocumentTemplate(ABC):
  8. """单据模板基类(支持配置扩展)"""
  9. def __init__(self, config_manager: TemplateConfigManager = None):
  10. # 新增:配置管理器
  11. self.config_manager = config_manager or TemplateConfigManager()
  12. self._template_config = self.config_manager.get_template_config(
  13. self.template_name
  14. )
  15. @property
  16. @abstractmethod
  17. def template_name(self) -> str:
  18. """模板名称标识"""
  19. pass
  20. @property
  21. @abstractmethod
  22. def description(self) -> str:
  23. """模板描述"""
  24. pass
  25. def get_hardcoded_guidance(self) -> Dict[str, Any]:
  26. """硬编码的字段指导信息(子类可重写)"""
  27. return {"field_guidance": {}, "additional_rules": ""} # 项目初期为空
  28. @property
  29. def system_prompt(self) -> str:
  30. """系统提示词(支持配置扩展)"""
  31. # 获取硬编码指导
  32. hardcoded_guidance = self.get_hardcoded_guidance()
  33. hardcoded_field_guidance = hardcoded_guidance.get("field_guidance", {})
  34. hardcoded_additional_rules = hardcoded_guidance.get("additional_rules", "")
  35. # 获取配置指导(已自动过滤空白值)
  36. configured_field_guidance = self._template_config.get("field_guidance", {})
  37. configured_additional_rules = self._template_config.get("additional_rules", "")
  38. # 合并字段指导信息
  39. merged_field_guidance = self._merge_field_guidance(
  40. hardcoded_field_guidance, configured_field_guidance
  41. )
  42. base_prompt = f"""你是一个专业的单据信息提取助手。现在时间是{datetime.now().isoformat()}
  43. 请从OCR识别结果或用户的输入中提取信息,并严格按照以下JSON格式返回:
  44. {json.dumps(self.output_schema(), indent=2, ensure_ascii=False)}
  45. 提取规则:
  46. {self.extraction_rules()}"""
  47. # 添加合并后的指导信息(仅在存在内容时添加)
  48. extended_prompt = self._build_extended_guidance(
  49. merged_field_guidance,
  50. hardcoded_additional_rules,
  51. configured_additional_rules,
  52. )
  53. if extended_prompt:
  54. base_prompt += f"\n\n额外指导信息:\n{extended_prompt}"
  55. base_prompt += """
  56. 请确保:
  57. 1. 只提取明确存在的字段
  58. 2. 日期格式统一为YYYY-MM-DD
  59. 3. 数字类型保持原样"""
  60. return base_prompt
  61. def _merge_field_guidance(
  62. self, hardcoded: Dict[str, list], configured: Dict[str, list]
  63. ) -> Dict[str, list]:
  64. """合并硬编码和配置的字段指导信息(自动去重)"""
  65. merged = {}
  66. all_fields = set(hardcoded.keys()) | set(configured.keys())
  67. for field in all_fields:
  68. hardcoded_hints = hardcoded.get(field, [])
  69. configured_hints = configured.get(field, [])
  70. # 合并并去重,保持顺序
  71. combined = []
  72. seen = set()
  73. for hint in hardcoded_hints + configured_hints:
  74. if hint and hint not in seen: # 过滤空值并去重
  75. combined.append(hint)
  76. seen.add(hint)
  77. if combined: # 只添加有内容的字段
  78. merged[field] = combined
  79. return merged
  80. def _build_extended_guidance(
  81. self,
  82. field_guidance: Dict[str, list],
  83. hardcoded_rules: str,
  84. configured_rules: str,
  85. ) -> str:
  86. """构建扩展指导信息"""
  87. guidance_parts = []
  88. # 字段指导信息(仅在存在内容时添加)
  89. if field_guidance:
  90. guidance_parts.append("字段识别指导:")
  91. for field_name, hints in field_guidance.items():
  92. if hints:
  93. combined_hints = "; ".join(hints)
  94. guidance_parts.append(f"- {field_name}: {combined_hints}")
  95. # 合并额外规则(过滤空值)
  96. additional_rules = []
  97. if hardcoded_rules and hardcoded_rules.strip():
  98. additional_rules.append(hardcoded_rules)
  99. if configured_rules and configured_rules.strip():
  100. additional_rules.append(configured_rules)
  101. if additional_rules:
  102. guidance_parts.append(f"特殊规则: {'; '.join(additional_rules)}")
  103. return "\n".join(guidance_parts) if guidance_parts else ""
  104. @abstractmethod
  105. def output_schema(self) -> Dict[str, Any]:
  106. """返回JSON输出结构定义"""
  107. pass
  108. @abstractmethod
  109. def extraction_rules(self) -> str:
  110. """字段提取规则说明"""
  111. pass
  112. def validate_result(self, result: Dict) -> bool:
  113. """验证解析结果"""
  114. return True
  115. def post_process(self, result: Dict) -> Dict:
  116. """后处理钩子"""
  117. result["_metadata"] = {
  118. "template": self.template_name,
  119. "processed_at": datetime.now().isoformat(),
  120. "template_description": self.description,
  121. }
  122. return result