Kaynağa Gözat

改用轮询方式获取结果

longjoedyy 1 hafta önce
ebeveyn
işleme
b723ea9d24

+ 32 - 1
api/models.py

@@ -1,5 +1,26 @@
 from pydantic import BaseModel
-from typing import Optional, List, Dict
+from typing import Optional, List, Dict, Any
+
+
+class AsyncChatResponse(BaseModel):
+    """异步聊天响应"""
+
+    success: bool
+    task_id: str
+    message: str
+    polling_endpoint: str
+
+
+class ChatResultResponse(BaseModel):
+    """聊天结果响应"""
+
+    success: bool
+    task_id: str
+    status: str  # pending, processing, completed, failed
+    response: Optional[Dict[str, Any]] = None
+    error: Optional[str] = None
+    created_at: Optional[str] = None
+    updated_at: Optional[str] = None
 
 
 class ChatRequest(BaseModel):
@@ -49,6 +70,7 @@ class MessageCreateBill(BaseModel):
 
 class ImageVectorItem(BaseModel):
     """图片向量项"""
+
     image_id: str  # 图片ID
     vector: List[float]  # 图片特征向量
     image_name: Optional[str] = None  # 图片名称,可选
@@ -57,12 +79,14 @@ class ImageVectorItem(BaseModel):
 
 class ImageVectorRequest(BaseModel):
     """计算图片特征向量请求"""
+
     image: str  # Base64编码的图片数据
     image_id: Optional[str] = None  # 图片ID,可选
 
 
 class ImageVectorResponse(BaseModel):
     """计算图片特征向量响应"""
+
     success: bool
     image_id: Optional[str] = None
     vector: Optional[List[float]] = None
@@ -71,11 +95,13 @@ class ImageVectorResponse(BaseModel):
 
 class BuildIndexRequest(BaseModel):
     """构建索引请求"""
+
     image_vectors: List[ImageVectorItem]  # 图片向量列表
 
 
 class BuildIndexResponse(BaseModel):
     """构建索引响应"""
+
     success: bool
     indexed_count: int  # 索引的图片数量
     error: Optional[str] = None
@@ -83,6 +109,7 @@ class BuildIndexResponse(BaseModel):
 
 class SearchResultItem(BaseModel):
     """搜索结果项"""
+
     image_id: str  # 图片ID
     similarity: float  # 相似度
     image_name: Optional[str] = None  # 图片名称
@@ -91,6 +118,7 @@ class SearchResultItem(BaseModel):
 
 class SearchRequest(BaseModel):
     """搜索请求(支持以图搜图和以文搜图)"""
+
     image: Optional[str] = None  # Base64编码的图片数据,可选
     text: Optional[str] = None  # 文字描述,可选
     top_k: int = 10  # 返回结果数量
@@ -98,6 +126,7 @@ class SearchRequest(BaseModel):
 
 class SearchResponse(BaseModel):
     """搜索响应"""
+
     success: bool
     results: List[SearchResultItem]  # 搜索结果列表
     total_count: int  # 总结果数量
@@ -107,6 +136,7 @@ class SearchResponse(BaseModel):
 
 class SearchHistoryRequest(BaseModel):
     """记录搜索历史请求"""
+
     empid: int  # 员工ID
     search_type: int  # 搜索类型:1-以图搜图,2-以文搜图,3-图文混合搜索
     search_content: str  # 搜索内容
@@ -116,6 +146,7 @@ class SearchHistoryRequest(BaseModel):
 
 class SearchHistoryResponse(BaseModel):
     """记录搜索历史响应"""
+
     success: bool
     history_id: Optional[int] = None
     error: Optional[str] = None

+ 88 - 50
api/routes.py

@@ -5,9 +5,17 @@ from fastapi import APIRouter, HTTPException
 
 from utils.device_id import get_device_id
 from .models import (
-    ChatRequest, ChatResponse, OCRRequest, MessageCreateBill,
-    ImageVectorRequest, ImageVectorResponse, BuildIndexRequest, BuildIndexResponse,
-    SearchRequest, SearchResponse, SearchResultItem
+    ChatRequest,
+    ChatResponse,
+    OCRRequest,
+    MessageCreateBill,
+    ImageVectorRequest,
+    ImageVectorResponse,
+    BuildIndexRequest,
+    BuildIndexResponse,
+    SearchRequest,
+    SearchResponse,
+    SearchResultItem,
 )
 from core.chat_service import chat_service
 from core.agent_manager import agent_manager
@@ -18,6 +26,9 @@ import time
 from utils.registration_manager import registration_manager
 from core.ocr_service import PaddleOCRService
 from core.document_processor.document_service import DocumentProcessingService
+from core.async_chat_service import async_chat_service
+from core.chat_result_manager import chat_result_manager
+from typing import Dict, Any
 
 # 初始化服务
 ocr_service = PaddleOCRService(
@@ -30,6 +41,44 @@ doc_service = DocumentProcessingService(ocr_service=ocr_service)
 router = APIRouter()
 
 
+@router.post("/chat/async", response_model=Dict[str, Any])
+async def async_chat_endpoint(request: ChatRequest):
+    """异步聊天接口 - 立即返回任务ID"""
+    try:
+        task_id = await async_chat_service.submit_chat_task(request.model_dump())
+        return {
+            "success": True,
+            "task_id": task_id,
+            "message": "任务已提交,请使用任务ID轮询结果",
+            "polling_endpoint": f"/chat/result/{task_id}",
+        }
+    except Exception as e:
+        chat_logger.error(f"异步聊天任务提交失败: {str(e)}")
+        raise HTTPException(status_code=500, detail=f"任务提交失败: {str(e)}")
+
+
+@router.get("/chat/result/{task_id}", response_model=Dict[str, Any])
+async def get_chat_result(task_id: str):
+    """获取聊天结果"""
+    try:
+        result = await async_chat_service.get_task_result(task_id)
+        return result
+    except Exception as e:
+        chat_logger.error(f"获取聊天结果失败: {str(e)}")
+        raise HTTPException(status_code=500, detail=f"获取结果失败: {str(e)}")
+
+
+@router.get("/chat/tasks/cleanup")
+async def cleanup_old_tasks():
+    """清理旧任务数据(管理员接口)"""
+    try:
+        chat_result_manager.cleanup_old_tasks(max_days=7)
+        return {"success": True, "message": "旧任务数据清理完成"}
+    except Exception as e:
+        chat_logger.error(f"清理旧任务失败: {str(e)}")
+        raise HTTPException(status_code=500, detail=f"清理失败: {str(e)}")
+
+
 @router.post("/chat", response_model=ChatResponse)
 async def chat_endpoint(request: ChatRequest):
     """聊天接口"""
@@ -329,14 +378,11 @@ async def batch_calculate_vectors_endpoint(requests: List[ImageVectorRequest]):
         # 构建请求数据
         image_items = []
         for req in requests:
-            image_items.append({
-                "image": req.image,
-                "image_id": req.image_id
-            })
-        
+            image_items.append({"image": req.image, "image_id": req.image_id})
+
         # 调用服务
         results = await image_search_service.batch_calculate_vectors(image_items)
-        
+
         # 构建响应
         responses = []
         for result in results:
@@ -344,12 +390,12 @@ async def batch_calculate_vectors_endpoint(requests: List[ImageVectorRequest]):
                 success=result.get("success", False),
                 image_id=result.get("image_id"),
                 vector=result.get("vector"),
-                error=result.get("error")
+                error=result.get("error"),
             )
             responses.append(response)
-        
+
         return responses
-        
+
     except Exception as e:
         chat_logger.error(f"批量计算图片特征向量失败: {str(e)}")
         raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
@@ -362,24 +408,23 @@ async def build_index_endpoint(request: BuildIndexRequest):
         # 构建请求数据
         image_vectors = []
         for item in request.image_vectors:
-            image_vectors.append({
-                "image_id": item.image_id,
-                "vector": item.vector,
-                "image_name": item.image_name,
-                "image_path": item.image_path
-            })
-        
+            image_vectors.append(
+                {
+                    "image_id": item.image_id,
+                    "vector": item.vector,
+                    "image_name": item.image_name,
+                    "image_path": item.image_path,
+                }
+            )
+
         # 调用服务
         indexed_count = await image_search_service.build_index(image_vectors)
-        
+
         # 构建响应
-        response = BuildIndexResponse(
-            success=True,
-            indexed_count=indexed_count
-        )
-        
+        response = BuildIndexResponse(success=True, indexed_count=indexed_count)
+
         return response
-        
+
     except Exception as e:
         chat_logger.error(f"构建索引失败: {str(e)}")
         raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
@@ -390,8 +435,9 @@ async def search_endpoint(request: SearchRequest):
     """搜索相似图片(支持以图搜图和以文搜图)"""
     try:
         import time
+
         start_time = time.time()
-        
+
         # 处理图片数据
         image_bytes = None
         if request.image:
@@ -400,23 +446,21 @@ async def search_endpoint(request: SearchRequest):
                     base64_str = request.image.split(",", 1)[1]
                 else:
                     base64_str = request.image
-                
+
                 image_bytes = base64.b64decode(base64_str)
-                
+
             except Exception as e:
                 chat_logger.error(f"图片解码失败: {e}")
                 raise HTTPException(400, f"图片格式错误: {str(e)}")
-        
+
         # 调用服务
         results = await image_search_service.search(
-            image_bytes=image_bytes,
-            text=request.text,
-            top_k=request.top_k
+            image_bytes=image_bytes, text=request.text, top_k=request.top_k
         )
-        
+
         # 计算处理时间
         processing_time = time.time() - start_time
-        
+
         # 构建响应
         search_results = []
         for result in results:
@@ -424,19 +468,19 @@ async def search_endpoint(request: SearchRequest):
                 image_id=result.get("image_id"),
                 similarity=result.get("similarity"),
                 image_name=result.get("image_name"),
-                image_path=result.get("image_path")
+                image_path=result.get("image_path"),
             )
             search_results.append(item)
-        
+
         response = SearchResponse(
             success=True,
             results=search_results,
             total_count=len(search_results),
-            processing_time=round(processing_time, 4)
+            processing_time=round(processing_time, 4),
         )
-        
+
         return response
-        
+
     except HTTPException:
         raise
     except Exception as e:
@@ -449,11 +493,8 @@ async def get_index_status_endpoint():
     """获取索引状态"""
     try:
         status = await image_search_service.get_index_status()
-        return {
-            "success": True,
-            "status": status
-        }
-        
+        return {"success": True, "status": status}
+
     except Exception as e:
         chat_logger.error(f"获取索引状态失败: {str(e)}")
         raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
@@ -464,11 +505,8 @@ async def clear_index_endpoint():
     """清空索引"""
     try:
         await image_search_service.clear_index()
-        return {
-            "success": True,
-            "message": "索引已清空"
-        }
-        
+        return {"success": True, "message": "索引已清空"}
+
     except Exception as e:
         chat_logger.error(f"清空索引失败: {str(e)}")
         raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")

+ 5 - 5
config/settings.py

@@ -10,7 +10,7 @@ class Settings:
         # 1. 检测环境
         self.is_development = Path(".env.development").exists()
         self.env = "development" if self.is_development else "production"
-        print(f"🎯 当前环境: {self.env}")
+        print(f"当前环境: {self.env}")
 
         # 2. 加载配置文件
         self._load_env_file()
@@ -30,7 +30,7 @@ class Settings:
         )
 
         # 其他配置
-        self.LLM_MODEL = "deepseek-chat"
+        self.LLM_MODEL = "qwen-flash"  # os.getenv("LLM_MODEL", "deepseek-chat")
         self.LLM_TEMPERATURE = 0.7
         self.LLM_MAX_TOKENS = 2048
 
@@ -42,9 +42,9 @@ class Settings:
 
         if Path(env_file).exists():
             load_dotenv(env_file, override=True)
-            print(f"📁 加载配置文件: {env_file}")
+            print(f"加载配置文件: {env_file}")
         else:
-            print(f"⚠️  配置文件 {env_file} 不存在")
+            print(f"配置文件 {env_file} 不存在")
 
     def _get_api_key(self) -> str:
         """根据环境获取API Key"""
@@ -52,7 +52,7 @@ class Settings:
             # 开发环境:使用明文
             key = os.getenv("DEEPSEEK_API_KEY")
             if not key:
-                print("\n⚠️  开发环境提示:")
+                print("\n开发环境提示:")
                 print("   请在 .env.development 中添加:")
                 print("   DEEPSEEK_API_KEY=sk-your-key-here")
                 print("=" * 40)

+ 1 - 1
config/tool_config.json

@@ -31,6 +31,6 @@
         "输出格式要求": [
             "重复信息要总结归纳,精简显示"
         ],
-        "使用示例": "用户输入:'查看2023年1月1日至2023年12月31日的销售金额' -> 系统调用此工具获取2023年1月至12月的销售金额"
+        "使用示例": "用户输入:'查看2023年1月1日至2023年12月31日的销售金额' -> 系统调用此工具获取2023年1月至12月的销售金额;'2024年前5热销产品是哪些?' -> 系统调用此工具获取2024年产品销售额;'2025年前10销售额最高的客户是?' -> 系统调用此工具获取2025年客户销售额"
     }
 }

+ 7 - 5
core/agent.py

@@ -85,7 +85,7 @@ def get_sqlite_checkpointer():
         return checkpointer
 
     except Exception as e:
-        print(f"❌❌ 创建 SQLite 检查器失败: {e}")
+        print(f"[ERROR]创建 SQLite 检查器失败: {e}")
         import traceback
 
         traceback.print_exc()
@@ -93,7 +93,7 @@ def get_sqlite_checkpointer():
         # 回退到内存保存器
         from langgraph.checkpoint.memory import InMemorySaver
 
-        print("⚠️ 使用 InMemorySaver 作为回退")
+        print("[WARN]使用 InMemorySaver 作为回退")
         return InMemorySaver()
 
 
@@ -126,14 +126,16 @@ def cleanup_old_checkpoints(max_days=7):
                     if days_diff > max_days:
                         file_path = os.path.join(base_dir, filename)
                         os.remove(file_path)
-                        print(f"🧹🧹 清理旧检查点文件: {filename} (超过 {max_days} 天)")
+                        print(
+                            f"[CLEAN]清理旧检查点文件: {filename} (超过 {max_days} 天)"
+                        )
 
                 except (ValueError, IndexError):
                     # 文件名不符合预期,跳过
                     continue
 
     except Exception as e:
-        print(f"⚠️ 清理旧检查点失败: {e}")
+        print(f"[WARN]清理旧检查点失败: {e}")
 
 
 # 创建agent
@@ -153,7 +155,7 @@ def create_langchain_agent(
 
     tools = get_all_tools()
     # 添加调试信息
-    print(f"🔧🔧🔧🔧 Agent 创建调试信息:")
+    print(f"[DEBUG]Agent 创建调试信息:")
     print(f"  - 用户: {username}")
     print(f"  - Thread ID: {thread_id}")
     print(f"  - 后端地址: {backend_url}")

+ 2 - 2
core/agent_manager.py

@@ -15,13 +15,13 @@ class AgentManager:
     async def initialize(self):
         """异步初始化管理器"""
         self._is_shutdown = False
-        chat_logger.info("🔧🔧🔧🔧 Agent管理器初始化完成")
+        chat_logger.info("Agent管理器初始化完成")
 
     async def shutdown(self):
         """异步关闭管理器"""
         self._is_shutdown = True
         self._local_agent_cache.clear()
-        chat_logger.info("🧹🧹🧹🧹 Agent管理器已关闭")
+        chat_logger.info("Agent管理器已关闭")
 
     def _get_agent_config_key(
         self, thread_id: str, username: str, backend_url: str, token: str

+ 213 - 0
core/async_chat_service.py

@@ -0,0 +1,213 @@
+import asyncio
+import time
+from typing import Dict, Any
+from langchain_core.messages import HumanMessage
+from utils.logger import chat_logger, log_chat_entry
+from core.agent_manager import agent_manager
+from core.chat_result_manager import chat_result_manager
+
+
+class AsyncChatService:
+    """异步聊天服务 - 支持轮询方式的版本"""
+
+    def __init__(self):
+        self.agent_manager = agent_manager
+        self._thread_pool = None
+        self._processing_tasks = {}  # 正在处理的任务缓存
+
+    def _get_thread_pool(self):
+        """获取或创建线程池"""
+        if self._thread_pool is None:
+            import concurrent.futures
+
+            self._thread_pool = concurrent.futures.ThreadPoolExecutor(
+                max_workers=20,
+                thread_name_prefix="async_chat_worker",
+            )
+        return self._thread_pool
+
+    async def submit_chat_task(self, request_data: Dict[str, Any]) -> str:
+        """提交聊天任务(立即返回任务ID)"""
+        username = request_data["username"]
+        # 创建任务记录
+        task_id = chat_result_manager.create_task(request_data)
+        chat_logger.info(f"用户{username},已提交聊天任务: {task_id}")
+
+        # 异步执行任务
+        asyncio.create_task(self._process_chat_task(task_id, request_data))
+
+        return task_id
+
+    async def _process_chat_task(self, task_id: str, request_data: Dict[str, Any]):
+        """异步处理聊天任务"""
+        try:
+            # 更新状态为处理中
+            chat_result_manager.update_task_status(task_id, "processing")
+
+            # 提取请求数据
+            message = request_data["message"]
+            thread_id = request_data["thread_id"]
+            username = request_data["username"]
+            backend_url = request_data["backend_url"]
+            token = request_data["token"]
+            user_id = username
+
+            chat_logger.info(f"开始处理任务 - 任务ID={task_id}, 用户={user_id}")
+
+            # 异步获取agent实例
+            agent = await self.agent_manager.get_agent_instance(
+                thread_id=thread_id,
+                username=username,
+                backend_url=backend_url,
+                token=token,
+            )
+
+            # 在线程池中执行同步的Langchain操作
+            result = await self._run_agent_in_threadpool(
+                agent, message, thread_id, user_id
+            )
+
+            if not isinstance(result, dict) or "messages" not in result:
+                raise ValueError(f"Agent返回格式异常: {type(result)}")
+
+            # 处理结果
+            response_data = self._process_agent_result(result, user_id, request_data)
+
+            # 更新任务状态为完成
+            chat_result_manager.update_task_status(task_id, "completed", response_data)
+
+            chat_logger.info(f"任务处理完成 - 任务ID={task_id}")
+
+        except Exception as e:
+            error_msg = f"聊天处理失败: {str(e)}"
+            chat_logger.error(f"{error_msg} - 任务ID={task_id}")
+
+            # 更新任务状态为失败
+            chat_result_manager.update_task_status(
+                task_id, "failed", error_message=error_msg
+            )
+
+    async def _run_agent_in_threadpool(
+        self, agent, message: str, thread_id: str, user_id: str
+    ):
+        """在线程池中执行Langchain Agent"""
+        loop = asyncio.get_event_loop()
+        thread_pool = self._get_thread_pool()
+
+        # 准备输入
+        inputs = {"messages": [HumanMessage(content=message)]}
+        config = {"configurable": {"thread_id": thread_id}}
+
+        chat_logger.info(f"在线程池中执行Agent - 用户={user_id}")
+
+        try:
+            # 在线程池中执行同步操作
+            result = await loop.run_in_executor(
+                thread_pool, lambda: agent.invoke(inputs, config)
+            )
+            return result
+        except Exception as e:
+            chat_logger.error(f"Agent执行失败 - 用户={user_id}: {str(e)}")
+            raise
+
+    def _process_agent_result(
+        self, result: Dict[str, Any], user_id: str, request_data: Dict
+    ) -> Dict[str, Any]:
+        """处理Agent返回结果"""
+        all_messages = result["messages"]
+        processed_messages = []
+        all_ai_messages = []
+        all_tool_calls = []
+        final_answer = ""
+
+        for i, msg in enumerate(all_messages):
+            msg_data = {
+                "index": i,
+                "type": getattr(msg, "type", "unknown"),
+                "content": "",
+            }
+
+            # 获取内容
+            if hasattr(msg, "content"):
+                content = msg.content
+                if isinstance(content, str):
+                    msg_data["content"] = content
+                else:
+                    msg_data["content"] = str(content)
+
+            # 获取工具调用
+            if hasattr(msg, "tool_calls") and msg.tool_calls:
+                msg_data["tool_calls"] = msg.tool_calls
+                all_tool_calls.extend(msg.tool_calls)
+
+                for tool_call in msg.tool_calls:
+                    tool_name = tool_call.get("name", "unknown")
+                    tool_args = tool_call.get("args", {})
+                    chat_logger.info(f"工具调用 - 用户={user_id}, 工具={tool_name}")
+
+            if hasattr(msg, "tool_call_id"):
+                msg_data["tool_call_id"] = msg.tool_call_id
+
+            if hasattr(msg, "name"):
+                msg_data["name"] = msg.name
+
+            processed_messages.append(msg_data)
+
+            # 收集AI消息
+            if msg_data["type"] == "ai":
+                all_ai_messages.append(msg_data)
+                final_answer = msg_data["content"]
+
+        # 构建响应
+        response = {
+            "final_answer": final_answer,
+            "all_ai_messages": all_ai_messages,
+            "all_messages": processed_messages,
+            "tool_calls": all_tool_calls,
+            "thread_id": request_data["thread_id"],
+            "user_identifier": user_id,
+            "backend_config": {
+                "backend_url": request_data["backend_url"] or "未配置",
+                "username": request_data["username"],
+                "has_token": bool(request_data["token"]),
+            },
+            "success": True,
+        }
+
+        # 记录日志
+        log_chat_entry(user_id, request_data["message"], response)
+
+        return response
+
+    async def get_task_result(self, task_id: str) -> Dict[str, Any]:
+        """获取任务结果"""
+        task_info = chat_result_manager.get_task(task_id)
+        chat_logger.info(f"获取任务结果 - 任务ID={task_id}, 状态={task_info['status']}")
+        if not task_info:
+            return {
+                "success": False,
+                "error": f"任务不存在: {task_id}",
+                "task_id": task_id,
+            }
+
+        return {
+            "task_id": task_id,
+            "status": task_info["status"],
+            "response": task_info["response_data"],
+            "error": task_info["error_message"],
+            "created_at": task_info["created_at"],
+            "updated_at": task_info["updated_at"],
+            "success": task_info["status"] == "completed",
+        }
+
+    async def shutdown(self):
+        """关闭服务"""
+        if self._thread_pool:
+            self._thread_pool.shutdown(wait=False)
+            self._thread_pool = None
+        chat_result_manager.close()
+        chat_logger.info("异步聊天服务已关闭")
+
+
+# 全局实例
+async_chat_service = AsyncChatService()

+ 405 - 0
core/chat_result_manager.py

@@ -0,0 +1,405 @@
+import os
+import sqlite3
+import uuid
+import json
+import datetime
+import re
+from typing import Dict, Any, Optional
+from pathlib import Path
+from utils.logger import chat_logger
+
+
+class ChatResultManager:
+    """聊天结果管理器 - 支持按用户+日期分割的SQLite数据库存储"""
+
+    def __init__(self):
+        self._db_connections = {}  # 存储不同用户+日期组合的连接
+        self._task_db_cache = {}  # 缓存任务ID对应的数据库文件
+        self._init_tables()
+
+    def _get_day_number(self, date=None):
+        """获取日期编号 (YYYYMMDD 格式)"""
+        if date is None:
+            date = datetime.datetime.now()
+        return date.strftime("%Y%m%d")
+
+    def _get_db_key(self, username, day_number=None):
+        """生成数据库连接缓存key"""
+        if day_number is None:
+            day_number = self._get_day_number()
+        return f"{username}_{day_number}"
+
+    def _get_db_path(self, username, day_number=None):
+        """获取数据库文件路径"""
+        if day_number is None:
+            day_number = self._get_day_number()
+
+        # 数据库文件存放目录
+        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+        base_dir = os.path.join(project_root, "data", "chat_history")
+        os.makedirs(base_dir, exist_ok=True)
+
+        # 数据库文件名格式: chat_history_{username}_{day_number}.db
+        # 对用户名进行安全处理,避免特殊字符
+        safe_username = re.sub(r'[\\/:*?"<>|\0]', "_", username)
+        print(f"_get_db_path用户名: {safe_username}")
+        db_filename = f"chat_history_{safe_username}_{day_number}.db"
+        return os.path.join(base_dir, db_filename)
+
+    def _get_db_connection(self, username, day_number=None):
+        """获取数据库连接(按用户+日期分割)"""
+        if day_number is None:
+            day_number = self._get_day_number()
+
+        db_key = self._get_db_key(username, day_number)
+
+        # 如果连接不存在或需要重新连接,创建新连接
+        if db_key not in self._db_connections:
+            db_path = self._get_db_path(username, day_number)
+            conn = sqlite3.connect(db_path, check_same_thread=False)
+            conn.execute("PRAGMA journal_mode=WAL")
+            self._db_connections[db_key] = conn
+
+            # 初始化该数据库的表
+            self._init_tables_for_connection(conn)
+
+        return self._db_connections[db_key]
+
+    def _init_tables_for_connection(self, conn):
+        """为指定数据库连接初始化表结构"""
+        cursor = conn.cursor()
+
+        # 创建聊天任务表
+        cursor.execute(
+            """
+            CREATE TABLE IF NOT EXISTS chat_tasks (
+                task_id TEXT PRIMARY KEY,
+                status TEXT NOT NULL,  -- pending, processing, completed, failed
+                request_data TEXT NOT NULL,
+                response_data TEXT,
+                error_message TEXT,
+                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+                thread_id TEXT,
+                username TEXT
+            )
+        """
+        )
+
+        # 创建索引
+        cursor.execute("CREATE INDEX IF NOT EXISTS idx_status ON chat_tasks(status)")
+        cursor.execute(
+            "CREATE INDEX IF NOT EXISTS idx_thread_id ON chat_tasks(thread_id)"
+        )
+        cursor.execute(
+            "CREATE INDEX IF NOT EXISTS idx_username ON chat_tasks(username)"
+        )
+        cursor.execute(
+            "CREATE INDEX IF NOT EXISTS idx_created_at ON chat_tasks(created_at)"
+        )
+
+        conn.commit()
+
+    def _init_tables(self):
+        """初始化所有数据库表(兼容性方法)"""
+        # 为默认用户初始化表,确保系统启动时至少有一个数据库可用
+        default_username = "default"
+        conn = self._get_db_connection(default_username)
+        self._init_tables_for_connection(conn)
+
+    def create_task(self, request_data: Dict[str, Any]) -> str:
+        """创建新的聊天任务"""
+        task_id = str(uuid.uuid4())
+        username = request_data.get("username", "default")
+        print(f"创建任务: {task_id} (用户: {username})")
+        conn = self._get_db_connection(username)
+        cursor = conn.cursor()
+
+        cursor.execute(
+            """
+            INSERT INTO chat_tasks 
+            (task_id, status, request_data, thread_id, username)
+            VALUES (?, ?, ?, ?, ?)
+        """,
+            (
+                task_id,
+                "pending",
+                json.dumps(request_data, ensure_ascii=False),
+                request_data.get("thread_id", "default"),
+                username,
+            ),
+        )
+
+        conn.commit()
+
+        # 缓存任务ID对应的数据库信息
+        self._task_db_cache[task_id] = {
+            "username": username,
+            "day_number": self._get_day_number(),
+            "db_file": os.path.basename(self._get_db_path(username)),
+        }
+
+        chat_logger.info(f"创建聊天任务: {task_id} (用户: {username})")
+        return task_id
+
+    def update_task_status(
+        self,
+        task_id: str,
+        status: str,
+        response_data: Dict[str, Any] = None,
+        error_message: str = None,
+    ):
+        """更新任务状态"""
+        # 首先检查缓存中是否有该任务对应的数据库信息
+        if task_id in self._task_db_cache:
+            cache_info = self._task_db_cache[task_id]
+            username = cache_info["username"]
+            day_number = cache_info["day_number"]
+
+            # 尝试在缓存的数据库中更新
+            db_path = self._get_db_path(username, day_number)
+            if os.path.exists(db_path):
+                try:
+                    self._update_task_in_db(
+                        username,
+                        task_id,
+                        status,
+                        response_data,
+                        error_message,
+                        day_number,
+                    )
+                    return
+                except Exception as e:
+                    chat_logger.warning(
+                        f"使用缓存更新任务状态失败: {task_id} - {str(e)}"
+                    )
+                    # 缓存可能已失效,继续使用备用策略
+                    del self._task_db_cache[task_id]
+
+        # 备用策略:如果缓存不存在或失效,使用原来的get_task方法查找
+        task_info = self.get_task(task_id)
+        if not task_info:
+            chat_logger.error(f"更新任务状态失败: 任务不存在 - {task_id}")
+            return
+
+        username = task_info.get("username", "default")
+        self._update_task_in_db(username, task_id, status, response_data, error_message)
+
+    def _update_task_in_db(
+        self,
+        username,
+        task_id,
+        status,
+        response_data=None,
+        error_message=None,
+        day_number=None,
+    ):
+        """在指定数据库中更新任务状态(内部方法)"""
+        conn = self._get_db_connection(username, day_number)
+        cursor = conn.cursor()
+
+        cursor.execute(
+            """
+            UPDATE chat_tasks 
+            SET status = ?, 
+                response_data = ?, 
+                error_message = ?,
+                updated_at = CURRENT_TIMESTAMP
+            WHERE task_id = ?
+        """,
+            (
+                status,
+                (
+                    json.dumps(response_data, ensure_ascii=False)
+                    if response_data
+                    else None
+                ),
+                error_message,
+                task_id,
+            ),
+        )
+
+        conn.commit()
+        chat_logger.info(f"更新任务状态: {task_id} -> {status} (用户: {username})")
+
+    def _search_task_in_db(self, db_path, task_id):
+        """在指定数据库中搜索任务"""
+        try:
+            conn = sqlite3.connect(db_path, check_same_thread=False)
+            cursor = conn.cursor()
+
+            cursor.execute(
+                """
+                SELECT task_id, status, request_data, response_data, 
+                       error_message, created_at, updated_at, thread_id, username
+                FROM chat_tasks 
+                WHERE task_id = ?
+            """,
+                (task_id,),
+            )
+
+            row = cursor.fetchone()
+            conn.close()
+
+            if row:
+                return {
+                    "task_id": row[0],
+                    "status": row[1],
+                    "request_data": json.loads(row[2]) if row[2] else {},
+                    "response_data": json.loads(row[3]) if row[3] else None,
+                    "error_message": row[4],
+                    "created_at": row[5],
+                    "updated_at": row[6],
+                    "thread_id": row[7],
+                    "username": row[8],
+                }
+        except Exception as e:
+            chat_logger.warning(
+                f"搜索任务时数据库访问异常: {os.path.basename(db_path)} - {str(e)}"
+            )
+
+        return None
+
+    def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
+        """获取任务信息 - 智能查找策略"""
+        # 1. 首先检查缓存中是否有该任务对应的数据库信息
+        if task_id in self._task_db_cache:
+            cache_info = self._task_db_cache[task_id]
+            username = cache_info["username"]
+            day_number = cache_info["day_number"]
+
+            # 尝试在缓存的数据库中查找
+            db_path = self._get_db_path(username, day_number)
+            if os.path.exists(db_path):
+                result = self._search_task_in_db(db_path, task_id)
+                if result:
+                    return result
+            else:
+                # 如果缓存的文件不存在,清除缓存
+                del self._task_db_cache[task_id]
+
+        # 2. 如果没有缓存,尝试从请求上下文推断用户名
+        # 这里可以扩展为从请求头或其他上下文获取用户名
+        # 目前先使用默认策略
+
+        # 3. 按优先级搜索策略
+        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+        base_dir = os.path.join(project_root, "data", "chat_history")
+
+        if not os.path.exists(base_dir):
+            return None
+
+        # 策略1: 搜索当天所有用户的数据库(最可能的情况)
+        today = self._get_day_number()
+        for db_file in os.listdir(base_dir):
+            if db_file.endswith(f"_{today}.db") and db_file.startswith("chat_history_"):
+                db_path = os.path.join(base_dir, db_file)
+                result = self._search_task_in_db(db_path, task_id)
+                if result:
+                    # 找到后缓存数据库信息
+                    self._task_db_cache[task_id] = {
+                        "username": result["username"],
+                        "day_number": today,
+                        "db_file": db_file,
+                    }
+                    return result
+
+        # 策略2: 搜索最近7天内的数据库(次优情况)
+        for days_back in range(1, 8):
+            search_date = datetime.datetime.now() - datetime.timedelta(days=days_back)
+            search_day = self._get_day_number(search_date)
+
+            for db_file in os.listdir(base_dir):
+                if db_file.endswith(f"_{search_day}.db") and db_file.startswith(
+                    "chat_history_"
+                ):
+                    db_path = os.path.join(base_dir, db_file)
+                    result = self._search_task_in_db(db_path, task_id)
+                    if result:
+                        # 找到后缓存数据库信息
+                        self._task_db_cache[task_id] = {
+                            "username": result["username"],
+                            "day_number": search_day,
+                            "db_file": db_file,
+                        }
+                        return result
+
+        # 策略3: 最后搜索所有数据库(最坏情况)
+        for db_file in os.listdir(base_dir):
+            if db_file.endswith(".db") and db_file.startswith("chat_history_"):
+                db_path = os.path.join(base_dir, db_file)
+                result = self._search_task_in_db(db_path, task_id)
+                if result:
+                    # 从文件名中提取用户名和日期
+                    parts = (
+                        db_file.replace("chat_history_", "")
+                        .replace(".db", "")
+                        .split("_")
+                    )
+                    if len(parts) >= 2:
+                        username_part = "_".join(parts[:-1])  # 用户名可能包含下划线
+                        day_number = parts[-1]
+
+                        self._task_db_cache[task_id] = {
+                            "username": result["username"],
+                            "day_number": day_number,
+                            "db_file": db_file,
+                        }
+                    return result
+
+        return None
+
+    def cleanup_old_tasks(self, max_days=7):
+        """清理超过指定天数的旧任务数据"""
+        cutoff_date = datetime.datetime.now() - datetime.timedelta(days=max_days)
+        cutoff_day = self._get_day_number(cutoff_date)
+
+        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+        base_dir = os.path.join(project_root, "data", "chat_history")
+
+        if not os.path.exists(base_dir):
+            return
+
+        cleaned_count = 0
+        # 清理旧的数据库文件
+        for db_file in os.listdir(base_dir):
+            if db_file.endswith(".db") and db_file.startswith("chat_history_"):
+                # 从文件名中提取日期部分
+                # 格式: chat_history_{username}_{YYYYMMDD}.db
+                parts = db_file.split("_")
+                if len(parts) >= 3:
+                    try:
+                        file_day = parts[-1].replace(".db", "")
+                        if file_day < cutoff_day:
+                            db_path = os.path.join(base_dir, db_file)
+                            os.remove(db_path)
+                            cleaned_count += 1
+                            chat_logger.info(f"清理旧数据库文件: {db_file}")
+
+                            # 清理对应的缓存
+                            task_ids_to_remove = []
+                            for tid, cache_info in self._task_db_cache.items():
+                                if cache_info["db_file"] == db_file:
+                                    task_ids_to_remove.append(tid)
+
+                            for tid in task_ids_to_remove:
+                                del self._task_db_cache[tid]
+                    except (ValueError, IndexError):
+                        continue
+
+        chat_logger.info(f"清理完成: 删除了 {cleaned_count} 个旧数据库文件")
+
+    def close(self):
+        """关闭所有数据库连接"""
+        for db_key, conn in self._db_connections.items():
+            try:
+                conn.close()
+            except Exception as e:
+                chat_logger.warning(f"关闭数据库连接失败: {db_key} - {str(e)}")
+        self._db_connections.clear()
+        self._task_db_cache.clear()
+        chat_logger.info("所有数据库连接已关闭")
+
+
+# 全局实例
+chat_result_manager = ChatResultManager()

+ 8 - 4
core/lifespan_manager.py

@@ -3,6 +3,7 @@ from fastapi import FastAPI
 from core.agent_manager import agent_manager
 from utils.registration_manager import registration_manager
 from utils.logger import chat_logger
+from core.chat_result_manager import chat_result_manager
 
 
 @asynccontextmanager
@@ -12,16 +13,19 @@ async def lifespan(app: FastAPI):
         # 启动时检查注册状态,但不阻止启动
         registration_status = await registration_manager.check_registration()
         if not registration_status:
-            chat_logger.warning("⚠️ 服务启动:注册检查未通过,/chat接口将受限")
+            chat_logger.warning("服务启动:注册检查未通过,/chat接口将受限")
         else:
-            chat_logger.info("服务启动:注册检查通过")
+            chat_logger.info("服务启动:注册检查通过")
 
         await agent_manager.initialize()
-        chat_logger.info("🚀🚀 AI助手服务启动")
+        chat_logger.info("AI助手服务启动")
         yield
     finally:
         cleared_count = await agent_manager.shutdown()
-        chat_logger.info(f"🛑🛑🛑 AI助手服务停止,清理了 {cleared_count} 个Agent实例")
+        # 清理聊天结果管理器
+        chat_result_manager.close()
+        chat_result_manager.cleanup_old_tasks(max_days=7)
+        chat_logger.info(f"AI助手服务停止")
 
 
 def create_lifespan():

+ 18 - 18
tools/tool_factory.py

@@ -39,18 +39,18 @@ def get_all_tools() -> List[BaseTool]:
             module_name = file_path.stem.split(".")[0]
             if module_name not in tool_files:  # 避免重复添加
                 tool_files.append(module_name)
-                print(f"📦 发现编译后工具文件: {module_name}")
+                print(f"发现编译后工具文件: {module_name}")
 
     if not tool_files:
         tool_files = ["knowledge_tools", "sale_tools", "ware_tools"]
-        print("⚠️ 使用默认工具列表")
+        print("使用默认工具列表")
 
     for module_name in tool_files:
         try:
             # 导入模块
             full_module_path = f"tools.{module_name}"
             module = importlib.import_module(full_module_path)
-            # print(f"加载模块: {module_name}")
+            # print(f"加载模块: {module_name}")
 
             # 查找工具 - 使用更全面的方法
             tool_count = 0
@@ -63,14 +63,14 @@ def get_all_tools() -> List[BaseTool]:
                 attr = getattr(module, attr_name)
 
                 # # 详细调试信息
-                # print(f"  🔍 检查 {attr_name}:")
+                # print(f"  检查 {attr_name}:")
                 # print(f"    类型: {type(attr)}")
 
                 # 检查是否是BaseTool实例
                 if isinstance(attr, BaseTool):
                     tools.append(attr)
                     tool_count += 1
-                    # print(f"  发现BaseTool工具: {getattr(attr, 'name', attr_name)}")
+                    # print(f"  发现BaseTool工具: {getattr(attr, 'name', attr_name)}")
                     continue
 
                 # 检查是否是函数且具有工具属性
@@ -88,7 +88,7 @@ def get_all_tools() -> List[BaseTool]:
                     if is_tool_function(attr):
                         tools.append(attr)
                         tool_count += 1
-                        # print(f"  发现工具函数: {attr_name}")
+                        # print(f"  发现工具函数: {attr_name}")
 
             # 方法2: 检查模块的全局变量
             print(f"  🔍 检查模块全局变量...")
@@ -101,11 +101,11 @@ def get_all_tools() -> List[BaseTool]:
                         tools.append(value)
                         tool_count += 1
                         # print(
-                        #     f"  从全局变量发现BaseTool工具: {getattr(value, 'name', name)}"
+                        #     f"  从全局变量发现BaseTool工具: {getattr(value, 'name', name)}"
                         # )
 
             if tool_count == 0:
-                # print(f"  ⚠️ 模块 {module_name} 中未发现工具")
+                # print(f"  模块 {module_name} 中未发现工具")
                 # 尝试手动创建工具
                 manual_tools = create_tools_manually(module_name, module)
                 if manual_tools:
@@ -113,12 +113,12 @@ def get_all_tools() -> List[BaseTool]:
                     tool_count = len(manual_tools)
                     # print(f"  🔧 手动创建了 {tool_count} 个工具")
             else:
-                print(f"  📊 模块 {module_name} 中发现 {tool_count} 个工具")
+                print(f"  模块 {module_name} 中发现 {tool_count} 个工具")
 
         except Exception as e:
-            print(f"加载模块 {module_name} 失败: {e}")
+            print(f"加载模块 {module_name} 失败: {e}")
 
-    print(f"🎯 总共发现 {len(tools)} 个工具")
+    print(f"总共发现 {len(tools)} 个工具")
 
     # # 打印工具详情
     # for i, tool in enumerate(tools):
@@ -145,17 +145,17 @@ def create_tools_manually(module_name: str, module) -> List[BaseTool]:
                     # 使用@tool装饰器重新创建工具
                     tool_instance = tool(func)
                     tools.append(tool_instance)
-                    print(f"  🔧 手动创建工具: get_knowledge_list")
+                    print(f"  手动创建工具: get_knowledge_list")
 
             if hasattr(module, "get_knowledge_content"):
                 func = getattr(module, "get_knowledge_content")
                 if callable(func):
                     tool_instance = tool(func)
                     tools.append(tool_instance)
-                    print(f"  🔧 手动创建工具: get_knowledge_content")
+                    print(f"  手动创建工具: get_knowledge_content")
 
         except Exception as e:
-            print(f"  手动创建知识库工具失败: {e}")
+            print(f"  手动创建知识库工具失败: {e}")
 
     elif module_name == "sale_tools":
         # 手动创建销售工具
@@ -167,10 +167,10 @@ def create_tools_manually(module_name: str, module) -> List[BaseTool]:
                 if callable(func):
                     tool_instance = tool(func)
                     tools.append(tool_instance)
-                    print(f"  🔧 手动创建工具: get_sale_amt")
+                    print(f"  手动创建工具: get_sale_amt")
 
         except Exception as e:
-            print(f"  手动创建销售工具失败: {e}")
+            print(f"  手动创建销售工具失败: {e}")
 
     return tools
 
@@ -215,7 +215,7 @@ def is_tool_function(obj) -> bool:
 # 测试函数
 def test_tool_detection():
     """测试工具检测"""
-    print("🧪 测试工具检测...")
+    print("测试工具检测...")
 
     # 导入一个模块测试
     import tools.knowledge_tools as kt
@@ -225,7 +225,7 @@ def test_tool_detection():
     for attr_name in dir(kt):
         if not attr_name.startswith("_"):
             attr = getattr(kt, attr_name)
-            print(f"\n🔍 检查 {attr_name}:")
+            print(f"\n检查 {attr_name}:")
             print(f"  类型: {type(attr)}")
             print(f"  可调用: {callable(attr)}")
 

+ 6 - 0
utils/logger.py

@@ -31,6 +31,12 @@ def setup_logging():
         # 控制台处理器
         console_handler = logging.StreamHandler()
         console_handler.setLevel(logging.INFO)
+        # 确保控制台输出使用UTF-8编码
+        import sys
+        if hasattr(sys.stdout, 'reconfigure'):
+            sys.stdout.reconfigure(encoding='utf-8')
+        if hasattr(sys.stderr, 'reconfigure'):
+            sys.stderr.reconfigure(encoding='utf-8')
 
         # 格式化
         formatter = logging.Formatter(

+ 2 - 2
utils/registration_manager.py

@@ -32,9 +32,9 @@ class RegistrationManager:
             self._last_check_time = current_time
 
             if is_valid:
-                chat_logger.info(f"{message}")
+                chat_logger.info(f"{message}")
             else:
-                chat_logger.warning(f"{message}")
+                chat_logger.warning(f"{message}")
 
             return is_valid