agent.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. import os
  2. import dotenv
  3. import datetime
  4. from pathlib import Path
  5. from langchain.agents import create_agent, AgentState
  6. from langchain_openai import ChatOpenAI
  7. from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
  8. from tools.tool_factory import get_all_tools
  9. from langchain_core.runnables import RunnableConfig
  10. from langchain.agents.middleware import before_model
  11. from langgraph.runtime import Runtime
  12. from typing import Any, List, Sequence
  13. from langchain.messages import RemoveMessage
  14. from langgraph.graph.message import REMOVE_ALL_MESSAGES
  15. import sqlite3
  16. from config.settings import settings
  17. dotenv.load_dotenv()
  18. def create_system_prompt(
  19. backend_url: str = "", token: str = "", username: str = "default"
  20. ) -> str:
  21. """
  22. 创建动态的system_prompt,支持参数化配置
  23. Args:
  24. backend_url: 后端API地址
  25. token: 访问后端的认证令牌
  26. username: 用户名
  27. """
  28. # 判断token状态
  29. if token:
  30. token_status = "已配置有效的认证令牌,可以调用后端API获取用户数据"
  31. else:
  32. token_status = "未提供认证令牌,后端API调用可能受限"
  33. # 判断backend_url状态
  34. if backend_url:
  35. backend_status = f"已配置后端地址: {backend_url}"
  36. else:
  37. backend_status = "未配置后端地址,只能访问知识库"
  38. system_prompt = f"""你是属于龙嘉软件公司的AI助手,名字叫小龙。
  39. # 当前会话信息
  40. - 当前用户: {username}
  41. - 后端服务状态: {backend_status}
  42. - 认证状态: {token_status}
  43. 现在时间是{datetime.datetime.now().isoformat()}
  44. # 核心能力
  45. 你可为客户提供ERP问题的解决方案,也可回答与龙嘉软件相关的问题。软件面向全球客户,你需按用户提问的语言回答。
  46. # 工作流程
  47. 1. 分析用户问题的意图,提取关键词
  48. 2. 根据意图及关键词调用相应工具
  49. 3. 可以访问知识库工具,也可以调用后端API获取数据
  50. # 后端API使用指南
  51. {"- 当用户需要查询个人数据、订单信息、账户状态时,可使用后端API" if backend_url and token else "- 由于缺少认证信息,暂时无法调用后端API"}
  52. {"- 后端地址: " + backend_url if backend_url else ""}
  53. {"- API调用会自动包含用户的认证令牌: " + token if token else ""}
  54. # 知识库搜索规则
  55. - 判断问题所属模块(销售、采购、生产、财务、仓储、权限等)并纳入关键词
  56. - 文章匹配要精准,例如"销售订单新建权限",拆分为:"销售订单"、"新建"、"权限"
  57. - 避免使用"的"、"地"、"得"、"了"、"在"等无意义词汇
  58. - 关键词可以多个,要判断问题属于哪个模块并将其纳入关键字
  59. - 如果匹配文章太少(少于3篇),尝试以下方法:
  60. a) 变更关键字(同义词、近义词)
  61. b) 把关键字拆得更细
  62. c) 扩大搜索范围(减少关键词数量)
  63. d) 重新搜索
  64. - 获取到文章列表后,用工具获取文章内容然后回答用户问题
  65. # 回答策略
  66. - 优先使用知识库中的准确信息
  67. - 如果知识库中有相关文章,结合文章内容进行回答
  68. - 如果需要实时数据且认证有效,可调用后端API
  69. - 如果找不到对应知识库文章,向客户说明:"我正在学习这个问题的解决方案,很快就能正式为您服务"
  70. - 如果用户的问题需要后端数据但认证无效,提示:"查看个人数据需要登录验证,请确保已提供正确的访问令牌"
  71. # 注意事项
  72. - 保护用户隐私,不在回复中暴露敏感信息
  73. - 如果API调用失败,提供友好的错误信息
  74. - 保持回答的专业性和准确性
  75. - 对于不确定的问题,可以建议用户联系客服或技术支持
  76. """
  77. # print(system_prompt)
  78. return system_prompt
  79. def get_day_number(date=None):
  80. """获取日期编号 (YYYYMMDD 格式)"""
  81. if date is None:
  82. date = datetime.datetime.now()
  83. return date.strftime("%Y%m%d") # 格式: 20251229
  84. def get_sqlite_checkpointer():
  85. """创建按天分割的SQLite检查点保存器"""
  86. try:
  87. from langgraph.checkpoint.sqlite import SqliteSaver
  88. # 获取当前日期编号
  89. current_day = get_day_number()
  90. # 数据库文件存放目录
  91. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  92. base_dir = os.path.join(project_root, "data", "checkpoints")
  93. os.makedirs(base_dir, exist_ok=True)
  94. # 数据库文件名格式: checkpoints_20251229.db
  95. db_filename = f"checkpoints_{current_day}.db"
  96. db_path = os.path.join(base_dir, db_filename)
  97. # checkpointer = SqliteSaver.from_conn_string(db_path)
  98. conn = sqlite3.connect(db_path, check_same_thread=False)
  99. conn.execute("PRAGMA wal_autocheckpoint=500") # 2MB 就提交
  100. conn.execute("PRAGMA journal_size_limit=52428800") # 最大 50MB
  101. checkpointer = SqliteSaver(conn)
  102. return checkpointer
  103. except Exception as e:
  104. print(f"❌❌ 创建 SQLite 检查器失败: {e}")
  105. import traceback
  106. traceback.print_exc()
  107. # 回退到内存保存器
  108. from langgraph.checkpoint.memory import InMemorySaver
  109. print("⚠️ 使用 InMemorySaver 作为回退")
  110. return InMemorySaver()
  111. def cleanup_old_checkpoints(max_days=7):
  112. """清理超过指定天数的旧检查点文件(可选功能)"""
  113. try:
  114. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  115. base_dir = os.path.join(project_root, "data", "checkpoints")
  116. if not os.path.exists(base_dir):
  117. return
  118. # 获取当前日期
  119. current_date = datetime.datetime.now()
  120. # 遍历目录中的所有.db文件
  121. for filename in os.listdir(base_dir):
  122. if filename.startswith("checkpoints_") and filename.endswith(".db"):
  123. try:
  124. print(f"检查旧检查点文件: {filename}")
  125. # 提取日期 (checkpoints_day_20251229.db -> 20251229)
  126. date_str = filename.replace("checkpoints_day_", "").replace(
  127. ".db", ""
  128. )
  129. file_date = datetime.datetime.strptime(date_str, "%Y%m%d")
  130. # 计算天数差
  131. days_diff = (current_date - file_date).days
  132. # 删除超过 max_days 天的旧数据
  133. if days_diff > max_days:
  134. file_path = os.path.join(base_dir, filename)
  135. os.remove(file_path)
  136. print(f"🧹🧹 清理旧检查点文件: {filename} (超过 {max_days} 天)")
  137. except (ValueError, IndexError):
  138. # 文件名不符合预期,跳过
  139. continue
  140. except Exception as e:
  141. print(f"⚠️ 清理旧检查点失败: {e}")
  142. # 创建agent
  143. def create_langchain_agent(
  144. backend_url: str = "",
  145. token: str = "",
  146. username: str = "default",
  147. thread_id: str = "default",
  148. ):
  149. llm = ChatOpenAI(
  150. model=settings.LLM_MODEL,
  151. temperature=settings.LLM_TEMPERATURE,
  152. api_key=settings.DEEPSEEK_API_KEY,
  153. base_url=settings.DEEPSEEK_BASE_URL,
  154. max_tokens=settings.LLM_MAX_TOKENS,
  155. )
  156. tools = get_all_tools()
  157. # 添加调试信息
  158. print(f"🔧🔧🔧🔧 Agent 创建调试信息:")
  159. print(f" - 用户: {username}")
  160. print(f" - Thread ID: {thread_id}")
  161. print(f" - 后端地址: {backend_url}")
  162. print(f" - Token: {'已提供' if token else '未提供'}")
  163. print(f" - 工具数量: {len(tools)}")
  164. for i, tool in enumerate(tools):
  165. print(f" - 工具 {i+1}: {tool.name}")
  166. # 获取动态的system_prompt
  167. system_prompt = create_system_prompt(backend_url, token, username)
  168. def simple_turn_based_trim(
  169. messages: Sequence[BaseMessage],
  170. keep_turns: int = 3,
  171. system_message: BaseMessage = None,
  172. ) -> List[BaseMessage]:
  173. """
  174. 修正版:按完整对话轮次修剪消息
  175. 每轮对话从Human开始,到下一个Human之前结束
  176. """
  177. if not messages:
  178. return []
  179. # 分离系统消息(始终保留)
  180. system_messages = []
  181. other_messages = []
  182. for msg in messages:
  183. if (
  184. isinstance(msg, SystemMessage)
  185. or getattr(msg, "type", None) == "system"
  186. or getattr(msg, "role", None) == "system"
  187. or msg.__class__.__name__ == "SystemMessage"
  188. ):
  189. system_messages.append(msg)
  190. else:
  191. other_messages.append(msg)
  192. if len(other_messages) <= 1:
  193. return system_messages + other_messages
  194. # 找出所有Human消息的位置
  195. human_indices = []
  196. for i, msg in enumerate(other_messages):
  197. if (
  198. isinstance(msg, HumanMessage)
  199. or getattr(msg, "type", None) == "human"
  200. or getattr(msg, "role", None) == "user"
  201. ):
  202. human_indices.append(i)
  203. # 如果Human消息不足keep_turns轮,返回所有
  204. if not human_indices or len(human_indices) <= keep_turns:
  205. return system_messages + other_messages
  206. # 计算起始索引
  207. start_idx = human_indices[-keep_turns]
  208. # 获取要保留的消息
  209. preserved_messages = other_messages[start_idx:]
  210. # 4. 返回从该索引开始的所有消息
  211. result = system_messages + preserved_messages
  212. # print(f"修剪后消息数: {len(result)}")
  213. return result
  214. @before_model
  215. def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
  216. """Keep only the last few messages to fit context window."""
  217. messages = state["messages"]
  218. if len(messages) <= 3:
  219. return None # No changes needed
  220. # 保留最后4轮对话
  221. trimmed_messages = simple_turn_based_trim(messages, keep_turns=4)
  222. return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)] + trimmed_messages}
  223. # 使用SQLiteSaver(按天分割)
  224. checkpointer = get_sqlite_checkpointer()
  225. # print(f"打印检查点保存器: {checkpointer}")
  226. # 可选:清理旧检查点(可配置为定期执行)
  227. if os.getenv("AUTO_CLEANUP", "false").lower() == "true":
  228. cleanup_old_checkpoints(max_days=7) # 保留最近7天数据
  229. agent = create_agent(
  230. llm,
  231. tools,
  232. checkpointer=checkpointer,
  233. system_prompt=system_prompt,
  234. middleware=[trim_messages],
  235. )
  236. return agent