async_chat_service.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import asyncio
  2. import time
  3. from typing import Dict, Any
  4. from langchain_core.messages import HumanMessage
  5. from utils.logger import chat_logger, log_chat_entry
  6. from core.agent_manager import agent_manager
  7. from core.chat_result_manager import chat_result_manager
  8. class AsyncChatService:
  9. """异步聊天服务 - 支持轮询方式的版本"""
  10. def __init__(self):
  11. self.agent_manager = agent_manager
  12. self._thread_pool = None
  13. self._processing_tasks = {} # 正在处理的任务缓存
  14. def _get_thread_pool(self):
  15. """获取或创建线程池"""
  16. if self._thread_pool is None:
  17. import concurrent.futures
  18. self._thread_pool = concurrent.futures.ThreadPoolExecutor(
  19. max_workers=20,
  20. thread_name_prefix="async_chat_worker",
  21. )
  22. return self._thread_pool
  23. async def submit_chat_task(self, request_data: Dict[str, Any]) -> str:
  24. """提交聊天任务(立即返回任务ID)"""
  25. username = request_data["username"]
  26. # 创建任务记录
  27. task_id = chat_result_manager.create_task(request_data)
  28. chat_logger.info(f"用户{username},已提交聊天任务: {task_id}")
  29. # 异步执行任务
  30. asyncio.create_task(self._process_chat_task(task_id, request_data))
  31. return task_id
  32. async def _process_chat_task(self, task_id: str, request_data: Dict[str, Any]):
  33. """异步处理聊天任务"""
  34. try:
  35. # 更新状态为处理中
  36. chat_result_manager.update_task_status(task_id, "processing")
  37. # 提取请求数据
  38. message = request_data["message"]
  39. thread_id = request_data["thread_id"]
  40. username = request_data["username"]
  41. backend_url = request_data["backend_url"]
  42. token = request_data["token"]
  43. user_id = username
  44. chat_logger.info(f"开始处理任务 - 任务ID={task_id}, 用户={user_id}")
  45. # 异步获取agent实例
  46. agent = await self.agent_manager.get_agent_instance(
  47. thread_id=thread_id,
  48. username=username,
  49. backend_url=backend_url,
  50. token=token,
  51. )
  52. # 在线程池中执行同步的Langchain操作
  53. result = await self._run_agent_in_threadpool(
  54. agent, message, thread_id, user_id
  55. )
  56. if not isinstance(result, dict) or "messages" not in result:
  57. raise ValueError(f"Agent返回格式异常: {type(result)}")
  58. # 处理结果
  59. response_data = self._process_agent_result(result, user_id, request_data)
  60. # 更新任务状态为完成
  61. chat_result_manager.update_task_status(task_id, "completed", response_data)
  62. chat_logger.info(f"任务处理完成 - 任务ID={task_id}")
  63. except Exception as e:
  64. error_msg = f"聊天处理失败: {str(e)}"
  65. chat_logger.error(f"{error_msg} - 任务ID={task_id}")
  66. # 更新任务状态为失败
  67. chat_result_manager.update_task_status(
  68. task_id, "failed", error_message=error_msg
  69. )
  70. async def _run_agent_in_threadpool(
  71. self, agent, message: str, thread_id: str, user_id: str
  72. ):
  73. """在线程池中执行Langchain Agent"""
  74. loop = asyncio.get_event_loop()
  75. thread_pool = self._get_thread_pool()
  76. # 准备输入
  77. inputs = {"messages": [HumanMessage(content=message)]}
  78. config = {"configurable": {"thread_id": thread_id}}
  79. chat_logger.info(f"在线程池中执行Agent - 用户={user_id}")
  80. try:
  81. # 在线程池中执行同步操作
  82. result = await loop.run_in_executor(
  83. thread_pool, lambda: agent.invoke(inputs, config)
  84. )
  85. return result
  86. except Exception as e:
  87. chat_logger.error(f"Agent执行失败 - 用户={user_id}: {str(e)}")
  88. raise
  89. def _process_agent_result(
  90. self, result: Dict[str, Any], user_id: str, request_data: Dict
  91. ) -> Dict[str, Any]:
  92. """处理Agent返回结果"""
  93. all_messages = result["messages"]
  94. processed_messages = []
  95. all_ai_messages = []
  96. all_tool_calls = []
  97. final_answer = ""
  98. for i, msg in enumerate(all_messages):
  99. msg_data = {
  100. "index": i,
  101. "type": getattr(msg, "type", "unknown"),
  102. "content": "",
  103. }
  104. # 获取内容
  105. if hasattr(msg, "content"):
  106. content = msg.content
  107. if isinstance(content, str):
  108. msg_data["content"] = content
  109. else:
  110. msg_data["content"] = str(content)
  111. # 获取工具调用
  112. if hasattr(msg, "tool_calls") and msg.tool_calls:
  113. msg_data["tool_calls"] = msg.tool_calls
  114. all_tool_calls.extend(msg.tool_calls)
  115. for tool_call in msg.tool_calls:
  116. tool_name = tool_call.get("name", "unknown")
  117. tool_args = tool_call.get("args", {})
  118. chat_logger.info(f"工具调用 - 用户={user_id}, 工具={tool_name}")
  119. if hasattr(msg, "tool_call_id"):
  120. msg_data["tool_call_id"] = msg.tool_call_id
  121. if hasattr(msg, "name"):
  122. msg_data["name"] = msg.name
  123. processed_messages.append(msg_data)
  124. # 收集AI消息
  125. if msg_data["type"] == "ai":
  126. all_ai_messages.append(msg_data)
  127. final_answer = msg_data["content"]
  128. # 构建响应
  129. response = {
  130. "final_answer": final_answer,
  131. "all_ai_messages": all_ai_messages,
  132. "all_messages": processed_messages,
  133. "tool_calls": all_tool_calls,
  134. "thread_id": request_data["thread_id"],
  135. "user_identifier": user_id,
  136. "backend_config": {
  137. "backend_url": request_data["backend_url"] or "未配置",
  138. "username": request_data["username"],
  139. "has_token": bool(request_data["token"]),
  140. },
  141. "success": True,
  142. }
  143. # 记录日志
  144. log_chat_entry(user_id, request_data["message"], response)
  145. return response
  146. async def get_task_result(self, task_id: str) -> Dict[str, Any]:
  147. """获取任务结果"""
  148. task_info = chat_result_manager.get_task(task_id)
  149. chat_logger.info(f"获取任务结果 - 任务ID={task_id}, 状态={task_info['status']}")
  150. if not task_info:
  151. return {
  152. "success": False,
  153. "error": f"任务不存在: {task_id}",
  154. "task_id": task_id,
  155. }
  156. return {
  157. "task_id": task_id,
  158. "status": task_info["status"],
  159. "response": task_info["response_data"],
  160. "error": task_info["error_message"],
  161. "created_at": task_info["created_at"],
  162. "updated_at": task_info["updated_at"],
  163. "success": task_info["status"] == "completed",
  164. }
  165. async def shutdown(self):
  166. """关闭服务"""
  167. if self._thread_pool:
  168. self._thread_pool.shutdown(wait=False)
  169. self._thread_pool = None
  170. chat_result_manager.close()
  171. chat_logger.info("异步聊天服务已关闭")
  172. # 全局实例
  173. async_chat_service = AsyncChatService()