| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- import os
- import dotenv
- import datetime
- from pathlib import Path
- from langchain.agents import create_agent, AgentState
- from langchain_openai import ChatOpenAI
- from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
- from tools.tool_factory import get_all_tools
- from langchain_core.runnables import RunnableConfig
- from langchain.agents.middleware import before_model
- from langgraph.runtime import Runtime
- from typing import Any, List, Sequence
- from langchain.messages import RemoveMessage
- from langgraph.graph.message import REMOVE_ALL_MESSAGES
- import sqlite3
- from config.settings import settings
- dotenv.load_dotenv()
- def create_system_prompt(
- backend_url: str = "", token: str = "", username: str = "default"
- ) -> str:
- auth_status = "已认证" if token else "未认证"
- backend_available = "API可用" if backend_url and token else "仅知识库"
- system_prompt = f"""小龙助手(龙嘉软件)- 用户:{username} 认证:{auth_status} 服务:{backend_available}
- 职责:ERP问题解答,按用户语言回答。
- 工作流:
- 1. 分析问题意图,提取模块关键词
- 2. {"优先知识库搜索,需要时调用API" if token else "仅使用知识库搜索"}
- 3. 关键词要精准,避免无意义词
- 回答规则:
- - 知识库优先,找不到时提示"正在学习该问题"
- - {"需要个人数据时验证认证状态" if backend_url else "仅提供知识库支持"}
- - 保护隐私,专业准确
- {"- 后端地址: " + backend_url if backend_url else ""}
- {"- API用户的认证令牌: " + token if token else ""}
- 时间:{datetime.datetime.now().strftime("%m-%d %H:%M")}
- 库存及销量查询结果尽量以 Markdown 表格格式输出,格式如下:
- | 列名1 | 列名2 | 列名3 |
- | :--- | :--- | :--- |
- | 数据1 | 数据2 | 数据3 |
- | 数据4 | 数据5 | 数据6 |
- """
- return system_prompt
- def get_day_number(date=None):
- """获取日期编号 (YYYYMMDD 格式)"""
- if date is None:
- date = datetime.datetime.now()
- return date.strftime("%Y%m%d") # 格式: 20251229
- def get_sqlite_checkpointer():
- """创建按天分割的SQLite检查点保存器"""
- try:
- from langgraph.checkpoint.sqlite import SqliteSaver
- # 获取当前日期编号
- current_day = get_day_number()
- # 数据库文件存放目录
- project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
- base_dir = os.path.join(project_root, "data", "checkpoints")
- os.makedirs(base_dir, exist_ok=True)
- # 数据库文件名格式: checkpoints_20251229.db
- db_filename = f"checkpoints_{current_day}.db"
- db_path = os.path.join(base_dir, db_filename)
- # checkpointer = SqliteSaver.from_conn_string(db_path)
- conn = sqlite3.connect(db_path, check_same_thread=False)
- conn.execute("PRAGMA wal_autocheckpoint=500") # 2MB 就提交
- conn.execute("PRAGMA journal_size_limit=52428800") # 最大 50MB
- checkpointer = SqliteSaver(conn)
- return checkpointer
- except Exception as e:
- print(f"[ERROR]创建 SQLite 检查器失败: {e}")
- import traceback
- traceback.print_exc()
- # 回退到内存保存器
- from langgraph.checkpoint.memory import InMemorySaver
- print("[WARN]使用 InMemorySaver 作为回退")
- return InMemorySaver()
- def cleanup_old_checkpoints(max_days=7):
- """清理超过指定天数的旧检查点文件(可选功能)"""
- try:
- project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
- base_dir = os.path.join(project_root, "data", "checkpoints")
- if not os.path.exists(base_dir):
- return
- # 获取当前日期
- current_date = datetime.datetime.now()
- # 遍历目录中的所有.db文件
- for filename in os.listdir(base_dir):
- if filename.startswith("checkpoints_") and filename.endswith(".db"):
- try:
- print(f"检查旧检查点文件: {filename}")
- # 提取日期 (checkpoints_day_20251229.db -> 20251229)
- date_str = filename.replace("checkpoints_day_", "").replace(
- ".db", ""
- )
- file_date = datetime.datetime.strptime(date_str, "%Y%m%d")
- # 计算天数差
- days_diff = (current_date - file_date).days
- # 删除超过 max_days 天的旧数据
- if days_diff > max_days:
- file_path = os.path.join(base_dir, filename)
- os.remove(file_path)
- print(
- f"[CLEAN]清理旧检查点文件: {filename} (超过 {max_days} 天)"
- )
- except (ValueError, IndexError):
- # 文件名不符合预期,跳过
- continue
- except Exception as e:
- print(f"[WARN]清理旧检查点失败: {e}")
- # 创建agent
- def create_langchain_agent(
- backend_url: str = "",
- token: str = "",
- username: str = "default",
- thread_id: str = "default",
- ):
- llm = ChatOpenAI(
- model=settings.LLM_MODEL,
- temperature=settings.LLM_TEMPERATURE,
- api_key=settings.DEEPSEEK_API_KEY,
- base_url=settings.DEEPSEEK_BASE_URL,
- max_tokens=settings.LLM_MAX_TOKENS,
- )
- tools = get_all_tools()
- # 添加调试信息
- print(f"[DEBUG]Agent 创建调试信息:")
- print(f" - 用户: {username}")
- print(f" - Thread ID: {thread_id}")
- print(f" - 后端地址: {backend_url}")
- print(f" - Token: {'已提供' if token else '未提供'}")
- print(f" - 工具数量: {len(tools)}")
- for i, tool in enumerate(tools):
- print(f" - 工具 {i+1}: {tool.name}")
- # 获取动态的system_prompt
- system_prompt = create_system_prompt(backend_url, token, username)
- def simple_turn_based_trim(
- messages: Sequence[BaseMessage],
- keep_turns: int = 3,
- system_message: BaseMessage = None,
- ) -> List[BaseMessage]:
- """
- 修正版:按完整对话轮次修剪消息
- 每轮对话从Human开始,到下一个Human之前结束
- """
- if not messages:
- return []
- # 分离系统消息(始终保留)
- system_messages = []
- other_messages = []
- for msg in messages:
- if (
- isinstance(msg, SystemMessage)
- or getattr(msg, "type", None) == "system"
- or getattr(msg, "role", None) == "system"
- or msg.__class__.__name__ == "SystemMessage"
- ):
- system_messages.append(msg)
- else:
- other_messages.append(msg)
- if len(other_messages) <= 1:
- return system_messages + other_messages
- # 找出所有Human消息的位置
- human_indices = []
- for i, msg in enumerate(other_messages):
- if (
- isinstance(msg, HumanMessage)
- or getattr(msg, "type", None) == "human"
- or getattr(msg, "role", None) == "user"
- ):
- human_indices.append(i)
- # 如果Human消息不足keep_turns轮,返回所有
- if not human_indices or len(human_indices) <= keep_turns:
- return system_messages + other_messages
- # 计算起始索引
- start_idx = human_indices[-keep_turns]
- # 获取要保留的消息
- preserved_messages = other_messages[start_idx:]
- # 4. 返回从该索引开始的所有消息
- result = system_messages + preserved_messages
- # print(f"修剪后消息数: {len(result)}")
- return result
- @before_model
- def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
- """Keep only the last few messages to fit context window."""
- messages = state["messages"]
- if len(messages) <= 3:
- return None # No changes needed
- # 保留最后4轮对话
- trimmed_messages = simple_turn_based_trim(messages, keep_turns=4)
- return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)] + trimmed_messages}
- # 使用SQLiteSaver(按天分割)
- checkpointer = get_sqlite_checkpointer()
- # print(f"打印检查点保存器: {checkpointer}")
- # 可选:清理旧检查点(可配置为定期执行)
- if os.getenv("AUTO_CLEANUP", "false").lower() == "true":
- cleanup_old_checkpoints(max_days=7) # 保留最近7天数据
- agent = create_agent(
- llm,
- tools,
- checkpointer=checkpointer,
- system_prompt=system_prompt,
- middleware=[trim_messages],
- )
- return agent
|