agent.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import os
  2. import dotenv
  3. import datetime
  4. from pathlib import Path
  5. from langchain.agents import create_agent, AgentState
  6. from langchain_openai import ChatOpenAI
  7. from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
  8. from tools.tool_factory import get_all_tools
  9. from langchain_core.runnables import RunnableConfig
  10. from langchain.agents.middleware import before_model
  11. from langgraph.runtime import Runtime
  12. from typing import Any, List, Sequence
  13. from langchain.messages import RemoveMessage
  14. from langgraph.graph.message import REMOVE_ALL_MESSAGES
  15. import sqlite3
  16. from config.settings import settings
  17. dotenv.load_dotenv()
  18. def create_system_prompt(
  19. backend_url: str = "", token: str = "", username: str = "default"
  20. ) -> str:
  21. auth_status = "已认证" if token else "未认证"
  22. backend_available = "API可用" if backend_url and token else "仅数据查询"
  23. knowledge_status = (
  24. "知识库可用" if settings.KNOWLEDGE_BASE_ENABLED else "知识库已禁用"
  25. )
  26. if settings.KNOWLEDGE_BASE_ENABLED:
  27. # 知识库启用时的提示词
  28. system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status}
  29. 职责:ERP数据查询和问题解答,按用户语言回答。
  30. 工作流:
  31. 1. 分析问题意图,提取模块关键词
  32. 2. 如果是数据查询类问题,直接调用相关工具查询数据
  33. 3. 如果是其他问题,则通过工具搜索知识库,知识库工具使用流程:a.通过关键字获取相关文章列表,b.判断哪些文章最符合,c.再通过工具获取文章内容.严格按文章内容回复,不能编造答案.
  34. 4. 关键词要精准,避免无意义词
  35. 回答规则:
  36. - 知识库找不到时提示"正在学习该问题"
  37. - {"需要个人数据时验证认证状态" if backend_url else "仅提供数据查询和知识库支持"}
  38. - 保护隐私,专业准确,精炼简要
  39. {"- 后端地址: " + backend_url if backend_url else ""}
  40. {"- API用户的认证令牌: " + token if token else ""}
  41. 时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
  42. 数据查询结果尽量以 Markdown 表格格式输出,格式如下:
  43. | 列名1 | 列名2 | 列名3 |
  44. | :--- | :--- | :--- |
  45. | 数据1 | 数据2 | 数据3 |
  46. | 数据4 | 数据5 | 数据6 |
  47. """
  48. else:
  49. # 知识库禁用时的提示词 - 严格限制回答范围
  50. system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status}
  51. 职责:仅处理ERP数据查询类问题,按用户语言回答。
  52. 严格限制:
  53. - 知识库功能已禁用,无法回答任何非数据查询类问题
  54. - 禁止回答:疑问解答、操作流程、功能介绍、知识咨询等非数据查询问题
  55. - 禁止使用个人知识或经验进行回答
  56. 工作流:
  57. 1. 分析问题意图,判断是否为数据查询类问题
  58. 2. 如果是数据查询类问题,直接调用相关工具查询数据
  59. 3. 如果是非数据查询类问题(包括疑问、流程、操作等),必须明确回复:"知识库正在完善,无法回答该问题"
  60. 回答规则:
  61. - 如用户提出非ERP范围的问题(例如:"你好"等闲聊),明确告知用户自己的职责:仅处理ERP数据查询类问题
  62. - 非数据查询问题必须回复:"知识库正在完善,无法回答该问题"
  63. - 禁止尝试回答或提供任何建议
  64. - 禁止解释原因或提供替代方案
  65. - 严格按工具提供的数据回答数据查询问题,不能编造答案
  66. - {"需要个人数据时验证认证状态" if backend_url else "仅提供数据查询支持"}
  67. {"- 后端地址: " + backend_url if backend_url else ""}
  68. {"- API用户的认证令牌: " + token if token else ""}
  69. 当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
  70. 数据查询结果尽量以 Markdown 表格格式输出,格式如下:
  71. | 列名1 | 列名2 | 列名3 |
  72. | :--- | :--- | :--- |
  73. | 数据1 | 数据2 | 数据3 |
  74. | 数据4 | 数据5 | 数据6 |
  75. """
  76. print(system_prompt)
  77. return system_prompt
  78. def get_day_number(date=None):
  79. """获取日期编号 (YYYYMMDD 格式)"""
  80. if date is None:
  81. date = datetime.datetime.now()
  82. return date.strftime("%Y%m%d") # 格式: 20251229
  83. def get_sqlite_checkpointer():
  84. """创建按天分割的SQLite检查点保存器"""
  85. try:
  86. from langgraph.checkpoint.sqlite import SqliteSaver
  87. # 获取当前日期编号
  88. current_day = get_day_number()
  89. # 数据库文件存放目录
  90. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  91. base_dir = os.path.join(project_root, "data", "checkpoints")
  92. os.makedirs(base_dir, exist_ok=True)
  93. # 数据库文件名格式: checkpoints_20251229.db
  94. db_filename = f"checkpoints_{current_day}.db"
  95. db_path = os.path.join(base_dir, db_filename)
  96. # checkpointer = SqliteSaver.from_conn_string(db_path)
  97. conn = sqlite3.connect(db_path, check_same_thread=False)
  98. conn.execute("PRAGMA wal_autocheckpoint=500") # 2MB 就提交
  99. conn.execute("PRAGMA journal_size_limit=52428800") # 最大 50MB
  100. checkpointer = SqliteSaver(conn)
  101. return checkpointer
  102. except Exception as e:
  103. print(f"[ERROR]创建 SQLite 检查器失败: {e}")
  104. import traceback
  105. traceback.print_exc()
  106. # 回退到内存保存器
  107. from langgraph.checkpoint.memory import InMemorySaver
  108. print("[WARN]使用 InMemorySaver 作为回退")
  109. return InMemorySaver()
  110. def cleanup_old_checkpoints(max_days=7):
  111. """清理超过指定天数的旧检查点文件(可选功能)"""
  112. try:
  113. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  114. base_dir = os.path.join(project_root, "data", "checkpoints")
  115. if not os.path.exists(base_dir):
  116. return
  117. # 获取当前日期
  118. current_date = datetime.datetime.now()
  119. # 遍历目录中的所有.db文件
  120. for filename in os.listdir(base_dir):
  121. if filename.startswith("checkpoints_") and filename.endswith(".db"):
  122. try:
  123. print(f"检查旧检查点文件: {filename}")
  124. # 提取日期 (checkpoints_day_20251229.db -> 20251229)
  125. date_str = filename.replace("checkpoints_day_", "").replace(
  126. ".db", ""
  127. )
  128. file_date = datetime.datetime.strptime(date_str, "%Y%m%d")
  129. # 计算天数差
  130. days_diff = (current_date - file_date).days
  131. # 删除超过 max_days 天的旧数据
  132. if days_diff > max_days:
  133. file_path = os.path.join(base_dir, filename)
  134. os.remove(file_path)
  135. print(
  136. f"[CLEAN]清理旧检查点文件: {filename} (超过 {max_days} 天)"
  137. )
  138. except (ValueError, IndexError):
  139. # 文件名不符合预期,跳过
  140. continue
  141. except Exception as e:
  142. print(f"[WARN]清理旧检查点失败: {e}")
  143. # 创建agent
  144. def create_langchain_agent(
  145. backend_url: str = "",
  146. token: str = "",
  147. username: str = "default",
  148. thread_id: str = "default",
  149. ):
  150. llm = ChatOpenAI(
  151. model=settings.LLM_MODEL,
  152. temperature=settings.LLM_TEMPERATURE,
  153. api_key=settings.DEEPSEEK_API_KEY,
  154. base_url=settings.DEEPSEEK_BASE_URL,
  155. max_tokens=settings.LLM_MAX_TOKENS,
  156. )
  157. tools = get_all_tools()
  158. # 添加调试信息
  159. print(f"[DEBUG]Agent 创建调试信息:")
  160. print(f" - 用户: {username}")
  161. print(f" - Thread ID: {thread_id}")
  162. print(f" - 后端地址: {backend_url}")
  163. print(f" - Token: {'已提供' if token else '未提供'}")
  164. print(f" - 工具数量: {len(tools)}")
  165. for i, tool in enumerate(tools):
  166. print(f" - 工具 {i+1}: {tool.name}")
  167. # 获取动态的system_prompt
  168. system_prompt = create_system_prompt(backend_url, token, username)
  169. def simple_turn_based_trim(
  170. messages: Sequence[BaseMessage],
  171. keep_turns: int = 3,
  172. system_message: BaseMessage = None,
  173. ) -> List[BaseMessage]:
  174. """
  175. 修正版:按完整对话轮次修剪消息
  176. 每轮对话从Human开始,到下一个Human之前结束
  177. """
  178. if not messages:
  179. return []
  180. # 分离系统消息(始终保留)
  181. system_messages = []
  182. other_messages = []
  183. for msg in messages:
  184. if (
  185. isinstance(msg, SystemMessage)
  186. or getattr(msg, "type", None) == "system"
  187. or getattr(msg, "role", None) == "system"
  188. or msg.__class__.__name__ == "SystemMessage"
  189. ):
  190. system_messages.append(msg)
  191. else:
  192. other_messages.append(msg)
  193. if len(other_messages) <= 1:
  194. return system_messages + other_messages
  195. # 找出所有Human消息的位置
  196. human_indices = []
  197. for i, msg in enumerate(other_messages):
  198. if (
  199. isinstance(msg, HumanMessage)
  200. or getattr(msg, "type", None) == "human"
  201. or getattr(msg, "role", None) == "user"
  202. ):
  203. human_indices.append(i)
  204. # 如果Human消息不足keep_turns轮,返回所有
  205. if not human_indices or len(human_indices) <= keep_turns:
  206. return system_messages + other_messages
  207. # 计算起始索引
  208. start_idx = human_indices[-keep_turns]
  209. # 获取要保留的消息
  210. preserved_messages = other_messages[start_idx:]
  211. # 4. 返回从该索引开始的所有消息
  212. result = system_messages + preserved_messages
  213. # print(f"修剪后消息数: {len(result)}")
  214. return result
  215. @before_model
  216. def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
  217. """Keep only the last few messages to fit context window."""
  218. messages = state["messages"]
  219. if len(messages) <= 3:
  220. return None # No changes needed
  221. # 保留最后4轮对话
  222. trimmed_messages = simple_turn_based_trim(messages, keep_turns=4)
  223. return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)] + trimmed_messages}
  224. # 使用SQLiteSaver(按天分割)
  225. checkpointer = get_sqlite_checkpointer()
  226. # print(f"打印检查点保存器: {checkpointer}")
  227. # 可选:清理旧检查点(可配置为定期执行)
  228. if os.getenv("AUTO_CLEANUP", "false").lower() == "true":
  229. cleanup_old_checkpoints(max_days=7) # 保留最近7天数据
  230. agent = create_agent(
  231. llm,
  232. tools,
  233. checkpointer=checkpointer,
  234. system_prompt=system_prompt,
  235. middleware=[trim_messages],
  236. )
  237. return agent