chat_service.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # core/chat_service.py - 修复版本
  2. import asyncio
  3. from typing import Dict, Any, List
  4. from langchain_core.messages import HumanMessage
  5. from utils.logger import chat_logger, log_chat_entry
  6. from core.agent_manager import agent_manager
  7. class ChatService:
  8. """聊天服务 - 支持真正并发的版本"""
  9. def __init__(self):
  10. self.agent_manager = agent_manager
  11. # 创建专用的线程池用于执行同步的Langchain操作
  12. self._thread_pool = None
  13. def _get_thread_pool(self):
  14. """获取或创建线程池"""
  15. if self._thread_pool is None:
  16. import concurrent.futures
  17. # 创建足够大的线程池支持并发
  18. self._thread_pool = concurrent.futures.ThreadPoolExecutor(
  19. max_workers=20, # 根据服务器配置调整
  20. thread_name_prefix="langchain_worker",
  21. )
  22. return self._thread_pool
  23. async def process_chat_request(
  24. self, request_data: Dict[str, Any]
  25. ) -> Dict[str, Any]:
  26. """异步处理聊天请求 - 真正并发版本"""
  27. try:
  28. # 提取请求数据
  29. message = request_data["message"]
  30. thread_id = request_data["thread_id"]
  31. username = request_data["username"]
  32. backend_url = request_data["backend_url"]
  33. token = request_data["token"]
  34. # 生成用户标识符
  35. # user_id = self.agent_manager._get_user_identifier(username, token)
  36. user_id = username
  37. chat_logger.info(
  38. f"收到请求 - 用户={user_id} , 线程ID={thread_id}, 消息={message[:100]}"
  39. )
  40. # 异步获取agent实例
  41. agent = await self.agent_manager.get_agent_instance(
  42. thread_id=thread_id,
  43. username=username,
  44. backend_url=backend_url,
  45. token=token,
  46. )
  47. # ✅ 修复:在线程池中执行同步的Langchain操作
  48. result = await self._run_agent_in_threadpool(
  49. agent, message, thread_id, user_id
  50. )
  51. chat_logger.info(f"Agent处理完成 - 用户={user_id}")
  52. if not isinstance(result, dict) or "messages" not in result:
  53. raise ValueError(f"Agent返回格式异常: {type(result)}")
  54. # 处理结果
  55. return self._process_agent_result(result, user_id, request_data)
  56. except Exception as e:
  57. chat_logger.error(f"聊天处理失败: {str(e)}")
  58. raise
  59. async def _run_agent_in_threadpool(
  60. self, agent, message: str, thread_id: str, user_id: str
  61. ):
  62. """在线程池中执行Langchain Agent"""
  63. loop = asyncio.get_event_loop()
  64. thread_pool = self._get_thread_pool()
  65. # 准备输入
  66. inputs = {"messages": [HumanMessage(content=message)]}
  67. config = {"configurable": {"thread_id": thread_id}}
  68. chat_logger.info(f"在线程池中执行Agent - 用户={user_id}")
  69. try:
  70. # 在线程池中执行同步操作
  71. result = await loop.run_in_executor(
  72. thread_pool, lambda: agent.invoke(inputs, config)
  73. )
  74. return result
  75. except Exception as e:
  76. chat_logger.error(f"Agent执行失败 - 用户={user_id}: {str(e)}")
  77. raise
  78. def _process_agent_result(
  79. self, result: Dict[str, Any], user_id: str, request_data: Dict
  80. ) -> Dict[str, Any]:
  81. """处理Agent返回结果"""
  82. all_messages = result["messages"]
  83. processed_messages = []
  84. all_ai_messages = []
  85. all_tool_calls = []
  86. final_answer = ""
  87. for i, msg in enumerate(all_messages):
  88. msg_data = {
  89. "index": i,
  90. "type": getattr(msg, "type", "unknown"),
  91. "content": "",
  92. }
  93. # 获取内容
  94. if hasattr(msg, "content"):
  95. content = msg.content
  96. if isinstance(content, str):
  97. msg_data["content"] = content
  98. else:
  99. msg_data["content"] = str(content)
  100. # 获取工具调用
  101. if hasattr(msg, "tool_calls") and msg.tool_calls:
  102. msg_data["tool_calls"] = msg.tool_calls
  103. all_tool_calls.extend(msg.tool_calls)
  104. for tool_call in msg.tool_calls:
  105. tool_name = tool_call.get("name", "unknown")
  106. tool_args = tool_call.get("args", {})
  107. chat_logger.info(f"工具调用 - 用户={user_id}, 工具={tool_name}")
  108. if hasattr(msg, "tool_call_id"):
  109. msg_data["tool_call_id"] = msg.tool_call_id
  110. if hasattr(msg, "name"):
  111. msg_data["name"] = msg.name
  112. processed_messages.append(msg_data)
  113. # 收集AI消息
  114. if msg_data["type"] == "ai":
  115. all_ai_messages.append(msg_data)
  116. final_answer = msg_data["content"]
  117. # 构建响应
  118. response = {
  119. "final_answer": final_answer,
  120. "all_ai_messages": all_ai_messages,
  121. "all_messages": processed_messages,
  122. "tool_calls": all_tool_calls,
  123. "thread_id": request_data["thread_id"],
  124. "user_identifier": user_id,
  125. "backend_config": {
  126. "backend_url": request_data["backend_url"] or "未配置",
  127. "username": request_data["username"],
  128. "has_token": bool(request_data["token"]),
  129. },
  130. "success": True,
  131. }
  132. # 记录日志
  133. log_chat_entry(user_id, request_data["message"], response)
  134. chat_logger.info(f"请求处理完成 - 用户={user_id}")
  135. return response
  136. async def shutdown(self):
  137. """关闭线程池"""
  138. if self._thread_pool:
  139. self._thread_pool.shutdown(wait=False)
  140. self._thread_pool = None
  141. chat_logger.info("聊天服务线程池已关闭")
  142. # 全局实例
  143. chat_service = ChatService()