| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405 |
- import os
- import sqlite3
- import uuid
- import json
- import datetime
- import re
- from typing import Dict, Any, Optional
- from pathlib import Path
- from utils.logger import chat_logger
- class ChatResultManager:
- """聊天结果管理器 - 支持按用户+日期分割的SQLite数据库存储"""
- def __init__(self):
- self._db_connections = {} # 存储不同用户+日期组合的连接
- self._task_db_cache = {} # 缓存任务ID对应的数据库文件
- self._init_tables()
- def _get_day_number(self, date=None):
- """获取日期编号 (YYYYMMDD 格式)"""
- if date is None:
- date = datetime.datetime.now()
- return date.strftime("%Y%m%d")
- def _get_db_key(self, username, day_number=None):
- """生成数据库连接缓存key"""
- if day_number is None:
- day_number = self._get_day_number()
- return f"{username}_{day_number}"
- def _get_db_path(self, username, day_number=None):
- """获取数据库文件路径"""
- if day_number is None:
- day_number = self._get_day_number()
- # 数据库文件存放目录
- project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
- base_dir = os.path.join(project_root, "data", "chat_history")
- os.makedirs(base_dir, exist_ok=True)
- # 数据库文件名格式: chat_history_{username}_{day_number}.db
- # 对用户名进行安全处理,避免特殊字符
- safe_username = re.sub(r'[\\/:*?"<>|\0]', "_", username)
- print(f"_get_db_path用户名: {safe_username}")
- db_filename = f"chat_history_{safe_username}_{day_number}.db"
- return os.path.join(base_dir, db_filename)
- def _get_db_connection(self, username, day_number=None):
- """获取数据库连接(按用户+日期分割)"""
- if day_number is None:
- day_number = self._get_day_number()
- db_key = self._get_db_key(username, day_number)
- # 如果连接不存在或需要重新连接,创建新连接
- if db_key not in self._db_connections:
- db_path = self._get_db_path(username, day_number)
- conn = sqlite3.connect(db_path, check_same_thread=False)
- conn.execute("PRAGMA journal_mode=WAL")
- self._db_connections[db_key] = conn
- # 初始化该数据库的表
- self._init_tables_for_connection(conn)
- return self._db_connections[db_key]
- def _init_tables_for_connection(self, conn):
- """为指定数据库连接初始化表结构"""
- cursor = conn.cursor()
- # 创建聊天任务表
- cursor.execute(
- """
- CREATE TABLE IF NOT EXISTS chat_tasks (
- task_id TEXT PRIMARY KEY,
- status TEXT NOT NULL, -- pending, processing, completed, failed
- request_data TEXT NOT NULL,
- response_data TEXT,
- error_message TEXT,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- thread_id TEXT,
- username TEXT
- )
- """
- )
- # 创建索引
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_status ON chat_tasks(status)")
- cursor.execute(
- "CREATE INDEX IF NOT EXISTS idx_thread_id ON chat_tasks(thread_id)"
- )
- cursor.execute(
- "CREATE INDEX IF NOT EXISTS idx_username ON chat_tasks(username)"
- )
- cursor.execute(
- "CREATE INDEX IF NOT EXISTS idx_created_at ON chat_tasks(created_at)"
- )
- conn.commit()
- def _init_tables(self):
- """初始化所有数据库表(兼容性方法)"""
- # 为默认用户初始化表,确保系统启动时至少有一个数据库可用
- default_username = "default"
- conn = self._get_db_connection(default_username)
- self._init_tables_for_connection(conn)
- def create_task(self, request_data: Dict[str, Any]) -> str:
- """创建新的聊天任务"""
- task_id = str(uuid.uuid4())
- username = request_data.get("username", "default")
- print(f"创建任务: {task_id} (用户: {username})")
- conn = self._get_db_connection(username)
- cursor = conn.cursor()
- cursor.execute(
- """
- INSERT INTO chat_tasks
- (task_id, status, request_data, thread_id, username)
- VALUES (?, ?, ?, ?, ?)
- """,
- (
- task_id,
- "pending",
- json.dumps(request_data, ensure_ascii=False),
- request_data.get("thread_id", "default"),
- username,
- ),
- )
- conn.commit()
- # 缓存任务ID对应的数据库信息
- self._task_db_cache[task_id] = {
- "username": username,
- "day_number": self._get_day_number(),
- "db_file": os.path.basename(self._get_db_path(username)),
- }
- chat_logger.info(f"创建聊天任务: {task_id} (用户: {username})")
- return task_id
- def update_task_status(
- self,
- task_id: str,
- status: str,
- response_data: Dict[str, Any] = None,
- error_message: str = None,
- ):
- """更新任务状态"""
- # 首先检查缓存中是否有该任务对应的数据库信息
- if task_id in self._task_db_cache:
- cache_info = self._task_db_cache[task_id]
- username = cache_info["username"]
- day_number = cache_info["day_number"]
- # 尝试在缓存的数据库中更新
- db_path = self._get_db_path(username, day_number)
- if os.path.exists(db_path):
- try:
- self._update_task_in_db(
- username,
- task_id,
- status,
- response_data,
- error_message,
- day_number,
- )
- return
- except Exception as e:
- chat_logger.warning(
- f"使用缓存更新任务状态失败: {task_id} - {str(e)}"
- )
- # 缓存可能已失效,继续使用备用策略
- del self._task_db_cache[task_id]
- # 备用策略:如果缓存不存在或失效,使用原来的get_task方法查找
- task_info = self.get_task(task_id)
- if not task_info:
- chat_logger.error(f"更新任务状态失败: 任务不存在 - {task_id}")
- return
- username = task_info.get("username", "default")
- self._update_task_in_db(username, task_id, status, response_data, error_message)
- def _update_task_in_db(
- self,
- username,
- task_id,
- status,
- response_data=None,
- error_message=None,
- day_number=None,
- ):
- """在指定数据库中更新任务状态(内部方法)"""
- conn = self._get_db_connection(username, day_number)
- cursor = conn.cursor()
- cursor.execute(
- """
- UPDATE chat_tasks
- SET status = ?,
- response_data = ?,
- error_message = ?,
- updated_at = CURRENT_TIMESTAMP
- WHERE task_id = ?
- """,
- (
- status,
- (
- json.dumps(response_data, ensure_ascii=False)
- if response_data
- else None
- ),
- error_message,
- task_id,
- ),
- )
- conn.commit()
- chat_logger.info(f"更新任务状态: {task_id} -> {status} (用户: {username})")
- def _search_task_in_db(self, db_path, task_id):
- """在指定数据库中搜索任务"""
- try:
- conn = sqlite3.connect(db_path, check_same_thread=False)
- cursor = conn.cursor()
- cursor.execute(
- """
- SELECT task_id, status, request_data, response_data,
- error_message, created_at, updated_at, thread_id, username
- FROM chat_tasks
- WHERE task_id = ?
- """,
- (task_id,),
- )
- row = cursor.fetchone()
- conn.close()
- if row:
- return {
- "task_id": row[0],
- "status": row[1],
- "request_data": json.loads(row[2]) if row[2] else {},
- "response_data": json.loads(row[3]) if row[3] else None,
- "error_message": row[4],
- "created_at": row[5],
- "updated_at": row[6],
- "thread_id": row[7],
- "username": row[8],
- }
- except Exception as e:
- chat_logger.warning(
- f"搜索任务时数据库访问异常: {os.path.basename(db_path)} - {str(e)}"
- )
- return None
- def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
- """获取任务信息 - 智能查找策略"""
- # 1. 首先检查缓存中是否有该任务对应的数据库信息
- if task_id in self._task_db_cache:
- cache_info = self._task_db_cache[task_id]
- username = cache_info["username"]
- day_number = cache_info["day_number"]
- # 尝试在缓存的数据库中查找
- db_path = self._get_db_path(username, day_number)
- if os.path.exists(db_path):
- result = self._search_task_in_db(db_path, task_id)
- if result:
- return result
- else:
- # 如果缓存的文件不存在,清除缓存
- del self._task_db_cache[task_id]
- # 2. 如果没有缓存,尝试从请求上下文推断用户名
- # 这里可以扩展为从请求头或其他上下文获取用户名
- # 目前先使用默认策略
- # 3. 按优先级搜索策略
- project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
- base_dir = os.path.join(project_root, "data", "chat_history")
- if not os.path.exists(base_dir):
- return None
- # 策略1: 搜索当天所有用户的数据库(最可能的情况)
- today = self._get_day_number()
- for db_file in os.listdir(base_dir):
- if db_file.endswith(f"_{today}.db") and db_file.startswith("chat_history_"):
- db_path = os.path.join(base_dir, db_file)
- result = self._search_task_in_db(db_path, task_id)
- if result:
- # 找到后缓存数据库信息
- self._task_db_cache[task_id] = {
- "username": result["username"],
- "day_number": today,
- "db_file": db_file,
- }
- return result
- # 策略2: 搜索最近7天内的数据库(次优情况)
- for days_back in range(1, 8):
- search_date = datetime.datetime.now() - datetime.timedelta(days=days_back)
- search_day = self._get_day_number(search_date)
- for db_file in os.listdir(base_dir):
- if db_file.endswith(f"_{search_day}.db") and db_file.startswith(
- "chat_history_"
- ):
- db_path = os.path.join(base_dir, db_file)
- result = self._search_task_in_db(db_path, task_id)
- if result:
- # 找到后缓存数据库信息
- self._task_db_cache[task_id] = {
- "username": result["username"],
- "day_number": search_day,
- "db_file": db_file,
- }
- return result
- # 策略3: 最后搜索所有数据库(最坏情况)
- for db_file in os.listdir(base_dir):
- if db_file.endswith(".db") and db_file.startswith("chat_history_"):
- db_path = os.path.join(base_dir, db_file)
- result = self._search_task_in_db(db_path, task_id)
- if result:
- # 从文件名中提取用户名和日期
- parts = (
- db_file.replace("chat_history_", "")
- .replace(".db", "")
- .split("_")
- )
- if len(parts) >= 2:
- username_part = "_".join(parts[:-1]) # 用户名可能包含下划线
- day_number = parts[-1]
- self._task_db_cache[task_id] = {
- "username": result["username"],
- "day_number": day_number,
- "db_file": db_file,
- }
- return result
- return None
- def cleanup_old_tasks(self, max_days=7):
- """清理超过指定天数的旧任务数据"""
- cutoff_date = datetime.datetime.now() - datetime.timedelta(days=max_days)
- cutoff_day = self._get_day_number(cutoff_date)
- project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
- base_dir = os.path.join(project_root, "data", "chat_history")
- if not os.path.exists(base_dir):
- return
- cleaned_count = 0
- # 清理旧的数据库文件
- for db_file in os.listdir(base_dir):
- if db_file.endswith(".db") and db_file.startswith("chat_history_"):
- # 从文件名中提取日期部分
- # 格式: chat_history_{username}_{YYYYMMDD}.db
- parts = db_file.split("_")
- if len(parts) >= 3:
- try:
- file_day = parts[-1].replace(".db", "")
- if file_day < cutoff_day:
- db_path = os.path.join(base_dir, db_file)
- os.remove(db_path)
- cleaned_count += 1
- chat_logger.info(f"清理旧数据库文件: {db_file}")
- # 清理对应的缓存
- task_ids_to_remove = []
- for tid, cache_info in self._task_db_cache.items():
- if cache_info["db_file"] == db_file:
- task_ids_to_remove.append(tid)
- for tid in task_ids_to_remove:
- del self._task_db_cache[tid]
- except (ValueError, IndexError):
- continue
- chat_logger.info(f"清理完成: 删除了 {cleaned_count} 个旧数据库文件")
- def close(self):
- """关闭所有数据库连接"""
- for db_key, conn in self._db_connections.items():
- try:
- conn.close()
- except Exception as e:
- chat_logger.warning(f"关闭数据库连接失败: {db_key} - {str(e)}")
- self._db_connections.clear()
- self._task_db_cache.clear()
- chat_logger.info("所有数据库连接已关闭")
- # 全局实例
- chat_result_manager = ChatResultManager()
|