chat_service.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # core/chat_service.py - 修复版本
  2. import asyncio
  3. from typing import Dict, Any, List
  4. from langchain_core.messages import HumanMessage
  5. from utils.logger import chat_logger, log_chat_entry
  6. from core.agent_manager import agent_manager
  7. class ChatService:
  8. """聊天服务 - 支持真正并发的版本"""
  9. def __init__(self):
  10. self.agent_manager = agent_manager
  11. # 创建专用的线程池用于执行同步的Langchain操作
  12. self._thread_pool = None
  13. def _get_thread_pool(self):
  14. """获取或创建线程池"""
  15. if self._thread_pool is None:
  16. import concurrent.futures
  17. # 创建足够大的线程池支持并发
  18. self._thread_pool = concurrent.futures.ThreadPoolExecutor(
  19. max_workers=20, # 根据服务器配置调整
  20. thread_name_prefix="langchain_worker",
  21. )
  22. return self._thread_pool
  23. async def process_chat_request(
  24. self, request_data: Dict[str, Any]
  25. ) -> Dict[str, Any]:
  26. """异步处理聊天请求 - 真正并发版本"""
  27. try:
  28. # 提取请求数据
  29. message = request_data["message"]
  30. thread_id = request_data["thread_id"]
  31. username = request_data["username"]
  32. backend_url = request_data["backend_url"]
  33. token = request_data["token"]
  34. # 生成用户标识符
  35. user_id = self.agent_manager._get_user_identifier(username, token)
  36. chat_logger.info(
  37. f"收到请求 - 用户={user_id} , 线程ID={thread_id}, 消息={message[:100]}"
  38. )
  39. # 异步获取agent实例
  40. agent = await self.agent_manager.get_agent_instance(
  41. thread_id=thread_id,
  42. username=username,
  43. backend_url=backend_url,
  44. token=token,
  45. )
  46. # ✅ 修复:在线程池中执行同步的Langchain操作
  47. result = await self._run_agent_in_threadpool(
  48. agent, message, thread_id, user_id
  49. )
  50. chat_logger.info(f"Agent处理完成 - 用户={user_id}")
  51. if not isinstance(result, dict) or "messages" not in result:
  52. raise ValueError(f"Agent返回格式异常: {type(result)}")
  53. # 处理结果
  54. return self._process_agent_result(result, user_id, request_data)
  55. except Exception as e:
  56. chat_logger.error(f"聊天处理失败: {str(e)}")
  57. raise
  58. async def _run_agent_in_threadpool(
  59. self, agent, message: str, thread_id: str, user_id: str
  60. ):
  61. """在线程池中执行Langchain Agent"""
  62. loop = asyncio.get_event_loop()
  63. thread_pool = self._get_thread_pool()
  64. # 准备输入
  65. inputs = {"messages": [HumanMessage(content=message)]}
  66. config = {"configurable": {"thread_id": thread_id}}
  67. chat_logger.info(f"在线程池中执行Agent - 用户={user_id}")
  68. try:
  69. # 在线程池中执行同步操作
  70. result = await loop.run_in_executor(
  71. thread_pool, lambda: agent.invoke(inputs, config)
  72. )
  73. return result
  74. except Exception as e:
  75. chat_logger.error(f"Agent执行失败 - 用户={user_id}: {str(e)}")
  76. raise
  77. def _process_agent_result(
  78. self, result: Dict[str, Any], user_id: str, request_data: Dict
  79. ) -> Dict[str, Any]:
  80. """处理Agent返回结果"""
  81. all_messages = result["messages"]
  82. processed_messages = []
  83. all_ai_messages = []
  84. all_tool_calls = []
  85. final_answer = ""
  86. for i, msg in enumerate(all_messages):
  87. msg_data = {
  88. "index": i,
  89. "type": getattr(msg, "type", "unknown"),
  90. "content": "",
  91. }
  92. # 获取内容
  93. if hasattr(msg, "content"):
  94. content = msg.content
  95. if isinstance(content, str):
  96. msg_data["content"] = content
  97. else:
  98. msg_data["content"] = str(content)
  99. # 获取工具调用
  100. if hasattr(msg, "tool_calls") and msg.tool_calls:
  101. msg_data["tool_calls"] = msg.tool_calls
  102. all_tool_calls.extend(msg.tool_calls)
  103. for tool_call in msg.tool_calls:
  104. tool_name = tool_call.get("name", "unknown")
  105. tool_args = tool_call.get("args", {})
  106. chat_logger.info(f"工具调用 - 用户={user_id}, 工具={tool_name}")
  107. if hasattr(msg, "tool_call_id"):
  108. msg_data["tool_call_id"] = msg.tool_call_id
  109. if hasattr(msg, "name"):
  110. msg_data["name"] = msg.name
  111. processed_messages.append(msg_data)
  112. # 收集AI消息
  113. if msg_data["type"] == "ai":
  114. all_ai_messages.append(msg_data)
  115. final_answer = msg_data["content"]
  116. # 构建响应
  117. response = {
  118. "final_answer": final_answer,
  119. "all_ai_messages": all_ai_messages,
  120. "all_messages": processed_messages,
  121. "tool_calls": all_tool_calls,
  122. "thread_id": request_data["thread_id"],
  123. "user_identifier": user_id,
  124. "backend_config": {
  125. "backend_url": request_data["backend_url"] or "未配置",
  126. "username": request_data["username"],
  127. "has_token": bool(request_data["token"]),
  128. },
  129. "success": True,
  130. }
  131. # 记录日志
  132. log_chat_entry(user_id, request_data["message"], response)
  133. chat_logger.info(f"请求处理完成 - 用户={user_id}")
  134. return response
  135. async def shutdown(self):
  136. """关闭线程池"""
  137. if self._thread_pool:
  138. self._thread_pool.shutdown(wait=False)
  139. self._thread_pool = None
  140. chat_logger.info("聊天服务线程池已关闭")
  141. # 全局实例
  142. chat_service = ChatService()