|
|
@@ -0,0 +1,405 @@
|
|
|
+import os
|
|
|
+import sqlite3
|
|
|
+import uuid
|
|
|
+import json
|
|
|
+import datetime
|
|
|
+import re
|
|
|
+from typing import Dict, Any, Optional
|
|
|
+from pathlib import Path
|
|
|
+from utils.logger import chat_logger
|
|
|
+
|
|
|
+
|
|
|
+class ChatResultManager:
|
|
|
+ """聊天结果管理器 - 支持按用户+日期分割的SQLite数据库存储"""
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self._db_connections = {} # 存储不同用户+日期组合的连接
|
|
|
+ self._task_db_cache = {} # 缓存任务ID对应的数据库文件
|
|
|
+ self._init_tables()
|
|
|
+
|
|
|
+ def _get_day_number(self, date=None):
|
|
|
+ """获取日期编号 (YYYYMMDD 格式)"""
|
|
|
+ if date is None:
|
|
|
+ date = datetime.datetime.now()
|
|
|
+ return date.strftime("%Y%m%d")
|
|
|
+
|
|
|
+ def _get_db_key(self, username, day_number=None):
|
|
|
+ """生成数据库连接缓存key"""
|
|
|
+ if day_number is None:
|
|
|
+ day_number = self._get_day_number()
|
|
|
+ return f"{username}_{day_number}"
|
|
|
+
|
|
|
+ def _get_db_path(self, username, day_number=None):
|
|
|
+ """获取数据库文件路径"""
|
|
|
+ if day_number is None:
|
|
|
+ day_number = self._get_day_number()
|
|
|
+
|
|
|
+ # 数据库文件存放目录
|
|
|
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
|
+ base_dir = os.path.join(project_root, "data", "chat_history")
|
|
|
+ os.makedirs(base_dir, exist_ok=True)
|
|
|
+
|
|
|
+ # 数据库文件名格式: chat_history_{username}_{day_number}.db
|
|
|
+ # 对用户名进行安全处理,避免特殊字符
|
|
|
+ safe_username = re.sub(r'[\\/:*?"<>|\0]', "_", username)
|
|
|
+ print(f"_get_db_path用户名: {safe_username}")
|
|
|
+ db_filename = f"chat_history_{safe_username}_{day_number}.db"
|
|
|
+ return os.path.join(base_dir, db_filename)
|
|
|
+
|
|
|
+ def _get_db_connection(self, username, day_number=None):
|
|
|
+ """获取数据库连接(按用户+日期分割)"""
|
|
|
+ if day_number is None:
|
|
|
+ day_number = self._get_day_number()
|
|
|
+
|
|
|
+ db_key = self._get_db_key(username, day_number)
|
|
|
+
|
|
|
+ # 如果连接不存在或需要重新连接,创建新连接
|
|
|
+ if db_key not in self._db_connections:
|
|
|
+ db_path = self._get_db_path(username, day_number)
|
|
|
+ conn = sqlite3.connect(db_path, check_same_thread=False)
|
|
|
+ conn.execute("PRAGMA journal_mode=WAL")
|
|
|
+ self._db_connections[db_key] = conn
|
|
|
+
|
|
|
+ # 初始化该数据库的表
|
|
|
+ self._init_tables_for_connection(conn)
|
|
|
+
|
|
|
+ return self._db_connections[db_key]
|
|
|
+
|
|
|
+ def _init_tables_for_connection(self, conn):
|
|
|
+ """为指定数据库连接初始化表结构"""
|
|
|
+ cursor = conn.cursor()
|
|
|
+
|
|
|
+ # 创建聊天任务表
|
|
|
+ cursor.execute(
|
|
|
+ """
|
|
|
+ CREATE TABLE IF NOT EXISTS chat_tasks (
|
|
|
+ task_id TEXT PRIMARY KEY,
|
|
|
+ status TEXT NOT NULL, -- pending, processing, completed, failed
|
|
|
+ request_data TEXT NOT NULL,
|
|
|
+ response_data TEXT,
|
|
|
+ error_message TEXT,
|
|
|
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
+ thread_id TEXT,
|
|
|
+ username TEXT
|
|
|
+ )
|
|
|
+ """
|
|
|
+ )
|
|
|
+
|
|
|
+ # 创建索引
|
|
|
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_status ON chat_tasks(status)")
|
|
|
+ cursor.execute(
|
|
|
+ "CREATE INDEX IF NOT EXISTS idx_thread_id ON chat_tasks(thread_id)"
|
|
|
+ )
|
|
|
+ cursor.execute(
|
|
|
+ "CREATE INDEX IF NOT EXISTS idx_username ON chat_tasks(username)"
|
|
|
+ )
|
|
|
+ cursor.execute(
|
|
|
+ "CREATE INDEX IF NOT EXISTS idx_created_at ON chat_tasks(created_at)"
|
|
|
+ )
|
|
|
+
|
|
|
+ conn.commit()
|
|
|
+
|
|
|
+ def _init_tables(self):
|
|
|
+ """初始化所有数据库表(兼容性方法)"""
|
|
|
+ # 为默认用户初始化表,确保系统启动时至少有一个数据库可用
|
|
|
+ default_username = "default"
|
|
|
+ conn = self._get_db_connection(default_username)
|
|
|
+ self._init_tables_for_connection(conn)
|
|
|
+
|
|
|
+ def create_task(self, request_data: Dict[str, Any]) -> str:
|
|
|
+ """创建新的聊天任务"""
|
|
|
+ task_id = str(uuid.uuid4())
|
|
|
+ username = request_data.get("username", "default")
|
|
|
+ print(f"创建任务: {task_id} (用户: {username})")
|
|
|
+ conn = self._get_db_connection(username)
|
|
|
+ cursor = conn.cursor()
|
|
|
+
|
|
|
+ cursor.execute(
|
|
|
+ """
|
|
|
+ INSERT INTO chat_tasks
|
|
|
+ (task_id, status, request_data, thread_id, username)
|
|
|
+ VALUES (?, ?, ?, ?, ?)
|
|
|
+ """,
|
|
|
+ (
|
|
|
+ task_id,
|
|
|
+ "pending",
|
|
|
+ json.dumps(request_data, ensure_ascii=False),
|
|
|
+ request_data.get("thread_id", "default"),
|
|
|
+ username,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ conn.commit()
|
|
|
+
|
|
|
+ # 缓存任务ID对应的数据库信息
|
|
|
+ self._task_db_cache[task_id] = {
|
|
|
+ "username": username,
|
|
|
+ "day_number": self._get_day_number(),
|
|
|
+ "db_file": os.path.basename(self._get_db_path(username)),
|
|
|
+ }
|
|
|
+
|
|
|
+ chat_logger.info(f"创建聊天任务: {task_id} (用户: {username})")
|
|
|
+ return task_id
|
|
|
+
|
|
|
+ def update_task_status(
|
|
|
+ self,
|
|
|
+ task_id: str,
|
|
|
+ status: str,
|
|
|
+ response_data: Dict[str, Any] = None,
|
|
|
+ error_message: str = None,
|
|
|
+ ):
|
|
|
+ """更新任务状态"""
|
|
|
+ # 首先检查缓存中是否有该任务对应的数据库信息
|
|
|
+ if task_id in self._task_db_cache:
|
|
|
+ cache_info = self._task_db_cache[task_id]
|
|
|
+ username = cache_info["username"]
|
|
|
+ day_number = cache_info["day_number"]
|
|
|
+
|
|
|
+ # 尝试在缓存的数据库中更新
|
|
|
+ db_path = self._get_db_path(username, day_number)
|
|
|
+ if os.path.exists(db_path):
|
|
|
+ try:
|
|
|
+ self._update_task_in_db(
|
|
|
+ username,
|
|
|
+ task_id,
|
|
|
+ status,
|
|
|
+ response_data,
|
|
|
+ error_message,
|
|
|
+ day_number,
|
|
|
+ )
|
|
|
+ return
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.warning(
|
|
|
+ f"使用缓存更新任务状态失败: {task_id} - {str(e)}"
|
|
|
+ )
|
|
|
+ # 缓存可能已失效,继续使用备用策略
|
|
|
+ del self._task_db_cache[task_id]
|
|
|
+
|
|
|
+ # 备用策略:如果缓存不存在或失效,使用原来的get_task方法查找
|
|
|
+ task_info = self.get_task(task_id)
|
|
|
+ if not task_info:
|
|
|
+ chat_logger.error(f"更新任务状态失败: 任务不存在 - {task_id}")
|
|
|
+ return
|
|
|
+
|
|
|
+ username = task_info.get("username", "default")
|
|
|
+ self._update_task_in_db(username, task_id, status, response_data, error_message)
|
|
|
+
|
|
|
+ def _update_task_in_db(
|
|
|
+ self,
|
|
|
+ username,
|
|
|
+ task_id,
|
|
|
+ status,
|
|
|
+ response_data=None,
|
|
|
+ error_message=None,
|
|
|
+ day_number=None,
|
|
|
+ ):
|
|
|
+ """在指定数据库中更新任务状态(内部方法)"""
|
|
|
+ conn = self._get_db_connection(username, day_number)
|
|
|
+ cursor = conn.cursor()
|
|
|
+
|
|
|
+ cursor.execute(
|
|
|
+ """
|
|
|
+ UPDATE chat_tasks
|
|
|
+ SET status = ?,
|
|
|
+ response_data = ?,
|
|
|
+ error_message = ?,
|
|
|
+ updated_at = CURRENT_TIMESTAMP
|
|
|
+ WHERE task_id = ?
|
|
|
+ """,
|
|
|
+ (
|
|
|
+ status,
|
|
|
+ (
|
|
|
+ json.dumps(response_data, ensure_ascii=False)
|
|
|
+ if response_data
|
|
|
+ else None
|
|
|
+ ),
|
|
|
+ error_message,
|
|
|
+ task_id,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ conn.commit()
|
|
|
+ chat_logger.info(f"更新任务状态: {task_id} -> {status} (用户: {username})")
|
|
|
+
|
|
|
+ def _search_task_in_db(self, db_path, task_id):
|
|
|
+ """在指定数据库中搜索任务"""
|
|
|
+ try:
|
|
|
+ conn = sqlite3.connect(db_path, check_same_thread=False)
|
|
|
+ cursor = conn.cursor()
|
|
|
+
|
|
|
+ cursor.execute(
|
|
|
+ """
|
|
|
+ SELECT task_id, status, request_data, response_data,
|
|
|
+ error_message, created_at, updated_at, thread_id, username
|
|
|
+ FROM chat_tasks
|
|
|
+ WHERE task_id = ?
|
|
|
+ """,
|
|
|
+ (task_id,),
|
|
|
+ )
|
|
|
+
|
|
|
+ row = cursor.fetchone()
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ if row:
|
|
|
+ return {
|
|
|
+ "task_id": row[0],
|
|
|
+ "status": row[1],
|
|
|
+ "request_data": json.loads(row[2]) if row[2] else {},
|
|
|
+ "response_data": json.loads(row[3]) if row[3] else None,
|
|
|
+ "error_message": row[4],
|
|
|
+ "created_at": row[5],
|
|
|
+ "updated_at": row[6],
|
|
|
+ "thread_id": row[7],
|
|
|
+ "username": row[8],
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.warning(
|
|
|
+ f"搜索任务时数据库访问异常: {os.path.basename(db_path)} - {str(e)}"
|
|
|
+ )
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+ def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
|
|
|
+ """获取任务信息 - 智能查找策略"""
|
|
|
+ # 1. 首先检查缓存中是否有该任务对应的数据库信息
|
|
|
+ if task_id in self._task_db_cache:
|
|
|
+ cache_info = self._task_db_cache[task_id]
|
|
|
+ username = cache_info["username"]
|
|
|
+ day_number = cache_info["day_number"]
|
|
|
+
|
|
|
+ # 尝试在缓存的数据库中查找
|
|
|
+ db_path = self._get_db_path(username, day_number)
|
|
|
+ if os.path.exists(db_path):
|
|
|
+ result = self._search_task_in_db(db_path, task_id)
|
|
|
+ if result:
|
|
|
+ return result
|
|
|
+ else:
|
|
|
+ # 如果缓存的文件不存在,清除缓存
|
|
|
+ del self._task_db_cache[task_id]
|
|
|
+
|
|
|
+ # 2. 如果没有缓存,尝试从请求上下文推断用户名
|
|
|
+ # 这里可以扩展为从请求头或其他上下文获取用户名
|
|
|
+ # 目前先使用默认策略
|
|
|
+
|
|
|
+ # 3. 按优先级搜索策略
|
|
|
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
|
+ base_dir = os.path.join(project_root, "data", "chat_history")
|
|
|
+
|
|
|
+ if not os.path.exists(base_dir):
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 策略1: 搜索当天所有用户的数据库(最可能的情况)
|
|
|
+ today = self._get_day_number()
|
|
|
+ for db_file in os.listdir(base_dir):
|
|
|
+ if db_file.endswith(f"_{today}.db") and db_file.startswith("chat_history_"):
|
|
|
+ db_path = os.path.join(base_dir, db_file)
|
|
|
+ result = self._search_task_in_db(db_path, task_id)
|
|
|
+ if result:
|
|
|
+ # 找到后缓存数据库信息
|
|
|
+ self._task_db_cache[task_id] = {
|
|
|
+ "username": result["username"],
|
|
|
+ "day_number": today,
|
|
|
+ "db_file": db_file,
|
|
|
+ }
|
|
|
+ return result
|
|
|
+
|
|
|
+ # 策略2: 搜索最近7天内的数据库(次优情况)
|
|
|
+ for days_back in range(1, 8):
|
|
|
+ search_date = datetime.datetime.now() - datetime.timedelta(days=days_back)
|
|
|
+ search_day = self._get_day_number(search_date)
|
|
|
+
|
|
|
+ for db_file in os.listdir(base_dir):
|
|
|
+ if db_file.endswith(f"_{search_day}.db") and db_file.startswith(
|
|
|
+ "chat_history_"
|
|
|
+ ):
|
|
|
+ db_path = os.path.join(base_dir, db_file)
|
|
|
+ result = self._search_task_in_db(db_path, task_id)
|
|
|
+ if result:
|
|
|
+ # 找到后缓存数据库信息
|
|
|
+ self._task_db_cache[task_id] = {
|
|
|
+ "username": result["username"],
|
|
|
+ "day_number": search_day,
|
|
|
+ "db_file": db_file,
|
|
|
+ }
|
|
|
+ return result
|
|
|
+
|
|
|
+ # 策略3: 最后搜索所有数据库(最坏情况)
|
|
|
+ for db_file in os.listdir(base_dir):
|
|
|
+ if db_file.endswith(".db") and db_file.startswith("chat_history_"):
|
|
|
+ db_path = os.path.join(base_dir, db_file)
|
|
|
+ result = self._search_task_in_db(db_path, task_id)
|
|
|
+ if result:
|
|
|
+ # 从文件名中提取用户名和日期
|
|
|
+ parts = (
|
|
|
+ db_file.replace("chat_history_", "")
|
|
|
+ .replace(".db", "")
|
|
|
+ .split("_")
|
|
|
+ )
|
|
|
+ if len(parts) >= 2:
|
|
|
+ username_part = "_".join(parts[:-1]) # 用户名可能包含下划线
|
|
|
+ day_number = parts[-1]
|
|
|
+
|
|
|
+ self._task_db_cache[task_id] = {
|
|
|
+ "username": result["username"],
|
|
|
+ "day_number": day_number,
|
|
|
+ "db_file": db_file,
|
|
|
+ }
|
|
|
+ return result
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+ def cleanup_old_tasks(self, max_days=7):
|
|
|
+ """清理超过指定天数的旧任务数据"""
|
|
|
+ cutoff_date = datetime.datetime.now() - datetime.timedelta(days=max_days)
|
|
|
+ cutoff_day = self._get_day_number(cutoff_date)
|
|
|
+
|
|
|
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
|
+ base_dir = os.path.join(project_root, "data", "chat_history")
|
|
|
+
|
|
|
+ if not os.path.exists(base_dir):
|
|
|
+ return
|
|
|
+
|
|
|
+ cleaned_count = 0
|
|
|
+ # 清理旧的数据库文件
|
|
|
+ for db_file in os.listdir(base_dir):
|
|
|
+ if db_file.endswith(".db") and db_file.startswith("chat_history_"):
|
|
|
+ # 从文件名中提取日期部分
|
|
|
+ # 格式: chat_history_{username}_{YYYYMMDD}.db
|
|
|
+ parts = db_file.split("_")
|
|
|
+ if len(parts) >= 3:
|
|
|
+ try:
|
|
|
+ file_day = parts[-1].replace(".db", "")
|
|
|
+ if file_day < cutoff_day:
|
|
|
+ db_path = os.path.join(base_dir, db_file)
|
|
|
+ os.remove(db_path)
|
|
|
+ cleaned_count += 1
|
|
|
+ chat_logger.info(f"清理旧数据库文件: {db_file}")
|
|
|
+
|
|
|
+ # 清理对应的缓存
|
|
|
+ task_ids_to_remove = []
|
|
|
+ for tid, cache_info in self._task_db_cache.items():
|
|
|
+ if cache_info["db_file"] == db_file:
|
|
|
+ task_ids_to_remove.append(tid)
|
|
|
+
|
|
|
+ for tid in task_ids_to_remove:
|
|
|
+ del self._task_db_cache[tid]
|
|
|
+ except (ValueError, IndexError):
|
|
|
+ continue
|
|
|
+
|
|
|
+ chat_logger.info(f"清理完成: 删除了 {cleaned_count} 个旧数据库文件")
|
|
|
+
|
|
|
+ def close(self):
|
|
|
+ """关闭所有数据库连接"""
|
|
|
+ for db_key, conn in self._db_connections.items():
|
|
|
+ try:
|
|
|
+ conn.close()
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.warning(f"关闭数据库连接失败: {db_key} - {str(e)}")
|
|
|
+ self._db_connections.clear()
|
|
|
+ self._task_db_cache.clear()
|
|
|
+ chat_logger.info("所有数据库连接已关闭")
|
|
|
+
|
|
|
+
|
|
|
+# 全局实例
|
|
|
+chat_result_manager = ChatResultManager()
|