chat_result_manager.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. import os
  2. import sqlite3
  3. import uuid
  4. import json
  5. import datetime
  6. import re
  7. from typing import Dict, Any, Optional
  8. from pathlib import Path
  9. from utils.logger import chat_logger
  10. class ChatResultManager:
  11. """聊天结果管理器 - 支持按用户+日期分割的SQLite数据库存储"""
  12. def __init__(self):
  13. self._db_connections = {} # 存储不同用户+日期组合的连接
  14. self._task_db_cache = {} # 缓存任务ID对应的数据库文件
  15. self._init_tables()
  16. def _get_day_number(self, date=None):
  17. """获取日期编号 (YYYYMMDD 格式)"""
  18. if date is None:
  19. date = datetime.datetime.now()
  20. return date.strftime("%Y%m%d")
  21. def _get_db_key(self, username, day_number=None):
  22. """生成数据库连接缓存key"""
  23. if day_number is None:
  24. day_number = self._get_day_number()
  25. return f"{username}_{day_number}"
  26. def _get_db_path(self, username, day_number=None):
  27. """获取数据库文件路径"""
  28. if day_number is None:
  29. day_number = self._get_day_number()
  30. # 数据库文件存放目录
  31. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  32. base_dir = os.path.join(project_root, "data", "chat_history")
  33. os.makedirs(base_dir, exist_ok=True)
  34. # 数据库文件名格式: chat_history_{username}_{day_number}.db
  35. # 对用户名进行安全处理,避免特殊字符
  36. safe_username = re.sub(r'[\\/:*?"<>|\0]', "_", username)
  37. print(f"_get_db_path用户名: {safe_username}")
  38. db_filename = f"chat_history_{safe_username}_{day_number}.db"
  39. return os.path.join(base_dir, db_filename)
  40. def _get_db_connection(self, username, day_number=None):
  41. """获取数据库连接(按用户+日期分割)"""
  42. if day_number is None:
  43. day_number = self._get_day_number()
  44. db_key = self._get_db_key(username, day_number)
  45. # 如果连接不存在或需要重新连接,创建新连接
  46. if db_key not in self._db_connections:
  47. db_path = self._get_db_path(username, day_number)
  48. conn = sqlite3.connect(db_path, check_same_thread=False)
  49. conn.execute("PRAGMA journal_mode=WAL")
  50. self._db_connections[db_key] = conn
  51. # 初始化该数据库的表
  52. self._init_tables_for_connection(conn)
  53. return self._db_connections[db_key]
  54. def _init_tables_for_connection(self, conn):
  55. """为指定数据库连接初始化表结构"""
  56. cursor = conn.cursor()
  57. # 创建聊天任务表
  58. cursor.execute(
  59. """
  60. CREATE TABLE IF NOT EXISTS chat_tasks (
  61. task_id TEXT PRIMARY KEY,
  62. status TEXT NOT NULL, -- pending, processing, completed, failed
  63. request_data TEXT NOT NULL,
  64. response_data TEXT,
  65. error_message TEXT,
  66. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  67. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  68. thread_id TEXT,
  69. username TEXT
  70. )
  71. """
  72. )
  73. # 创建索引
  74. cursor.execute("CREATE INDEX IF NOT EXISTS idx_status ON chat_tasks(status)")
  75. cursor.execute(
  76. "CREATE INDEX IF NOT EXISTS idx_thread_id ON chat_tasks(thread_id)"
  77. )
  78. cursor.execute(
  79. "CREATE INDEX IF NOT EXISTS idx_username ON chat_tasks(username)"
  80. )
  81. cursor.execute(
  82. "CREATE INDEX IF NOT EXISTS idx_created_at ON chat_tasks(created_at)"
  83. )
  84. conn.commit()
  85. def _init_tables(self):
  86. """初始化所有数据库表(兼容性方法)"""
  87. # 为默认用户初始化表,确保系统启动时至少有一个数据库可用
  88. default_username = "default"
  89. conn = self._get_db_connection(default_username)
  90. self._init_tables_for_connection(conn)
  91. def create_task(self, request_data: Dict[str, Any]) -> str:
  92. """创建新的聊天任务"""
  93. task_id = str(uuid.uuid4())
  94. username = request_data.get("username", "default")
  95. print(f"创建任务: {task_id} (用户: {username})")
  96. conn = self._get_db_connection(username)
  97. cursor = conn.cursor()
  98. cursor.execute(
  99. """
  100. INSERT INTO chat_tasks
  101. (task_id, status, request_data, thread_id, username)
  102. VALUES (?, ?, ?, ?, ?)
  103. """,
  104. (
  105. task_id,
  106. "pending",
  107. json.dumps(request_data, ensure_ascii=False),
  108. request_data.get("thread_id", "default"),
  109. username,
  110. ),
  111. )
  112. conn.commit()
  113. # 缓存任务ID对应的数据库信息
  114. self._task_db_cache[task_id] = {
  115. "username": username,
  116. "day_number": self._get_day_number(),
  117. "db_file": os.path.basename(self._get_db_path(username)),
  118. }
  119. chat_logger.info(f"创建聊天任务: {task_id} (用户: {username})")
  120. return task_id
  121. def update_task_status(
  122. self,
  123. task_id: str,
  124. status: str,
  125. response_data: Dict[str, Any] = None,
  126. error_message: str = None,
  127. ):
  128. """更新任务状态"""
  129. # 首先检查缓存中是否有该任务对应的数据库信息
  130. if task_id in self._task_db_cache:
  131. cache_info = self._task_db_cache[task_id]
  132. username = cache_info["username"]
  133. day_number = cache_info["day_number"]
  134. # 尝试在缓存的数据库中更新
  135. db_path = self._get_db_path(username, day_number)
  136. if os.path.exists(db_path):
  137. try:
  138. self._update_task_in_db(
  139. username,
  140. task_id,
  141. status,
  142. response_data,
  143. error_message,
  144. day_number,
  145. )
  146. return
  147. except Exception as e:
  148. chat_logger.warning(
  149. f"使用缓存更新任务状态失败: {task_id} - {str(e)}"
  150. )
  151. # 缓存可能已失效,继续使用备用策略
  152. del self._task_db_cache[task_id]
  153. # 备用策略:如果缓存不存在或失效,使用原来的get_task方法查找
  154. task_info = self.get_task(task_id)
  155. if not task_info:
  156. chat_logger.error(f"更新任务状态失败: 任务不存在 - {task_id}")
  157. return
  158. username = task_info.get("username", "default")
  159. self._update_task_in_db(username, task_id, status, response_data, error_message)
  160. def _update_task_in_db(
  161. self,
  162. username,
  163. task_id,
  164. status,
  165. response_data=None,
  166. error_message=None,
  167. day_number=None,
  168. ):
  169. """在指定数据库中更新任务状态(内部方法)"""
  170. conn = self._get_db_connection(username, day_number)
  171. cursor = conn.cursor()
  172. cursor.execute(
  173. """
  174. UPDATE chat_tasks
  175. SET status = ?,
  176. response_data = ?,
  177. error_message = ?,
  178. updated_at = CURRENT_TIMESTAMP
  179. WHERE task_id = ?
  180. """,
  181. (
  182. status,
  183. (
  184. json.dumps(response_data, ensure_ascii=False)
  185. if response_data
  186. else None
  187. ),
  188. error_message,
  189. task_id,
  190. ),
  191. )
  192. conn.commit()
  193. chat_logger.info(f"更新任务状态: {task_id} -> {status} (用户: {username})")
  194. def _search_task_in_db(self, db_path, task_id):
  195. """在指定数据库中搜索任务"""
  196. try:
  197. conn = sqlite3.connect(db_path, check_same_thread=False)
  198. cursor = conn.cursor()
  199. cursor.execute(
  200. """
  201. SELECT task_id, status, request_data, response_data,
  202. error_message, created_at, updated_at, thread_id, username
  203. FROM chat_tasks
  204. WHERE task_id = ?
  205. """,
  206. (task_id,),
  207. )
  208. row = cursor.fetchone()
  209. conn.close()
  210. if row:
  211. return {
  212. "task_id": row[0],
  213. "status": row[1],
  214. "request_data": json.loads(row[2]) if row[2] else {},
  215. "response_data": json.loads(row[3]) if row[3] else None,
  216. "error_message": row[4],
  217. "created_at": row[5],
  218. "updated_at": row[6],
  219. "thread_id": row[7],
  220. "username": row[8],
  221. }
  222. except Exception as e:
  223. chat_logger.warning(
  224. f"搜索任务时数据库访问异常: {os.path.basename(db_path)} - {str(e)}"
  225. )
  226. return None
  227. def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
  228. """获取任务信息 - 智能查找策略"""
  229. # 1. 首先检查缓存中是否有该任务对应的数据库信息
  230. if task_id in self._task_db_cache:
  231. cache_info = self._task_db_cache[task_id]
  232. username = cache_info["username"]
  233. day_number = cache_info["day_number"]
  234. # 尝试在缓存的数据库中查找
  235. db_path = self._get_db_path(username, day_number)
  236. if os.path.exists(db_path):
  237. result = self._search_task_in_db(db_path, task_id)
  238. if result:
  239. return result
  240. else:
  241. # 如果缓存的文件不存在,清除缓存
  242. del self._task_db_cache[task_id]
  243. # 2. 如果没有缓存,尝试从请求上下文推断用户名
  244. # 这里可以扩展为从请求头或其他上下文获取用户名
  245. # 目前先使用默认策略
  246. # 3. 按优先级搜索策略
  247. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  248. base_dir = os.path.join(project_root, "data", "chat_history")
  249. if not os.path.exists(base_dir):
  250. return None
  251. # 策略1: 搜索当天所有用户的数据库(最可能的情况)
  252. today = self._get_day_number()
  253. for db_file in os.listdir(base_dir):
  254. if db_file.endswith(f"_{today}.db") and db_file.startswith("chat_history_"):
  255. db_path = os.path.join(base_dir, db_file)
  256. result = self._search_task_in_db(db_path, task_id)
  257. if result:
  258. # 找到后缓存数据库信息
  259. self._task_db_cache[task_id] = {
  260. "username": result["username"],
  261. "day_number": today,
  262. "db_file": db_file,
  263. }
  264. return result
  265. # 策略2: 搜索最近7天内的数据库(次优情况)
  266. for days_back in range(1, 8):
  267. search_date = datetime.datetime.now() - datetime.timedelta(days=days_back)
  268. search_day = self._get_day_number(search_date)
  269. for db_file in os.listdir(base_dir):
  270. if db_file.endswith(f"_{search_day}.db") and db_file.startswith(
  271. "chat_history_"
  272. ):
  273. db_path = os.path.join(base_dir, db_file)
  274. result = self._search_task_in_db(db_path, task_id)
  275. if result:
  276. # 找到后缓存数据库信息
  277. self._task_db_cache[task_id] = {
  278. "username": result["username"],
  279. "day_number": search_day,
  280. "db_file": db_file,
  281. }
  282. return result
  283. # 策略3: 最后搜索所有数据库(最坏情况)
  284. for db_file in os.listdir(base_dir):
  285. if db_file.endswith(".db") and db_file.startswith("chat_history_"):
  286. db_path = os.path.join(base_dir, db_file)
  287. result = self._search_task_in_db(db_path, task_id)
  288. if result:
  289. # 从文件名中提取用户名和日期
  290. parts = (
  291. db_file.replace("chat_history_", "")
  292. .replace(".db", "")
  293. .split("_")
  294. )
  295. if len(parts) >= 2:
  296. username_part = "_".join(parts[:-1]) # 用户名可能包含下划线
  297. day_number = parts[-1]
  298. self._task_db_cache[task_id] = {
  299. "username": result["username"],
  300. "day_number": day_number,
  301. "db_file": db_file,
  302. }
  303. return result
  304. return None
  305. def cleanup_old_tasks(self, max_days=7):
  306. """清理超过指定天数的旧任务数据"""
  307. cutoff_date = datetime.datetime.now() - datetime.timedelta(days=max_days)
  308. cutoff_day = self._get_day_number(cutoff_date)
  309. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  310. base_dir = os.path.join(project_root, "data", "chat_history")
  311. if not os.path.exists(base_dir):
  312. return
  313. cleaned_count = 0
  314. # 清理旧的数据库文件
  315. for db_file in os.listdir(base_dir):
  316. if db_file.endswith(".db") and db_file.startswith("chat_history_"):
  317. # 从文件名中提取日期部分
  318. # 格式: chat_history_{username}_{YYYYMMDD}.db
  319. parts = db_file.split("_")
  320. if len(parts) >= 3:
  321. try:
  322. file_day = parts[-1].replace(".db", "")
  323. if file_day < cutoff_day:
  324. db_path = os.path.join(base_dir, db_file)
  325. os.remove(db_path)
  326. cleaned_count += 1
  327. chat_logger.info(f"清理旧数据库文件: {db_file}")
  328. # 清理对应的缓存
  329. task_ids_to_remove = []
  330. for tid, cache_info in self._task_db_cache.items():
  331. if cache_info["db_file"] == db_file:
  332. task_ids_to_remove.append(tid)
  333. for tid in task_ids_to_remove:
  334. del self._task_db_cache[tid]
  335. except (ValueError, IndexError):
  336. continue
  337. chat_logger.info(f"清理完成: 删除了 {cleaned_count} 个旧数据库文件")
  338. def close(self):
  339. """关闭所有数据库连接"""
  340. for db_key, conn in self._db_connections.items():
  341. try:
  342. conn.close()
  343. except Exception as e:
  344. chat_logger.warning(f"关闭数据库连接失败: {db_key} - {str(e)}")
  345. self._db_connections.clear()
  346. self._task_db_cache.clear()
  347. chat_logger.info("所有数据库连接已关闭")
  348. # 全局实例
  349. chat_result_manager = ChatResultManager()