Parcourir la source

改用官方的修剪历史消息函数

longjoedyy il y a 1 jour
Parent
commit
6dd2fd2893
6 fichiers modifiés avec 221 ajouts et 133 suppressions
  1. 1 1
      config/tool_config.json
  2. 161 112
      core/agent.py
  3. 1 1
      core/agent_manager.py
  4. 1 1
      core/chat_service.py
  5. 1 0
      core/document_processor/document_service.py
  6. 56 18
      utils/logger.py

+ 1 - 1
config/tool_config.json

@@ -72,7 +72,7 @@
         "输出格式要求": [
             "以表格输出,如果custname和viewdate数据全部空白,不显示客户名称和收款日期列"
         ],
-        "使用示例": "用户输入:'查询客户A,2025年的收款情况' -> 系统调用此工具获取客户A在2025年1月至12月的收款数据"
+        "使用示例": "用户输入:'查询客户A,2025年的收款情况' -> 系统调用此工具获取客户A在2025年1月至12月的收款数据;'今天收款多少?'->获取所有客户今天的收款数据"
     },
     "get_bmsttake_with_mx": {
         "基础描述": "获取指定客户应收帐数据",

+ 161 - 112
core/agent.py

@@ -5,7 +5,12 @@ from pathlib import Path
 
 from langchain.agents import create_agent, AgentState
 from langchain_openai import ChatOpenAI
-from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
+from langchain_core.messages import (
+    SystemMessage,
+    HumanMessage,
+    BaseMessage,
+    trim_messages,
+)
 from tools.tool_factory import get_all_tools
 from langchain_core.runnables import RunnableConfig
 from langchain.agents.middleware import before_model
@@ -15,6 +20,7 @@ from langchain.messages import RemoveMessage
 from langgraph.graph.message import REMOVE_ALL_MESSAGES
 import sqlite3
 from config.settings import settings
+from langchain_core.messages.utils import count_tokens_approximately
 
 dotenv.load_dotenv()
 
@@ -30,54 +36,68 @@ def create_system_prompt(
     if settings.KNOWLEDGE_BASE_ENABLED:
         # 知识库启用时的提示词
         system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status}
-职责:ERP数据查询和问题解答,按用户语言回答。
-工作流:
-1. 分析问题意图,提取模块关键词
-2. 如果是数据查询类问题,直接调用相关工具查询数据
-3. 如果是其他问题,则通过工具搜索知识库,知识库工具使用流程:a.通过关键字获取相关文章列表,b.判断哪些文章最符合,c.再通过工具获取文章内容.严格按文章内容回复,不能编造答案.
-4. 关键词要精准,避免无意义词
-回答规则:
-- 知识库找不到时提示"正在学习该问题"
-- {"需要个人数据时验证认证状态" if backend_url else "仅提供数据查询和知识库支持"}
-- 保护隐私,专业准确,精炼简要
-{"- 后端地址: " + backend_url if backend_url else ""}
-{"- API用户的认证令牌: " + token if token else ""}
-时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
-数据查询结果尽量以 Markdown 表格格式输出,格式如下:
-| 列名1 | 列名2 | 列名3 |
-| :--- | :--- | :--- |
-| 数据1 | 数据2 | 数据3 |
-| 数据4 | 数据5 | 数据6 |
-"""
+    职责:ERP数据查询和问题解答,按用户语言回答。
+
+    **核心安全指令 (必遵)**:
+    1.  **当前凭据 (每次工具调用必须使用)**:
+        - 后端地址: {backend_url if backend_url else '无'}
+        - API令牌: {token if token else '无'}
+    2.  **禁止沿用历史**:**严禁**从对话历史中复制、沿用任何旧的后端地址、令牌或工具参数。历史记录仅用于理解背景,其中的工具详情**不能**作为本次调用的参数来源。
+    3.  **调用规范**:调用查询工具时,**必须且只能**使用上方提供的当前凭据。
+    工作流:
+    1. 分析问题意图,提取模块关键词
+    2. 如果是数据查询类问题,直接调用相关工具查询数据
+    3. 如果是其他问题,则通过工具搜索知识库,知识库工具使用流程:a.通过关键字获取相关文章列表,b.判断哪些文章最符合,c.再通过工具获取文章内容.严格按文章内容回复,不能编造答案.
+    4. 关键词要精准,避免无意义词
+    工具调用规格:
+    - 如果连续3次调用相同工具相同参数,自动停止
+    - 工具返回相同结果但仍在重复调用时,自动停止
+    回答规则:
+    - 知识库找不到时提示"正在学习该问题"
+    - {"需要个人数据时验证认证状态" if backend_url else "仅提供数据查询和知识库支持"}
+    - 保护隐私,专业准确,精炼简要
+    时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
+    数据查询结果尽量以 Markdown 表格格式输出,格式如下:
+    | 列名1 | 列名2 | 列名3 |
+    | :--- | :--- | :--- |
+    | 数据1 | 数据2 | 数据3 |
+    | 数据4 | 数据5 | 数据6 |
+    """
     else:
-        # 知识库禁用时的提示词 - 严格限制回答范围
+        # 知识库禁用时的提示词 - 灵活处理工具返回结果
         system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status}
-职责:仅处理ERP数据查询类问题,按用户语言回答。
-严格限制:
-- 知识库功能已禁用,无法回答任何非数据查询类问题
-- 禁止回答:疑问解答、操作流程、功能介绍、知识咨询等非数据查询问题
-- 禁止使用个人知识或经验进行回答
-工作流:
-1. 分析问题意图,判断是否为数据查询类问题
-2. 如果是数据查询类问题,直接调用相关工具查询数据
-3. 如果是非数据查询类问题(包括疑问、流程、操作等),必须明确回复:"知识库正在完善,无法回答该问题"
-回答规则:
-- 如用户提出非ERP范围的问题(例如:"你好"等闲聊),明确告知用户自己的职责:仅处理ERP数据查询类问题
-- 非数据查询问题必须回复:"知识库正在完善,无法回答该问题"
-- 禁止尝试回答或提供任何建议
-- 禁止解释原因或提供替代方案
-- 严格按工具提供的数据回答数据查询问题,不能编造答案
-- {"需要个人数据时验证认证状态" if backend_url else "仅提供数据查询支持"}
-{"- 后端地址: " + backend_url if backend_url else ""}
-{"- API用户的认证令牌: " + token if token else ""}
-当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
-数据查询结果尽量以 Markdown 表格格式输出,格式如下:
-| 列名1 | 列名2 | 列名3 |
-| :--- | :--- | :--- |
-| 数据1 | 数据2 | 数据3 |
-| 数据4 | 数据5 | 数据6 |
-"""
-    print(system_prompt)
+    职责:处理ERP数据查询类问题,按用户语言回答。
+    **核心安全指令 (必遵)**:
+    1.  **当前凭据 (每次工具调用必须使用)**:
+        - 后端地址: {backend_url if backend_url else '无'}
+        - API令牌: {token if token else '无'}
+    2.  **禁止沿用历史**:**严禁**从对话历史中复制、沿用任何旧的后端地址、令牌或工具参数。历史记录仅用于理解背景,其中的工具详情**不能**作为本次调用的参数来源。
+    3.  **调用规范**:调用查询工具时,**必须且只能**使用上方提供的当前凭据。
+    工作流:
+    1. 分析问题意图,判断是否为数据查询类问题
+    2. 如果是数据查询类问题,直接调用相关工具查询数据
+    3. 根据工具返回的结果进行回答:
+    - 如果工具返回了具体数据,按数据内容回答
+    - 如果工具返回了错误信息(如"API返回错误","查询失败","没有权限"等),如实告知用户错误信息
+    - 如果工具返回空数据或"未找到数据",如实告知用户
+    4. 如果是非数据查询类问题(如疑问、流程、操作等),回复:"知识库正在完善,无法回答该问题"
+    工具调用规格:
+    - 禁止连续调用相同工具相同参数
+    - 工具返回相同结果但仍在重复调用时,自动停止
+    回答规则:
+    - 如用户提出非ERP范围的问题(例如:"你好"等闲聊),明确告知用户自己的职责:主要处理ERP数据查询类问题
+    - 工具提示没有权限时,明确回复用户没有权限
+    - 严格按工具返回的内容回答,不能编造答案,可对结果进行简单总结
+    - 当工具返回错误信息时,如实转达给用户,不要添加额外解释
+    - 保持专业、准确、简洁的回答风格
+    {"- 需要个人数据时验证认证状态" if backend_url else "- 仅提供数据查询支持"}
+    当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
+    数据查询结果尽量以 Markdown 表格格式输出,格式如下:
+    | 列名1 | 列名2 | 列名3 |
+    | :--- | :--- | :--- |
+    | 数据1 | 数据2 | 数据3 |
+    | 数据4 | 数据5 | 数据6 |
+    """
     return system_prompt
 
 
@@ -197,72 +217,101 @@ def create_langchain_agent(
     # 获取动态的system_prompt
     system_prompt = create_system_prompt(backend_url, token, username)
 
-    def simple_turn_based_trim(
-        messages: Sequence[BaseMessage],
-        keep_turns: int = 3,
-        system_message: BaseMessage = None,
-    ) -> List[BaseMessage]:
-        """
-        修正版:按完整对话轮次修剪消息
-        每轮对话从Human开始,到下一个Human之前结束
-        """
-        if not messages:
-            return []
-        # 分离系统消息(始终保留)
-        system_messages = []
-        other_messages = []
-        for msg in messages:
-            if (
-                isinstance(msg, SystemMessage)
-                or getattr(msg, "type", None) == "system"
-                or getattr(msg, "role", None) == "system"
-                or msg.__class__.__name__ == "SystemMessage"
-            ):
-                system_messages.append(msg)
-            else:
-                other_messages.append(msg)
-
-        if len(other_messages) <= 1:
-            return system_messages + other_messages
-
-        # 找出所有Human消息的位置
-        human_indices = []
-        for i, msg in enumerate(other_messages):
-            if (
-                isinstance(msg, HumanMessage)
-                or getattr(msg, "type", None) == "human"
-                or getattr(msg, "role", None) == "user"
-            ):
-                human_indices.append(i)
-
-        # 如果Human消息不足keep_turns轮,返回所有
-        if not human_indices or len(human_indices) <= keep_turns:
-            return system_messages + other_messages
-
-        # 计算起始索引
-        start_idx = human_indices[-keep_turns]
-
-        # 获取要保留的消息
-        preserved_messages = other_messages[start_idx:]
-
-        # 4. 返回从该索引开始的所有消息
-        result = system_messages + preserved_messages
-        # print(f"修剪后消息数: {len(result)}")
-
-        return result
+    # def simple_turn_based_trim(
+    #     messages: Sequence[BaseMessage],
+    #     keep_turns: int = 3,
+    #     system_message: BaseMessage = None,
+    # ) -> List[BaseMessage]:
+    #     """
+    #     修正版:按完整对话轮次修剪消息
+    #     每轮对话从Human开始,到下一个Human之前结束
+    #     """
+    #     if not messages:
+    #         return []
+    #     # 分离系统消息(始终保留)
+    #     system_messages = []
+    #     other_messages = []
+    #     for msg in messages:
+    #         if (
+    #             isinstance(msg, SystemMessage)
+    #             or getattr(msg, "type", None) == "system"
+    #             or getattr(msg, "role", None) == "system"
+    #             or msg.__class__.__name__ == "SystemMessage"
+    #         ):
+    #             system_messages.append(msg)
+    #         else:
+    #             other_messages.append(msg)
+
+    #     if len(other_messages) <= 1:
+    #         return system_messages + other_messages
+
+    #     # 找出所有Human消息的位置
+    #     human_indices = []
+    #     for i, msg in enumerate(other_messages):
+    #         if (
+    #             isinstance(msg, HumanMessage)
+    #             or getattr(msg, "type", None) == "human"
+    #             or getattr(msg, "role", None) == "user"
+    #         ):
+    #             human_indices.append(i)
+
+    #     # 如果Human消息不足keep_turns轮,返回所有
+    #     if not human_indices or len(human_indices) <= keep_turns:
+    #         return system_messages + other_messages
+
+    #     # 计算起始索引
+    #     start_idx = human_indices[-keep_turns]
+
+    #     # 获取要保留的消息
+    #     preserved_messages = other_messages[start_idx:]
+
+    #     # 4. 返回从该索引开始的所有消息
+    #     result = system_messages + preserved_messages
+    #     # print(f"修剪后消息数: {len(result)}")
+
+    #     return result
+
+    # @before_model
+    # def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
+    #     """Keep only the last few messages to fit context window."""
+    #     messages = state["messages"]
+
+    #     if len(messages) <= 3:
+    #         return None  # No changes needed
+
+    #     # 保留最后4轮对话
+    #     trimmed_messages = simple_turn_based_trim(messages, keep_turns=4)
+
+    #     return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)] + trimmed_messages}
 
     @before_model
-    def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
-        """Keep only the last few messages to fit context window."""
-        messages = state["messages"]
-
-        if len(messages) <= 3:
-            return None  # No changes needed
-
-        # 保留最后4轮对话
-        trimmed_messages = simple_turn_based_trim(messages, keep_turns=4)
-
-        return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)] + trimmed_messages}
+    def trim_messages_middleware(
+        state: AgentState, runtime: Runtime
+    ) -> dict[str, Any] | None:
+        """使用官方trim_messages函数修剪消息"""
+        messages = state.get("messages", [])
+        print(f"trim_messages_middleware[DEBUG]原始消息数: {len(messages)}")
+        # if len(messages) <= 3:
+        #     return None  # 不需要修剪
+
+        trimmed_messages = trim_messages(
+            messages,
+            max_tokens=1000,
+            strategy="last",  # 保留最近的对话
+            token_counter=count_tokens_approximately,  # token计数器
+            start_on="human",  # 从human消息开始计算轮次
+            include_system=True,  # 包含系统消息
+        )
+
+        # 添加调试信息
+        original_count = len(messages)
+        trimmed_count = len(trimmed_messages)
+        print(f"trim_messages_middleware[DEBUG]修剪后消息数: {trimmed_count}")
+
+        if trimmed_count < original_count:
+            print(f"[INFO]消息修剪: {original_count} -> {trimmed_count} 条消息")
+
+        return {"messages": trimmed_messages}
 
     # 使用SQLiteSaver(按天分割)
     checkpointer = get_sqlite_checkpointer()
@@ -277,7 +326,7 @@ def create_langchain_agent(
         tools,
         checkpointer=checkpointer,
         system_prompt=system_prompt,
-        middleware=[trim_messages],
+        middleware=[trim_messages_middleware],
     )
 
     return agent

+ 1 - 1
core/agent_manager.py

@@ -58,7 +58,7 @@ class AgentManager:
         config_key = self._get_agent_config_key(
             thread_id, clean_username, clean_backend, clean_token
         )
-
+        print(f"config_key: {config_key}")
         # 检查本地配置缓存
         current_time = time.time()
         if config_key in self._local_agent_cache:

+ 1 - 1
core/chat_service.py

@@ -80,7 +80,7 @@ class ChatService:
 
         # 准备输入
         inputs = {"messages": [HumanMessage(content=message)]}
-        config = {"configurable": {"thread_id": thread_id}}
+        config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 5}
 
         chat_logger.info(f"在线程池中执行Agent - 用户={user_id}")
 

+ 1 - 0
core/document_processor/document_service.py

@@ -138,6 +138,7 @@ class DocumentProcessingService:
                 }
             )
         print("result:", result)
+        chat_logger.info(f"message_create_bill result: {result}")
         return result
 
     async def _detect_document_type(self, ocr_text: str) -> Optional[str]:

+ 56 - 18
utils/logger.py

@@ -2,6 +2,7 @@ import logging
 import json
 from datetime import datetime
 import os
+import sys
 from typing import Dict, Any
 from pathlib import Path
 
@@ -13,30 +14,46 @@ LOG_DIR.mkdir(exist_ok=True)
 DETAILED_LOG_DIR = LOG_DIR / "detailed"
 DETAILED_LOG_DIR.mkdir(exist_ok=True)
 
+# 全局变量
+current_log_date = datetime.now().strftime("%Y%m%d")
+chat_logger = None  # 初始化为None,在setup_logging中初始化
 
-# 配置日志
-def setup_logging():
+
+def setup_logging(force_reconfigure=False):
     """配置日志系统"""
-    # 主日志记录器
+    global chat_logger
+
+    # 创建或重新获取logger
     logger = logging.getLogger("chat_logger")
     logger.setLevel(logging.INFO)
 
-    # 避免重复添加handler
-    if not logger.handlers:
-        # 文件处理器 - 按天分割
-        log_file = LOG_DIR / f"chat_{datetime.now().strftime('%Y%m%d')}.log"
+    # 文件处理器 - 按天分割
+    today_str = datetime.now().strftime("%Y%m%d")
+    log_file = LOG_DIR / f"chat_{today_str}.log"
+
+    # 检查是否需要重新配置(强制重新配置或没有handlers)
+    needs_reconfigure = force_reconfigure or not logger.handlers
+
+    if needs_reconfigure:
+        # 清理旧的handler
+        for handler in logger.handlers[:]:
+            handler.close()
+            logger.removeHandler(handler)
+
+        # 创建新的文件处理器
         file_handler = logging.FileHandler(log_file, encoding="utf-8")
         file_handler.setLevel(logging.INFO)
 
         # 控制台处理器
         console_handler = logging.StreamHandler()
         console_handler.setLevel(logging.INFO)
-        # 确保控制台输出使用UTF-8编码
-        import sys
-        if hasattr(sys.stdout, 'reconfigure'):
-            sys.stdout.reconfigure(encoding='utf-8')
-        if hasattr(sys.stderr, 'reconfigure'):
-            sys.stderr.reconfigure(encoding='utf-8')
+
+        # 确保控制台编码
+        try:
+            if hasattr(sys.stdout, "reconfigure") and sys.stdout.encoding != "utf-8":
+                sys.stdout.reconfigure(encoding="utf-8")
+        except:
+            pass
 
         # 格式化
         formatter = logging.Formatter(
@@ -49,6 +66,7 @@ def setup_logging():
         logger.addHandler(file_handler)
         logger.addHandler(console_handler)
 
+    chat_logger = logger
     return logger
 
 
@@ -57,10 +75,7 @@ chat_logger = setup_logging()
 
 
 def get_detailed_log_path(user_id: str, timestamp: datetime = None) -> Path:
-    """
-    获取详细日志文件的路径
-    格式: chat_logs/detailed/YYYYMMDD/user_id_YYYYMMDD_HHMMSS.json
-    """
+    """获取详细日志文件的路径"""
     if timestamp is None:
         timestamp = datetime.now()
 
@@ -71,18 +86,41 @@ def get_detailed_log_path(user_id: str, timestamp: datetime = None) -> Path:
 
     # 生成文件名
     filename = f"{user_id}_{timestamp.strftime('%Y%m%d_%H%M%S')}.json"
-
     return date_dir / filename
 
 
+def _reconfigure_logging():
+    """重新配置日志处理器,用于处理跨日情况"""
+    global chat_logger, current_log_date
+
+    today = datetime.now().strftime("%Y%m%d")
+
+    if today != current_log_date:
+        # 记录日期变化信息到当前日志文件
+        if chat_logger:
+            chat_logger.info(
+                f"检测到日期变化: {current_log_date} -> {today}, 重新配置日志处理器"
+            )
+        current_log_date = today
+
+        # 强制重新配置日志处理器
+        setup_logging(force_reconfigure=True)
+
+
 def log_chat_entry(user_id: str, user_message: str, agent_response: Dict[str, Any]):
     """记录完整的对话日志"""
     try:
+        # 每次调用都检查日期是否变化
+        _reconfigure_logging()
+
         timestamp = datetime.now()
+
+        # 构建日志条目
         log_entry = {
             "timestamp": timestamp.isoformat(),
             "user_id": user_id,
             "user_message": user_message,
+            "thread_id": agent_response.get("thread_id", ""),
             "agent_response": {
                 "final_answer": agent_response.get("final_answer", ""),
                 "all_ai_messages_count": len(agent_response.get("all_ai_messages", [])),