agent.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import os
  2. import dotenv
  3. import datetime
  4. from pathlib import Path
  5. from langchain.agents import create_agent, AgentState
  6. from langchain_openai import ChatOpenAI
  7. from langchain_core.messages import (
  8. SystemMessage,
  9. HumanMessage,
  10. BaseMessage,
  11. trim_messages,
  12. )
  13. from tools.tool_factory import get_all_tools
  14. from langchain_core.runnables import RunnableConfig
  15. from langchain.agents.middleware import before_model
  16. from langgraph.runtime import Runtime
  17. from typing import Any, List, Sequence
  18. from langchain.messages import RemoveMessage
  19. from langgraph.graph.message import REMOVE_ALL_MESSAGES
  20. import sqlite3
  21. from config.settings import settings
  22. from langchain_core.messages.utils import count_tokens_approximately
  23. dotenv.load_dotenv()
  24. def create_system_prompt(
  25. backend_url: str = "", token: str = "", username: str = "default"
  26. ) -> str:
  27. auth_status = "已认证" if token else "未认证"
  28. backend_available = "API可用" if backend_url and token else "仅数据查询"
  29. knowledge_status = (
  30. "知识库可用" if settings.KNOWLEDGE_BASE_ENABLED else "知识库已禁用"
  31. )
  32. echart_status = "图表可用" if settings.ECHARTS_ENABLED else "图表已禁用"
  33. if settings.KNOWLEDGE_BASE_ENABLED:
  34. # 知识库启用时的提示词
  35. system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status} 图表:{echart_status}
  36. 职责:ERP数据查询和问题解答,按用户语言回答。
  37. **核心安全指令 (必遵)**:
  38. 1. **当前凭据 (每次工具调用必须使用)**:
  39. - 后端地址: {backend_url if backend_url else '无'}
  40. - API令牌: {token if token else '无'}
  41. 2. **禁止沿用历史**:**严禁**从对话历史中复制、沿用任何旧的后端地址、令牌或工具参数。历史记录仅用于理解背景,其中的工具详情**不能**作为本次调用的参数来源。
  42. 3. **调用规范**:调用查询工具时,**必须且只能**使用上方提供的当前凭据。
  43. 工作流:
  44. 1. 分析问题意图,提取模块关键词
  45. 2. 如果是数据查询类问题,直接调用相关工具查询数据
  46. 3. 如果是其他问题,则通过工具搜索知识库,知识库工具使用流程:a.通过关键字获取相关文章列表,b.判断哪些文章最符合,c.再通过工具获取文章内容.严格按文章内容回复,不能编造答案.
  47. 4. 关键词要精准,避免无意义词
  48. 工具调用规格:
  49. - 如果连续3次调用相同工具相同参数,自动停止
  50. - 工具返回相同结果但仍在重复调用时,自动停止
  51. 回答规则:
  52. - 知识库找不到时提示"正在学习该问题"
  53. - {"需要个人数据时验证认证状态" if backend_url else "仅提供数据查询和知识库支持"}
  54. - 保护隐私,专业准确,精炼简要
  55. 时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
  56. 数据查询结果尽量以 Markdown 表格格式输出,格式如下:
  57. | 列名1 | 列名2 | 列名3 |
  58. | :--- | :--- | :--- |
  59. | 数据1 | 数据2 | 数据3 |
  60. | 数据4 | 数据5 | 数据6 |
  61. """
  62. else:
  63. # 知识库禁用时的提示词 - 灵活处理工具返回结果
  64. system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status} 图表:{echart_status}
  65. 职责:处理ERP数据查询类问题,按用户语言回答。
  66. **核心安全指令 (必遵)**:
  67. 1. **当前凭据 (每次工具调用必须使用)**:
  68. - 后端地址: {backend_url if backend_url else '无'}
  69. - API令牌: {token if token else '无'}
  70. 2. **禁止沿用历史**:**严禁**从对话历史中复制、沿用任何旧的后端地址、令牌或工具参数。历史记录仅用于理解背景,其中的工具详情**不能**作为本次调用的参数来源。
  71. 3. **调用规范**:调用查询工具时,**必须且只能**使用上方提供的当前凭据。
  72. 工作流:
  73. 1. 分析问题意图,判断是否为数据查询类问题
  74. 2. 如果是数据查询类问题,直接调用相关工具查询数据
  75. 3. 根据工具返回的结果进行回答:
  76. - 如果工具返回了具体数据,按数据内容回答
  77. - 如果工具返回了错误信息(如"API返回错误","查询失败","没有权限"等),如实告知用户错误信息
  78. - 如果工具返回空数据或"未找到数据",如实告知用户
  79. 4. 如果是非数据查询类问题(如疑问、流程、操作等),回复:"知识库正在完善,无法回答该问题"
  80. 工具调用规格:
  81. - 禁止连续调用相同工具相同参数
  82. - 工具返回相同结果但仍在重复调用时,自动停止
  83. 回答规则:
  84. - 如用户提出非ERP范围的问题(例如:"你好"等闲聊),明确告知用户自己的职责:主要处理ERP数据查询类问题
  85. - 工具提示没有权限时,明确回复用户没有权限
  86. - 严格按工具返回的内容回答,不能编造答案,可对结果进行简单总结
  87. - 当工具返回错误信息时,如实转达给用户,不要添加额外解释
  88. - 保持专业、准确、简洁的回答风格
  89. {"- 需要个人数据时验证认证状态" if backend_url else "- 仅提供数据查询支持"}
  90. 当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
  91. 数据查询结果尽量以 Markdown 表格格式输出,格式如下:
  92. | 列名1 | 列名2 | 列名3 |
  93. | :--- | :--- | :--- |
  94. | 数据1 | 数据2 | 数据3 |
  95. | 数据4 | 数据5 | 数据6 |
  96. """
  97. if settings.ECHARTS_ENABLED:
  98. system_prompt = (
  99. system_prompt
  100. + """
  101. 并且根据数据的格式,主动选择合适的图表输出,你可以输出柱状图、折线图、饼图。
  102. 饼图格式范例如下:
  103. ```echarts
  104. {{
  105. "title": {{
  106. "text": "浏览器份额", "left": "center" }}
  107. "tooltip": {{
  108. "trigger": "item" }},
  109. "legend": {{
  110. "orient": "vertical", "left": "left" }},
  111. "series": [
  112. {{
  113. "name": "Share",
  114. "type": "pie",
  115. "radius": "55%",
  116. "center": ["50%", "60%"],
  117. "data": [
  118. {{"value": 1048, "name": "Chrome" }}
  119. {{"value": 735, "name": "Firefox" }}
  120. {{"value": 580, "name": "Edge" }}
  121. ]
  122. }}
  123. ]
  124. }}
  125. ```
  126. 柱状图格式范例如下:
  127. ```echarts
  128. {{
  129. "title": {{"text": "每周销量" }}
  130. "tooltip": {{}},
  131. "xAxis": {{"type": "category", "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"] }}
  132. "yAxis": {{"type": "value" }}
  133. "series": [
  134. {{"type": "bar", "data": [120, 200, 150, 80, 70, 110, 130] }}
  135. ]
  136. }}
  137. ```
  138. 折线图,格式范例如下:
  139. ```echarts
  140. {{
  141. "title": {{ "text": "温度趋势" }},
  142. "tooltip": {{ "trigger": "axis" }},
  143. "legend": {{ "data": ["最高", "最低"] }},
  144. "xAxis": {{
  145. "type": "category",
  146. "boundaryGap": false,
  147. "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"]
  148. }},
  149. "yAxis": {{"type": "value" }},
  150. "series": [
  151. {{"name": "最高", "type": "line", "data": [11, 11, 15, 13, 12, 13, 10], "smooth": true }}
  152. {{"name": "最低", "type": "line", "data": [1, -2, 2, 5, 3, 2, 0], "smooth": true }}
  153. ]
  154. }}
  155. ```
  156. """
  157. )
  158. return system_prompt
  159. def get_day_number(date=None):
  160. """获取日期编号 (YYYYMMDD 格式)"""
  161. if date is None:
  162. date = datetime.datetime.now()
  163. return date.strftime("%Y%m%d") # 格式: 20251229
  164. def get_sqlite_checkpointer():
  165. """创建按天分割的SQLite检查点保存器"""
  166. try:
  167. from langgraph.checkpoint.sqlite import SqliteSaver
  168. # 获取当前日期编号
  169. current_day = get_day_number()
  170. # 数据库文件存放目录
  171. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  172. base_dir = os.path.join(project_root, "data", "checkpoints")
  173. os.makedirs(base_dir, exist_ok=True)
  174. # 数据库文件名格式: checkpoints_20251229.db
  175. db_filename = f"checkpoints_{current_day}.db"
  176. db_path = os.path.join(base_dir, db_filename)
  177. # checkpointer = SqliteSaver.from_conn_string(db_path)
  178. conn = sqlite3.connect(db_path, check_same_thread=False)
  179. conn.execute("PRAGMA wal_autocheckpoint=500") # 2MB 就提交
  180. conn.execute("PRAGMA journal_size_limit=52428800") # 最大 50MB
  181. checkpointer = SqliteSaver(conn)
  182. return checkpointer
  183. except Exception as e:
  184. print(f"[ERROR]创建 SQLite 检查器失败: {e}")
  185. import traceback
  186. traceback.print_exc()
  187. # 回退到内存保存器
  188. from langgraph.checkpoint.memory import InMemorySaver
  189. print("[WARN]使用 InMemorySaver 作为回退")
  190. return InMemorySaver()
  191. def cleanup_old_checkpoints(max_days=7):
  192. """清理超过指定天数的旧检查点文件(可选功能)"""
  193. try:
  194. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  195. base_dir = os.path.join(project_root, "data", "checkpoints")
  196. if not os.path.exists(base_dir):
  197. return
  198. # 获取当前日期
  199. current_date = datetime.datetime.now()
  200. # 遍历目录中的所有.db文件
  201. for filename in os.listdir(base_dir):
  202. if filename.startswith("checkpoints_") and filename.endswith(".db"):
  203. try:
  204. print(f"检查旧检查点文件: {filename}")
  205. # 提取日期 (checkpoints_day_20251229.db -> 20251229)
  206. date_str = filename.replace("checkpoints_day_", "").replace(
  207. ".db", ""
  208. )
  209. file_date = datetime.datetime.strptime(date_str, "%Y%m%d")
  210. # 计算天数差
  211. days_diff = (current_date - file_date).days
  212. # 删除超过 max_days 天的旧数据
  213. if days_diff > max_days:
  214. file_path = os.path.join(base_dir, filename)
  215. os.remove(file_path)
  216. print(
  217. f"[CLEAN]清理旧检查点文件: {filename} (超过 {max_days} 天)"
  218. )
  219. except (ValueError, IndexError):
  220. # 文件名不符合预期,跳过
  221. continue
  222. except Exception as e:
  223. print(f"[WARN]清理旧检查点失败: {e}")
  224. # 创建agent
  225. def create_langchain_agent(
  226. backend_url: str = "",
  227. token: str = "",
  228. username: str = "default",
  229. thread_id: str = "default",
  230. ):
  231. llm = ChatOpenAI(
  232. model=settings.LLM_MODEL,
  233. temperature=settings.LLM_TEMPERATURE,
  234. api_key=settings.DEEPSEEK_API_KEY,
  235. base_url=settings.DEEPSEEK_BASE_URL,
  236. max_tokens=settings.LLM_MAX_TOKENS,
  237. )
  238. tools = get_all_tools()
  239. # 添加调试信息
  240. print(f"[DEBUG]Agent 创建调试信息:")
  241. print(f" - 用户: {username}")
  242. print(f" - Thread ID: {thread_id}")
  243. print(f" - 后端地址: {backend_url}")
  244. print(f" - Token: {'已提供' if token else '未提供'}")
  245. print(f" - 工具数量: {len(tools)}")
  246. for i, tool in enumerate(tools):
  247. print(f" - 工具 {i+1}: {tool.name}")
  248. # 获取动态的system_prompt
  249. system_prompt = create_system_prompt(backend_url, token, username)
  250. print(system_prompt)
  251. # def simple_turn_based_trim(
  252. # messages: Sequence[BaseMessage],
  253. # keep_turns: int = 3,
  254. # system_message: BaseMessage = None,
  255. # ) -> List[BaseMessage]:
  256. # """
  257. # 修正版:按完整对话轮次修剪消息
  258. # 每轮对话从Human开始,到下一个Human之前结束
  259. # """
  260. # if not messages:
  261. # return []
  262. # # 分离系统消息(始终保留)
  263. # system_messages = []
  264. # other_messages = []
  265. # for msg in messages:
  266. # if (
  267. # isinstance(msg, SystemMessage)
  268. # or getattr(msg, "type", None) == "system"
  269. # or getattr(msg, "role", None) == "system"
  270. # or msg.__class__.__name__ == "SystemMessage"
  271. # ):
  272. # system_messages.append(msg)
  273. # else:
  274. # other_messages.append(msg)
  275. # if len(other_messages) <= 1:
  276. # return system_messages + other_messages
  277. # # 找出所有Human消息的位置
  278. # human_indices = []
  279. # for i, msg in enumerate(other_messages):
  280. # if (
  281. # isinstance(msg, HumanMessage)
  282. # or getattr(msg, "type", None) == "human"
  283. # or getattr(msg, "role", None) == "user"
  284. # ):
  285. # human_indices.append(i)
  286. # # 如果Human消息不足keep_turns轮,返回所有
  287. # if not human_indices or len(human_indices) <= keep_turns:
  288. # return system_messages + other_messages
  289. # # 计算起始索引
  290. # start_idx = human_indices[-keep_turns]
  291. # # 获取要保留的消息
  292. # preserved_messages = other_messages[start_idx:]
  293. # # 4. 返回从该索引开始的所有消息
  294. # result = system_messages + preserved_messages
  295. # # print(f"修剪后消息数: {len(result)}")
  296. # return result
  297. # @before_model
  298. # def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
  299. # """Keep only the last few messages to fit context window."""
  300. # messages = state["messages"]
  301. # if len(messages) <= 3:
  302. # return None # No changes needed
  303. # # 保留最后4轮对话
  304. # trimmed_messages = simple_turn_based_trim(messages, keep_turns=4)
  305. # return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)] + trimmed_messages}
  306. @before_model
  307. def trim_messages_middleware(
  308. state: AgentState, runtime: Runtime
  309. ) -> dict[str, Any] | None:
  310. """使用官方trim_messages函数修剪消息"""
  311. messages = state.get("messages", [])
  312. print(f"trim_messages_middleware[DEBUG]原始消息数: {len(messages)}")
  313. # if len(messages) <= 3:
  314. # return None # 不需要修剪
  315. trimmed_messages = trim_messages(
  316. messages,
  317. max_tokens=1000,
  318. strategy="last", # 保留最近的对话
  319. token_counter=count_tokens_approximately, # token计数器
  320. start_on="human", # 从human消息开始计算轮次
  321. include_system=True, # 包含系统消息
  322. )
  323. # 添加调试信息
  324. original_count = len(messages)
  325. trimmed_count = len(trimmed_messages)
  326. print(f"trim_messages_middleware[DEBUG]修剪后消息数: {trimmed_count}")
  327. if trimmed_count < original_count:
  328. print(f"[INFO]消息修剪: {original_count} -> {trimmed_count} 条消息")
  329. return {"messages": trimmed_messages}
  330. # 使用SQLiteSaver(按天分割)
  331. checkpointer = get_sqlite_checkpointer()
  332. # print(f"打印检查点保存器: {checkpointer}")
  333. # 可选:清理旧检查点(可配置为定期执行)
  334. if os.getenv("AUTO_CLEANUP", "false").lower() == "true":
  335. cleanup_old_checkpoints(max_days=7) # 保留最近7天数据
  336. agent = create_agent(
  337. llm,
  338. tools,
  339. checkpointer=checkpointer,
  340. system_prompt=system_prompt,
  341. middleware=[trim_messages_middleware],
  342. )
  343. return agent