|
|
@@ -5,7 +5,12 @@ from pathlib import Path
|
|
|
|
|
|
from langchain.agents import create_agent, AgentState
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
-from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
|
|
+from langchain_core.messages import (
|
|
|
+ SystemMessage,
|
|
|
+ HumanMessage,
|
|
|
+ BaseMessage,
|
|
|
+ trim_messages,
|
|
|
+)
|
|
|
from tools.tool_factory import get_all_tools
|
|
|
from langchain_core.runnables import RunnableConfig
|
|
|
from langchain.agents.middleware import before_model
|
|
|
@@ -15,6 +20,7 @@ from langchain.messages import RemoveMessage
|
|
|
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
|
|
import sqlite3
|
|
|
from config.settings import settings
|
|
|
+from langchain_core.messages.utils import count_tokens_approximately
|
|
|
|
|
|
dotenv.load_dotenv()
|
|
|
|
|
|
@@ -30,54 +36,68 @@ def create_system_prompt(
|
|
|
if settings.KNOWLEDGE_BASE_ENABLED:
|
|
|
# 知识库启用时的提示词
|
|
|
system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status}
|
|
|
-职责:ERP数据查询和问题解答,按用户语言回答。
|
|
|
-工作流:
|
|
|
-1. 分析问题意图,提取模块关键词
|
|
|
-2. 如果是数据查询类问题,直接调用相关工具查询数据
|
|
|
-3. 如果是其他问题,则通过工具搜索知识库,知识库工具使用流程:a.通过关键字获取相关文章列表,b.判断哪些文章最符合,c.再通过工具获取文章内容.严格按文章内容回复,不能编造答案.
|
|
|
-4. 关键词要精准,避免无意义词
|
|
|
-回答规则:
|
|
|
-- 知识库找不到时提示"正在学习该问题"
|
|
|
-- {"需要个人数据时验证认证状态" if backend_url else "仅提供数据查询和知识库支持"}
|
|
|
-- 保护隐私,专业准确,精炼简要
|
|
|
-{"- 后端地址: " + backend_url if backend_url else ""}
|
|
|
-{"- API用户的认证令牌: " + token if token else ""}
|
|
|
-时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
|
|
-数据查询结果尽量以 Markdown 表格格式输出,格式如下:
|
|
|
-| 列名1 | 列名2 | 列名3 |
|
|
|
-| :--- | :--- | :--- |
|
|
|
-| 数据1 | 数据2 | 数据3 |
|
|
|
-| 数据4 | 数据5 | 数据6 |
|
|
|
-"""
|
|
|
+ 职责:ERP数据查询和问题解答,按用户语言回答。
|
|
|
+
|
|
|
+ **核心安全指令 (必遵)**:
|
|
|
+ 1. **当前凭据 (每次工具调用必须使用)**:
|
|
|
+ - 后端地址: {backend_url if backend_url else '无'}
|
|
|
+ - API令牌: {token if token else '无'}
|
|
|
+ 2. **禁止沿用历史**:**严禁**从对话历史中复制、沿用任何旧的后端地址、令牌或工具参数。历史记录仅用于理解背景,其中的工具详情**不能**作为本次调用的参数来源。
|
|
|
+ 3. **调用规范**:调用查询工具时,**必须且只能**使用上方提供的当前凭据。
|
|
|
+ 工作流:
|
|
|
+ 1. 分析问题意图,提取模块关键词
|
|
|
+ 2. 如果是数据查询类问题,直接调用相关工具查询数据
|
|
|
+ 3. 如果是其他问题,则通过工具搜索知识库,知识库工具使用流程:a.通过关键字获取相关文章列表,b.判断哪些文章最符合,c.再通过工具获取文章内容.严格按文章内容回复,不能编造答案.
|
|
|
+ 4. 关键词要精准,避免无意义词
|
|
|
+ 工具调用规格:
|
|
|
+ - 如果连续3次调用相同工具相同参数,自动停止
|
|
|
+ - 工具返回相同结果但仍在重复调用时,自动停止
|
|
|
+ 回答规则:
|
|
|
+ - 知识库找不到时提示"正在学习该问题"
|
|
|
+ - {"需要个人数据时验证认证状态" if backend_url else "仅提供数据查询和知识库支持"}
|
|
|
+ - 保护隐私,专业准确,精炼简要
|
|
|
+ 时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
|
|
+ 数据查询结果尽量以 Markdown 表格格式输出,格式如下:
|
|
|
+ | 列名1 | 列名2 | 列名3 |
|
|
|
+ | :--- | :--- | :--- |
|
|
|
+ | 数据1 | 数据2 | 数据3 |
|
|
|
+ | 数据4 | 数据5 | 数据6 |
|
|
|
+ """
|
|
|
else:
|
|
|
- # 知识库禁用时的提示词 - 严格限制回答范围
|
|
|
+ # 知识库禁用时的提示词 - 灵活处理工具返回结果
|
|
|
system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status}
|
|
|
-职责:仅处理ERP数据查询类问题,按用户语言回答。
|
|
|
-严格限制:
|
|
|
-- 知识库功能已禁用,无法回答任何非数据查询类问题
|
|
|
-- 禁止回答:疑问解答、操作流程、功能介绍、知识咨询等非数据查询问题
|
|
|
-- 禁止使用个人知识或经验进行回答
|
|
|
-工作流:
|
|
|
-1. 分析问题意图,判断是否为数据查询类问题
|
|
|
-2. 如果是数据查询类问题,直接调用相关工具查询数据
|
|
|
-3. 如果是非数据查询类问题(包括疑问、流程、操作等),必须明确回复:"知识库正在完善,无法回答该问题"
|
|
|
-回答规则:
|
|
|
-- 如用户提出非ERP范围的问题(例如:"你好"等闲聊),明确告知用户自己的职责:仅处理ERP数据查询类问题
|
|
|
-- 非数据查询问题必须回复:"知识库正在完善,无法回答该问题"
|
|
|
-- 禁止尝试回答或提供任何建议
|
|
|
-- 禁止解释原因或提供替代方案
|
|
|
-- 严格按工具提供的数据回答数据查询问题,不能编造答案
|
|
|
-- {"需要个人数据时验证认证状态" if backend_url else "仅提供数据查询支持"}
|
|
|
-{"- 后端地址: " + backend_url if backend_url else ""}
|
|
|
-{"- API用户的认证令牌: " + token if token else ""}
|
|
|
-当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
|
|
-数据查询结果尽量以 Markdown 表格格式输出,格式如下:
|
|
|
-| 列名1 | 列名2 | 列名3 |
|
|
|
-| :--- | :--- | :--- |
|
|
|
-| 数据1 | 数据2 | 数据3 |
|
|
|
-| 数据4 | 数据5 | 数据6 |
|
|
|
-"""
|
|
|
- print(system_prompt)
|
|
|
+ 职责:处理ERP数据查询类问题,按用户语言回答。
|
|
|
+ **核心安全指令 (必遵)**:
|
|
|
+ 1. **当前凭据 (每次工具调用必须使用)**:
|
|
|
+ - 后端地址: {backend_url if backend_url else '无'}
|
|
|
+ - API令牌: {token if token else '无'}
|
|
|
+ 2. **禁止沿用历史**:**严禁**从对话历史中复制、沿用任何旧的后端地址、令牌或工具参数。历史记录仅用于理解背景,其中的工具详情**不能**作为本次调用的参数来源。
|
|
|
+ 3. **调用规范**:调用查询工具时,**必须且只能**使用上方提供的当前凭据。
|
|
|
+ 工作流:
|
|
|
+ 1. 分析问题意图,判断是否为数据查询类问题
|
|
|
+ 2. 如果是数据查询类问题,直接调用相关工具查询数据
|
|
|
+ 3. 根据工具返回的结果进行回答:
|
|
|
+ - 如果工具返回了具体数据,按数据内容回答
|
|
|
+ - 如果工具返回了错误信息(如"API返回错误","查询失败","没有权限"等),如实告知用户错误信息
|
|
|
+ - 如果工具返回空数据或"未找到数据",如实告知用户
|
|
|
+ 4. 如果是非数据查询类问题(如疑问、流程、操作等),回复:"知识库正在完善,无法回答该问题"
|
|
|
+ 工具调用规格:
|
|
|
+ - 禁止连续调用相同工具相同参数
|
|
|
+ - 工具返回相同结果但仍在重复调用时,自动停止
|
|
|
+ 回答规则:
|
|
|
+ - 如用户提出非ERP范围的问题(例如:"你好"等闲聊),明确告知用户自己的职责:主要处理ERP数据查询类问题
|
|
|
+ - 工具提示没有权限时,明确回复用户没有权限
|
|
|
+ - 严格按工具返回的内容回答,不能编造答案,可对结果进行简单总结
|
|
|
+ - 当工具返回错误信息时,如实转达给用户,不要添加额外解释
|
|
|
+ - 保持专业、准确、简洁的回答风格
|
|
|
+ {"- 需要个人数据时验证认证状态" if backend_url else "- 仅提供数据查询支持"}
|
|
|
+ 当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
|
|
+ 数据查询结果尽量以 Markdown 表格格式输出,格式如下:
|
|
|
+ | 列名1 | 列名2 | 列名3 |
|
|
|
+ | :--- | :--- | :--- |
|
|
|
+ | 数据1 | 数据2 | 数据3 |
|
|
|
+ | 数据4 | 数据5 | 数据6 |
|
|
|
+ """
|
|
|
return system_prompt
|
|
|
|
|
|
|
|
|
@@ -197,72 +217,101 @@ def create_langchain_agent(
|
|
|
# 获取动态的system_prompt
|
|
|
system_prompt = create_system_prompt(backend_url, token, username)
|
|
|
|
|
|
- def simple_turn_based_trim(
|
|
|
- messages: Sequence[BaseMessage],
|
|
|
- keep_turns: int = 3,
|
|
|
- system_message: BaseMessage = None,
|
|
|
- ) -> List[BaseMessage]:
|
|
|
- """
|
|
|
- 修正版:按完整对话轮次修剪消息
|
|
|
- 每轮对话从Human开始,到下一个Human之前结束
|
|
|
- """
|
|
|
- if not messages:
|
|
|
- return []
|
|
|
- # 分离系统消息(始终保留)
|
|
|
- system_messages = []
|
|
|
- other_messages = []
|
|
|
- for msg in messages:
|
|
|
- if (
|
|
|
- isinstance(msg, SystemMessage)
|
|
|
- or getattr(msg, "type", None) == "system"
|
|
|
- or getattr(msg, "role", None) == "system"
|
|
|
- or msg.__class__.__name__ == "SystemMessage"
|
|
|
- ):
|
|
|
- system_messages.append(msg)
|
|
|
- else:
|
|
|
- other_messages.append(msg)
|
|
|
-
|
|
|
- if len(other_messages) <= 1:
|
|
|
- return system_messages + other_messages
|
|
|
-
|
|
|
- # 找出所有Human消息的位置
|
|
|
- human_indices = []
|
|
|
- for i, msg in enumerate(other_messages):
|
|
|
- if (
|
|
|
- isinstance(msg, HumanMessage)
|
|
|
- or getattr(msg, "type", None) == "human"
|
|
|
- or getattr(msg, "role", None) == "user"
|
|
|
- ):
|
|
|
- human_indices.append(i)
|
|
|
-
|
|
|
- # 如果Human消息不足keep_turns轮,返回所有
|
|
|
- if not human_indices or len(human_indices) <= keep_turns:
|
|
|
- return system_messages + other_messages
|
|
|
-
|
|
|
- # 计算起始索引
|
|
|
- start_idx = human_indices[-keep_turns]
|
|
|
-
|
|
|
- # 获取要保留的消息
|
|
|
- preserved_messages = other_messages[start_idx:]
|
|
|
-
|
|
|
- # 4. 返回从该索引开始的所有消息
|
|
|
- result = system_messages + preserved_messages
|
|
|
- # print(f"修剪后消息数: {len(result)}")
|
|
|
-
|
|
|
- return result
|
|
|
+ # def simple_turn_based_trim(
|
|
|
+ # messages: Sequence[BaseMessage],
|
|
|
+ # keep_turns: int = 3,
|
|
|
+ # system_message: BaseMessage = None,
|
|
|
+ # ) -> List[BaseMessage]:
|
|
|
+ # """
|
|
|
+ # 修正版:按完整对话轮次修剪消息
|
|
|
+ # 每轮对话从Human开始,到下一个Human之前结束
|
|
|
+ # """
|
|
|
+ # if not messages:
|
|
|
+ # return []
|
|
|
+ # # 分离系统消息(始终保留)
|
|
|
+ # system_messages = []
|
|
|
+ # other_messages = []
|
|
|
+ # for msg in messages:
|
|
|
+ # if (
|
|
|
+ # isinstance(msg, SystemMessage)
|
|
|
+ # or getattr(msg, "type", None) == "system"
|
|
|
+ # or getattr(msg, "role", None) == "system"
|
|
|
+ # or msg.__class__.__name__ == "SystemMessage"
|
|
|
+ # ):
|
|
|
+ # system_messages.append(msg)
|
|
|
+ # else:
|
|
|
+ # other_messages.append(msg)
|
|
|
+
|
|
|
+ # if len(other_messages) <= 1:
|
|
|
+ # return system_messages + other_messages
|
|
|
+
|
|
|
+ # # 找出所有Human消息的位置
|
|
|
+ # human_indices = []
|
|
|
+ # for i, msg in enumerate(other_messages):
|
|
|
+ # if (
|
|
|
+ # isinstance(msg, HumanMessage)
|
|
|
+ # or getattr(msg, "type", None) == "human"
|
|
|
+ # or getattr(msg, "role", None) == "user"
|
|
|
+ # ):
|
|
|
+ # human_indices.append(i)
|
|
|
+
|
|
|
+ # # 如果Human消息不足keep_turns轮,返回所有
|
|
|
+ # if not human_indices or len(human_indices) <= keep_turns:
|
|
|
+ # return system_messages + other_messages
|
|
|
+
|
|
|
+ # # 计算起始索引
|
|
|
+ # start_idx = human_indices[-keep_turns]
|
|
|
+
|
|
|
+ # # 获取要保留的消息
|
|
|
+ # preserved_messages = other_messages[start_idx:]
|
|
|
+
|
|
|
+ # # 4. 返回从该索引开始的所有消息
|
|
|
+ # result = system_messages + preserved_messages
|
|
|
+ # # print(f"修剪后消息数: {len(result)}")
|
|
|
+
|
|
|
+ # return result
|
|
|
+
|
|
|
+ # @before_model
|
|
|
+ # def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
|
|
+ # """Keep only the last few messages to fit context window."""
|
|
|
+ # messages = state["messages"]
|
|
|
+
|
|
|
+ # if len(messages) <= 3:
|
|
|
+ # return None # No changes needed
|
|
|
+
|
|
|
+ # # 保留最后4轮对话
|
|
|
+ # trimmed_messages = simple_turn_based_trim(messages, keep_turns=4)
|
|
|
+
|
|
|
+ # return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)] + trimmed_messages}
|
|
|
|
|
|
@before_model
|
|
|
- def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
|
|
- """Keep only the last few messages to fit context window."""
|
|
|
- messages = state["messages"]
|
|
|
-
|
|
|
- if len(messages) <= 3:
|
|
|
- return None # No changes needed
|
|
|
-
|
|
|
- # 保留最后4轮对话
|
|
|
- trimmed_messages = simple_turn_based_trim(messages, keep_turns=4)
|
|
|
-
|
|
|
- return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)] + trimmed_messages}
|
|
|
+ def trim_messages_middleware(
|
|
|
+ state: AgentState, runtime: Runtime
|
|
|
+ ) -> dict[str, Any] | None:
|
|
|
+ """使用官方trim_messages函数修剪消息"""
|
|
|
+ messages = state.get("messages", [])
|
|
|
+ print(f"trim_messages_middleware[DEBUG]原始消息数: {len(messages)}")
|
|
|
+ # if len(messages) <= 3:
|
|
|
+ # return None # 不需要修剪
|
|
|
+
|
|
|
+ trimmed_messages = trim_messages(
|
|
|
+ messages,
|
|
|
+ max_tokens=1000,
|
|
|
+ strategy="last", # 保留最近的对话
|
|
|
+ token_counter=count_tokens_approximately, # token计数器
|
|
|
+ start_on="human", # 从human消息开始计算轮次
|
|
|
+ include_system=True, # 包含系统消息
|
|
|
+ )
|
|
|
+
|
|
|
+ # 添加调试信息
|
|
|
+ original_count = len(messages)
|
|
|
+ trimmed_count = len(trimmed_messages)
|
|
|
+ print(f"trim_messages_middleware[DEBUG]修剪后消息数: {trimmed_count}")
|
|
|
+
|
|
|
+ if trimmed_count < original_count:
|
|
|
+ print(f"[INFO]消息修剪: {original_count} -> {trimmed_count} 条消息")
|
|
|
+
|
|
|
+ return {"messages": trimmed_messages}
|
|
|
|
|
|
# 使用SQLiteSaver(按天分割)
|
|
|
checkpointer = get_sqlite_checkpointer()
|
|
|
@@ -277,7 +326,7 @@ def create_langchain_agent(
|
|
|
tools,
|
|
|
checkpointer=checkpointer,
|
|
|
system_prompt=system_prompt,
|
|
|
- middleware=[trim_messages],
|
|
|
+ middleware=[trim_messages_middleware],
|
|
|
)
|
|
|
|
|
|
return agent
|