import os import dotenv import datetime from pathlib import Path from langchain.agents import create_agent, AgentState from langchain_openai import ChatOpenAI from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage from tools.tool_factory import get_all_tools from langchain_core.runnables import RunnableConfig from langchain.agents.middleware import before_model from langgraph.runtime import Runtime from typing import Any, List, Sequence from langchain.messages import RemoveMessage from langgraph.graph.message import REMOVE_ALL_MESSAGES import sqlite3 from config.settings import settings dotenv.load_dotenv() def create_system_prompt( backend_url: str = "", token: str = "", username: str = "default" ) -> str: """ 创建动态的system_prompt,支持参数化配置 Args: backend_url: 后端API地址 token: 访问后端的认证令牌 username: 用户名 """ # 判断token状态 if token: token_status = "已配置有效的认证令牌,可以调用后端API获取用户数据" else: token_status = "未提供认证令牌,后端API调用可能受限" # 判断backend_url状态 if backend_url: backend_status = f"已配置后端地址: {backend_url}" else: backend_status = "未配置后端地址,只能访问知识库" system_prompt = f"""你是属于龙嘉软件公司的AI助手,名字叫小龙。 # 当前会话信息 - 当前用户: {username} - 后端服务状态: {backend_status} - 认证状态: {token_status} 现在时间是{datetime.datetime.now().isoformat()} # 核心能力 你可为客户提供ERP问题的解决方案,也可回答与龙嘉软件相关的问题。软件面向全球客户,你需按用户提问的语言回答。 # 工作流程 1. 分析用户问题的意图,提取关键词 2. 根据意图及关键词调用相应工具 3. 可以访问知识库工具,也可以调用后端API获取数据 # 后端API使用指南 {"- 当用户需要查询个人数据、订单信息、账户状态时,可使用后端API" if backend_url and token else "- 由于缺少认证信息,暂时无法调用后端API"} {"- 后端地址: " + backend_url if backend_url else ""} {"- API调用会自动包含用户的认证令牌: " + token if token else ""} # 知识库搜索规则 - 判断问题所属模块(销售、采购、生产、财务、仓储、权限等)并纳入关键词 - 文章匹配要精准,例如"销售订单新建权限",拆分为:"销售订单"、"新建"、"权限" - 避免使用"的"、"地"、"得"、"了"、"在"等无意义词汇 - 关键词可以多个,要判断问题属于哪个模块并将其纳入关键字 - 如果匹配文章太少(少于3篇),尝试以下方法: a) 变更关键字(同义词、近义词) b) 把关键字拆得更细 c) 扩大搜索范围(减少关键词数量) d) 重新搜索 - 获取到文章列表后,用工具获取文章内容然后回答用户问题 # 回答策略 - 优先使用知识库中的准确信息 - 如果知识库中有相关文章,结合文章内容进行回答 - 如果需要实时数据且认证有效,可调用后端API - 如果找不到对应知识库文章,向客户说明:"我正在学习这个问题的解决方案,很快就能正式为您服务" - 如果用户的问题需要后端数据但认证无效,提示:"查看个人数据需要登录验证,请确保已提供正确的访问令牌" # 注意事项 - 保护用户隐私,不在回复中暴露敏感信息 - 如果API调用失败,提供友好的错误信息 - 保持回答的专业性和准确性 - 对于不确定的问题,可以建议用户联系客服或技术支持 """ # print(system_prompt) return system_prompt def get_day_number(date=None): """获取日期编号 (YYYYMMDD 格式)""" if date is None: date = datetime.datetime.now() return date.strftime("%Y%m%d") # 格式: 20251229 def get_sqlite_checkpointer(): """创建按天分割的SQLite检查点保存器""" try: from langgraph.checkpoint.sqlite import SqliteSaver # 获取当前日期编号 current_day = get_day_number() # 数据库文件存放目录 project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) base_dir = os.path.join(project_root, "data", "checkpoints") os.makedirs(base_dir, exist_ok=True) # 数据库文件名格式: checkpoints_20251229.db db_filename = f"checkpoints_{current_day}.db" db_path = os.path.join(base_dir, db_filename) # checkpointer = SqliteSaver.from_conn_string(db_path) conn = sqlite3.connect(db_path, check_same_thread=False) conn.execute("PRAGMA wal_autocheckpoint=500") # 2MB 就提交 conn.execute("PRAGMA journal_size_limit=52428800") # 最大 50MB checkpointer = SqliteSaver(conn) return checkpointer except Exception as e: print(f"❌❌ 创建 SQLite 检查器失败: {e}") import traceback traceback.print_exc() # 回退到内存保存器 from langgraph.checkpoint.memory import InMemorySaver print("⚠️ 使用 InMemorySaver 作为回退") return InMemorySaver() def cleanup_old_checkpoints(max_days=7): """清理超过指定天数的旧检查点文件(可选功能)""" try: project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) base_dir = os.path.join(project_root, "data", "checkpoints") if not os.path.exists(base_dir): return # 获取当前日期 current_date = datetime.datetime.now() # 遍历目录中的所有.db文件 for filename in os.listdir(base_dir): if filename.startswith("checkpoints_") and filename.endswith(".db"): try: print(f"检查旧检查点文件: {filename}") # 提取日期 (checkpoints_day_20251229.db -> 20251229) date_str = filename.replace("checkpoints_day_", "").replace( ".db", "" ) file_date = datetime.datetime.strptime(date_str, "%Y%m%d") # 计算天数差 days_diff = (current_date - file_date).days # 删除超过 max_days 天的旧数据 if days_diff > max_days: file_path = os.path.join(base_dir, filename) os.remove(file_path) print(f"🧹🧹 清理旧检查点文件: {filename} (超过 {max_days} 天)") except (ValueError, IndexError): # 文件名不符合预期,跳过 continue except Exception as e: print(f"⚠️ 清理旧检查点失败: {e}") # 创建agent def create_langchain_agent( backend_url: str = "", token: str = "", username: str = "default", thread_id: str = "default", ): llm = ChatOpenAI( model=settings.LLM_MODEL, temperature=settings.LLM_TEMPERATURE, api_key=settings.DEEPSEEK_API_KEY, base_url=settings.DEEPSEEK_BASE_URL, max_tokens=settings.LLM_MAX_TOKENS, ) tools = get_all_tools() # 添加调试信息 print(f"🔧🔧🔧🔧 Agent 创建调试信息:") print(f" - 用户: {username}") print(f" - Thread ID: {thread_id}") print(f" - 后端地址: {backend_url}") print(f" - Token: {'已提供' if token else '未提供'}") print(f" - 工具数量: {len(tools)}") for i, tool in enumerate(tools): print(f" - 工具 {i+1}: {tool.name}") # 获取动态的system_prompt system_prompt = create_system_prompt(backend_url, token, username) def simple_turn_based_trim( messages: Sequence[BaseMessage], keep_turns: int = 3, system_message: BaseMessage = None, ) -> List[BaseMessage]: """ 修正版:按完整对话轮次修剪消息 每轮对话从Human开始,到下一个Human之前结束 """ if not messages: return [] # 分离系统消息(始终保留) system_messages = [] other_messages = [] for msg in messages: if ( isinstance(msg, SystemMessage) or getattr(msg, "type", None) == "system" or getattr(msg, "role", None) == "system" or msg.__class__.__name__ == "SystemMessage" ): system_messages.append(msg) else: other_messages.append(msg) if len(other_messages) <= 1: return system_messages + other_messages # 找出所有Human消息的位置 human_indices = [] for i, msg in enumerate(other_messages): if ( isinstance(msg, HumanMessage) or getattr(msg, "type", None) == "human" or getattr(msg, "role", None) == "user" ): human_indices.append(i) # 如果Human消息不足keep_turns轮,返回所有 if not human_indices or len(human_indices) <= keep_turns: return system_messages + other_messages # 计算起始索引 start_idx = human_indices[-keep_turns] # 获取要保留的消息 preserved_messages = other_messages[start_idx:] # 4. 返回从该索引开始的所有消息 result = system_messages + preserved_messages # print(f"修剪后消息数: {len(result)}") return result @before_model def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: """Keep only the last few messages to fit context window.""" messages = state["messages"] if len(messages) <= 3: return None # No changes needed # 保留最后4轮对话 trimmed_messages = simple_turn_based_trim(messages, keep_turns=4) return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)] + trimmed_messages} # 使用SQLiteSaver(按天分割) checkpointer = get_sqlite_checkpointer() # print(f"打印检查点保存器: {checkpointer}") # 可选:清理旧检查点(可配置为定期执行) if os.getenv("AUTO_CLEANUP", "false").lower() == "true": cleanup_old_checkpoints(max_days=7) # 保留最近7天数据 agent = create_agent( llm, tools, checkpointer=checkpointer, system_prompt=system_prompt, middleware=[trim_messages], ) return agent