import os import sqlite3 import uuid import json import datetime import re from typing import Dict, Any, Optional from pathlib import Path from utils.logger import chat_logger class ChatResultManager: """聊天结果管理器 - 支持按用户+日期分割的SQLite数据库存储""" def __init__(self): self._db_connections = {} # 存储不同用户+日期组合的连接 self._task_db_cache = {} # 缓存任务ID对应的数据库文件 self._init_tables() def _get_day_number(self, date=None): """获取日期编号 (YYYYMMDD 格式)""" if date is None: date = datetime.datetime.now() return date.strftime("%Y%m%d") def _get_db_key(self, username, day_number=None): """生成数据库连接缓存key""" if day_number is None: day_number = self._get_day_number() return f"{username}_{day_number}" def _get_db_path(self, username, day_number=None): """获取数据库文件路径""" if day_number is None: day_number = self._get_day_number() # 数据库文件存放目录 project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) base_dir = os.path.join(project_root, "data", "chat_history") os.makedirs(base_dir, exist_ok=True) # 数据库文件名格式: chat_history_{username}_{day_number}.db # 对用户名进行安全处理,避免特殊字符 safe_username = re.sub(r'[\\/:*?"<>|\0]', "_", username) print(f"_get_db_path用户名: {safe_username}") db_filename = f"chat_history_{safe_username}_{day_number}.db" return os.path.join(base_dir, db_filename) def _get_db_connection(self, username, day_number=None): """获取数据库连接(按用户+日期分割)""" if day_number is None: day_number = self._get_day_number() db_key = self._get_db_key(username, day_number) # 如果连接不存在或需要重新连接,创建新连接 if db_key not in self._db_connections: db_path = self._get_db_path(username, day_number) conn = sqlite3.connect(db_path, check_same_thread=False) conn.execute("PRAGMA journal_mode=WAL") self._db_connections[db_key] = conn # 初始化该数据库的表 self._init_tables_for_connection(conn) return self._db_connections[db_key] def _init_tables_for_connection(self, conn): """为指定数据库连接初始化表结构""" cursor = conn.cursor() # 创建聊天任务表 cursor.execute( """ CREATE TABLE IF NOT EXISTS chat_tasks ( task_id TEXT PRIMARY KEY, status TEXT NOT NULL, -- pending, processing, completed, failed request_data TEXT NOT NULL, response_data TEXT, error_message TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, thread_id TEXT, username TEXT ) """ ) # 创建索引 cursor.execute("CREATE INDEX IF NOT EXISTS idx_status ON chat_tasks(status)") cursor.execute( "CREATE INDEX IF NOT EXISTS idx_thread_id ON chat_tasks(thread_id)" ) cursor.execute( "CREATE INDEX IF NOT EXISTS idx_username ON chat_tasks(username)" ) cursor.execute( "CREATE INDEX IF NOT EXISTS idx_created_at ON chat_tasks(created_at)" ) conn.commit() def _init_tables(self): """初始化所有数据库表(兼容性方法)""" # 为默认用户初始化表,确保系统启动时至少有一个数据库可用 default_username = "default" conn = self._get_db_connection(default_username) self._init_tables_for_connection(conn) def create_task(self, request_data: Dict[str, Any]) -> str: """创建新的聊天任务""" task_id = str(uuid.uuid4()) username = request_data.get("username", "default") print(f"创建任务: {task_id} (用户: {username})") conn = self._get_db_connection(username) cursor = conn.cursor() cursor.execute( """ INSERT INTO chat_tasks (task_id, status, request_data, thread_id, username) VALUES (?, ?, ?, ?, ?) """, ( task_id, "pending", json.dumps(request_data, ensure_ascii=False), request_data.get("thread_id", "default"), username, ), ) conn.commit() # 缓存任务ID对应的数据库信息 self._task_db_cache[task_id] = { "username": username, "day_number": self._get_day_number(), "db_file": os.path.basename(self._get_db_path(username)), } chat_logger.info(f"创建聊天任务: {task_id} (用户: {username})") return task_id def update_task_status( self, task_id: str, status: str, response_data: Dict[str, Any] = None, error_message: str = None, ): """更新任务状态""" # 首先检查缓存中是否有该任务对应的数据库信息 if task_id in self._task_db_cache: cache_info = self._task_db_cache[task_id] username = cache_info["username"] day_number = cache_info["day_number"] # 尝试在缓存的数据库中更新 db_path = self._get_db_path(username, day_number) if os.path.exists(db_path): try: self._update_task_in_db( username, task_id, status, response_data, error_message, day_number, ) return except Exception as e: chat_logger.warning( f"使用缓存更新任务状态失败: {task_id} - {str(e)}" ) # 缓存可能已失效,继续使用备用策略 del self._task_db_cache[task_id] # 备用策略:如果缓存不存在或失效,使用原来的get_task方法查找 task_info = self.get_task(task_id) if not task_info: chat_logger.error(f"更新任务状态失败: 任务不存在 - {task_id}") return username = task_info.get("username", "default") self._update_task_in_db(username, task_id, status, response_data, error_message) def _update_task_in_db( self, username, task_id, status, response_data=None, error_message=None, day_number=None, ): """在指定数据库中更新任务状态(内部方法)""" conn = self._get_db_connection(username, day_number) cursor = conn.cursor() cursor.execute( """ UPDATE chat_tasks SET status = ?, response_data = ?, error_message = ?, updated_at = CURRENT_TIMESTAMP WHERE task_id = ? """, ( status, ( json.dumps(response_data, ensure_ascii=False) if response_data else None ), error_message, task_id, ), ) conn.commit() chat_logger.info(f"更新任务状态: {task_id} -> {status} (用户: {username})") def _search_task_in_db(self, db_path, task_id): """在指定数据库中搜索任务""" try: conn = sqlite3.connect(db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute( """ SELECT task_id, status, request_data, response_data, error_message, created_at, updated_at, thread_id, username FROM chat_tasks WHERE task_id = ? """, (task_id,), ) row = cursor.fetchone() conn.close() if row: return { "task_id": row[0], "status": row[1], "request_data": json.loads(row[2]) if row[2] else {}, "response_data": json.loads(row[3]) if row[3] else None, "error_message": row[4], "created_at": row[5], "updated_at": row[6], "thread_id": row[7], "username": row[8], } except Exception as e: chat_logger.warning( f"搜索任务时数据库访问异常: {os.path.basename(db_path)} - {str(e)}" ) return None def get_task(self, task_id: str) -> Optional[Dict[str, Any]]: """获取任务信息 - 智能查找策略""" # 1. 首先检查缓存中是否有该任务对应的数据库信息 if task_id in self._task_db_cache: cache_info = self._task_db_cache[task_id] username = cache_info["username"] day_number = cache_info["day_number"] # 尝试在缓存的数据库中查找 db_path = self._get_db_path(username, day_number) if os.path.exists(db_path): result = self._search_task_in_db(db_path, task_id) if result: return result else: # 如果缓存的文件不存在,清除缓存 del self._task_db_cache[task_id] # 2. 如果没有缓存,尝试从请求上下文推断用户名 # 这里可以扩展为从请求头或其他上下文获取用户名 # 目前先使用默认策略 # 3. 按优先级搜索策略 project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) base_dir = os.path.join(project_root, "data", "chat_history") if not os.path.exists(base_dir): return None # 策略1: 搜索当天所有用户的数据库(最可能的情况) today = self._get_day_number() for db_file in os.listdir(base_dir): if db_file.endswith(f"_{today}.db") and db_file.startswith("chat_history_"): db_path = os.path.join(base_dir, db_file) result = self._search_task_in_db(db_path, task_id) if result: # 找到后缓存数据库信息 self._task_db_cache[task_id] = { "username": result["username"], "day_number": today, "db_file": db_file, } return result # 策略2: 搜索最近7天内的数据库(次优情况) for days_back in range(1, 8): search_date = datetime.datetime.now() - datetime.timedelta(days=days_back) search_day = self._get_day_number(search_date) for db_file in os.listdir(base_dir): if db_file.endswith(f"_{search_day}.db") and db_file.startswith( "chat_history_" ): db_path = os.path.join(base_dir, db_file) result = self._search_task_in_db(db_path, task_id) if result: # 找到后缓存数据库信息 self._task_db_cache[task_id] = { "username": result["username"], "day_number": search_day, "db_file": db_file, } return result # 策略3: 最后搜索所有数据库(最坏情况) for db_file in os.listdir(base_dir): if db_file.endswith(".db") and db_file.startswith("chat_history_"): db_path = os.path.join(base_dir, db_file) result = self._search_task_in_db(db_path, task_id) if result: # 从文件名中提取用户名和日期 parts = ( db_file.replace("chat_history_", "") .replace(".db", "") .split("_") ) if len(parts) >= 2: username_part = "_".join(parts[:-1]) # 用户名可能包含下划线 day_number = parts[-1] self._task_db_cache[task_id] = { "username": result["username"], "day_number": day_number, "db_file": db_file, } return result return None def cleanup_old_tasks(self, max_days=7): """清理超过指定天数的旧任务数据""" cutoff_date = datetime.datetime.now() - datetime.timedelta(days=max_days) cutoff_day = self._get_day_number(cutoff_date) project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) base_dir = os.path.join(project_root, "data", "chat_history") if not os.path.exists(base_dir): return cleaned_count = 0 # 清理旧的数据库文件 for db_file in os.listdir(base_dir): if db_file.endswith(".db") and db_file.startswith("chat_history_"): # 从文件名中提取日期部分 # 格式: chat_history_{username}_{YYYYMMDD}.db parts = db_file.split("_") if len(parts) >= 3: try: file_day = parts[-1].replace(".db", "") if file_day < cutoff_day: db_path = os.path.join(base_dir, db_file) os.remove(db_path) cleaned_count += 1 chat_logger.info(f"清理旧数据库文件: {db_file}") # 清理对应的缓存 task_ids_to_remove = [] for tid, cache_info in self._task_db_cache.items(): if cache_info["db_file"] == db_file: task_ids_to_remove.append(tid) for tid in task_ids_to_remove: del self._task_db_cache[tid] except (ValueError, IndexError): continue chat_logger.info(f"清理完成: 删除了 {cleaned_count} 个旧数据库文件") def close(self): """关闭所有数据库连接""" for db_key, conn in self._db_connections.items(): try: conn.close() except Exception as e: chat_logger.warning(f"关闭数据库连接失败: {db_key} - {str(e)}") self._db_connections.clear() self._task_db_cache.clear() chat_logger.info("所有数据库连接已关闭") # 全局实例 chat_result_manager = ChatResultManager()