Bläddra i källkod

龙嘉AI助手项目第一次提交代码

longjoedyy 1 månad sedan
incheckning
70bc4ae999

+ 1 - 0
.env.development.example

@@ -0,0 +1 @@
+DEEPSEEK_API_KEY=你的DEEPSEEK_API_KEY(明文,非加密),开发环境使用

+ 1 - 0
.env.example

@@ -0,0 +1 @@
+ENCRYPTED_DEEPSEEK_KEY=加密后的DEEPSEEK_API_KEY,生产环境使用

+ 49 - 0
.gitignore

@@ -0,0 +1,49 @@
+# 环境配置文件
+.env
+.env.development
+.registration
+
+# 个人调试文件
+terminal_demo.py
+
+# Python编译文件
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+.Python
+.pytest_cache/
+
+# 编译生成的文件
+*.pyd
+*.pyc
+r.pyc
+*.pyo
+*.pyd
+
+# 构建目录
+build/
+dist/
+*.egg-info/
+
+# 项目特定目录
+chat_logs/
+data/
+tests/
+cython_build/
+
+# 编辑器文件
+.vscode/
+.idea/
+*.swp
+*.swo
+*~
+
+# 操作系统文件
+.DS_Store
+Thumbs.db
+
+# 其他临时文件
+*.log
+*.tmp
+*.temp

+ 1 - 0
.registration.example

@@ -0,0 +1 @@
+REGISTRATION_CODE=注册码

+ 225 - 0
README.md

@@ -0,0 +1,225 @@
+# 龙嘉软件AI助手项目
+
+## 项目简介
+
+龙嘉软件AI助手是一个基于FastAPI和LangChain构建的智能对话系统,提供聊天、OCR识别、文档处理等AI功能。项目采用模块化设计,支持多用户并发处理,具备完善的工具管理和缓存机制。
+
+## 项目结构
+
+```
+LongjoeAgent/
+├── api/                    # API接口层
+│   ├── models.py          # 数据模型定义
+│   └── routes.py          # 路由定义
+├── config/                # 配置文件
+│   ├── template_config.json   # 单据识别模板提示词配置
+│   ├── tool_config.json   # 工具提示词配置
+├── core/                  # 核心业务逻辑
+│   ├── agent.py           # LangChain Agent创建
+│   ├── agent_manager.py   # Agent管理器
+│   ├── chat_service.py    # 聊天服务
+│   ├── document_processor/ # 文档处理模块
+│   ├── ocr_service.py     # OCR服务
+│   └── lifespan_manager.py # 生命周期管理
+├── tools/                 # 工具模块
+│   ├── base_tools.py      # 工具类公用方法
+│   ├── knowledge_tools.py # 知识库工具
+│   ├── sale_tools.py      # 销售工具
+│   ├── ware_tools.py      # 库存工具
+│   └── tool_factory.py    # 工具工厂
+├── middlewares/           # 中间件
+├── utils/                 # 工具类
+├── chat_logs/            # 聊天日志
+├── app.py                # 主应用入口
+├── requirements.txt      # 依赖包
+├── cythonize.py          # 编译工具
+└── setup.py             # 编译配置
+```
+
+## 主要功能
+
+### 1. 智能聊天
+- 支持多用户并发对话
+- 基于LangChain的智能Agent
+- 工具调用和思维链展示
+- 对话历史管理
+
+### 2. OCR识别
+- 支持图片Base64格式识别
+- 集成PaddleOCR服务
+- 发票、单据等文档识别
+- 结构化数据提取
+
+### 3. 文档处理
+- 多类型文档模板支持
+- 自动化的单据创建
+- 文档内容解析和结构化
+
+### 4. 工具管理
+- 动态工具发现和加载
+- 知识库查询工具
+- 销售数据获取工具
+- 可扩展的工具框架
+
+## 技术栈
+
+- **后端框架**: FastAPI
+- **AI框架**: LangChain, LangGraph
+- **OCR服务**: PaddleOCR
+- **缓存管理**: 本地缓存
+- **日志系统**: 自定义日志记录
+- **部署**: Uvicorn + 多进程
+
+## 快速开始
+
+### 环境要求
+
+- Python 3.8+
+- 安装依赖包
+
+### 安装步骤
+0. 安装python 3.8+
+```https://www.python.org/downloads/```
+
+1. 克隆项目
+
+
+2. 安装依赖
+```bash
+pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
+```
+
+3. 配置DEEPSEEK_API_KEY
+参考`.env.development.example`创建 `.env.development` 文件并配置DEEPSEEK_API_KEY
+
+
+4. 启动服务
+```bash
+python app.py
+```
+
+5. 服务启动后访问
+健康检查: http://localhost:8888/health, 复制"device_id"的值,向耀申请注册码
+
+6. 配置注册码
+参考`.registration.example`创建 `.registration` 文件并配置REGISTRATION_CODE
+
+## API接口
+
+### 主要接口
+
+#### 1. 聊天接口
+- **路径**: `POST /chat`
+- **功能**: 处理用户消息并返回AI回复
+- **请求体**:
+```json
+{
+  "message": "用户消息",
+  "thread_id": "会话ID",
+  "username": "用户名",
+  "backend_url": "后端URL",
+  "token": "认证令牌"
+}
+```
+
+#### 2. OCR识别接口
+- **路径**: `POST /ocr`
+- **功能**: 识别图片中的文字
+- **请求体**:
+```json
+{
+  "image": "Base64图片数据",
+  "type": "图片类型"
+}
+```
+
+#### 3. 单据创建接口
+- **路径**: `POST /message_create_bill`
+- **功能**: 通过文本消息辅助创建单据
+- **请求体**:
+```json
+{
+  "message": "单据描述",
+  "document_type": "单据类型"
+}
+```
+
+
+## 工具配置
+项目目前支持以下工具:
+- **知识库工具**: 查询知识库内容
+- **销售工具**: 获取销售数据
+- **库存工具**: 通过工具工厂动态加载
+
+## 开发指南
+
+### 添加新工具
+
+1. 在 `tools/` 目录下创建新的工具文件,以 _tools.py 结尾,通常一个文件对应一个模块的工具
+2. 使用 `@tool` 装饰器标记工具函数,或使用 `tool` 类定义工具(参考ware_tools.py)
+3. 工具函数需要包含 `name` 和 `description` 属性
+4. 工具工厂会自动发现并加载新工具
+5. 工具的提示词可以通过json配置,在`config\tool_config.json`中定义,重新启动服务后生效
+
+### 扩展API接口
+
+1. 在 `api/routes.py` 中添加新的路由
+2. 在 `api/models.py` 中定义数据模型
+3. 在相应的服务模块中实现业务逻辑
+
+## 部署说明
+
+### 生产环境部署
+1. 生成依赖包 requirements.txt
+```bash
+pipreqs . --encoding=utf8 --force
+```
+
+1. 使用setup.py编译项目
+```bash
+python cythonize.py
+```
+
+1. 使用Uvicorn启动服务
+```bash
+uvicorn app:app --host 0.0.0.0 --port 8888 --workers 4
+```
+
+### 性能优化建议
+
+- 调整线程池大小(在chat_service.py中配置)
+- 优化缓存策略(在agent_manager.py中配置)
+- 监控内存使用和响应时间
+
+## 故障排除
+
+### 常见问题
+
+1. **工具加载失败**
+   - 检查工具函数是否正确定义
+   - 确认工具函数有name和description属性
+
+2. **OCR识别失败**
+   - 检查OCR服务连接
+   - 验证图片格式和大小
+
+3. **内存泄漏**
+   - 定期清理缓存
+   - 监控内存使用情况
+
+### 日志查看
+
+日志文件位于 `chat_logs/` 目录,包含详细的运行信息。
+
+## 许可证
+
+本项目仅供龙嘉软件内部使用。
+
+## 联系方式
+
+如有问题请联系项目维护团队。
+
+---
+
+*最后更新: 2026年1月*
+        

+ 46 - 0
api/models.py

@@ -0,0 +1,46 @@
+from pydantic import BaseModel
+from typing import Optional, List, Dict
+
+
+class ChatRequest(BaseModel):
+    message: str
+    thread_id: str = "default"
+    username: str = "default"
+    backend_url: str = ""
+    token: str = ""
+    include_thoughts: bool = False
+    include_tool_calls: bool = False
+
+
+class MessageModel(BaseModel):
+    type: str
+    content: str
+    tool_calls: Optional[List[Dict]] = None
+    tool_call_id: Optional[str] = None
+    name: Optional[str] = None
+
+
+class ChatResponse(BaseModel):
+    final_answer: str
+    all_ai_messages: List[MessageModel]
+    all_messages: List[MessageModel]
+    tool_calls: List[Dict]
+    thread_id: str
+    user_identifier: str
+    backend_config: Dict
+    success: bool
+    error: Optional[str] = None
+
+
+class OCRRequest(BaseModel):
+    """图片处理请求"""
+
+    image: str
+    type: str
+
+
+class MessageCreateBill(BaseModel):
+    """创建单据请求"""
+
+    message: str
+    document_type: str = None

+ 314 - 0
api/routes.py

@@ -0,0 +1,314 @@
+import base64
+from datetime import datetime
+from fastapi import APIRouter, HTTPException
+
+from utils.device_id import get_device_id
+from .models import ChatRequest, ChatResponse, OCRRequest, MessageCreateBill
+from core.chat_service import chat_service
+from core.agent_manager import agent_manager
+from utils.logger import chat_logger
+from tools.tool_factory import get_all_tools
+import time
+from utils.registration_manager import registration_manager
+from core.ocr_service import PaddleOCRService
+from core.document_processor.document_service import DocumentProcessingService
+
+# 初始化服务
+ocr_service = PaddleOCRService(
+    api_url="https://a8l0g1qda8zd48nb.aistudio-app.com/ocr",
+    token="f97d214abf87d5ea3c156e21257732a3b19661cb",
+)
+doc_service = DocumentProcessingService(ocr_service=ocr_service)
+
+
+router = APIRouter()
+
+
+@router.post("/chat", response_model=ChatResponse)
+async def chat_endpoint(request: ChatRequest):
+    """聊天接口"""
+    try:
+        result = await chat_service.process_chat_request(request.model_dump())
+        return ChatResponse(**result)
+    except Exception as e:
+        chat_logger.error(f"API处理失败: {str(e)}")
+        raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
+
+
+@router.post("/ocr")
+async def orc(request: OCRRequest):
+    """OCR接口"""
+    try:
+        chat_logger.info(f"开始进行图片识别")
+
+        # 1. 解码Base64图片
+        try:
+            if "," in request.image:
+                # 去掉 data:image/xxx;base64, 前缀
+                base64_str = request.image.split(",", 1)[1]
+            else:
+                base64_str = request.image
+
+            image_bytes = base64.b64decode(base64_str)
+
+        except Exception as e:
+            chat_logger.error(f"图片解码失败: {e}")
+            raise HTTPException(400, f"图片格式错误: {str(e)}")
+
+        # 2. OCR识别
+        result = await doc_service.pure_ocr(image_bytes=image_bytes)
+
+        # 3. 返回结果
+        return result
+
+    except HTTPException:
+        raise
+    except Exception as e:
+        chat_logger.error(f"处理失败: {e}")
+        raise HTTPException(500, f"处理失败: {str(e)}")
+
+
+@router.post("/ocr_create_bill")
+async def ocr_create_bill_endpoint(request: OCRRequest):
+    """
+    处理单张图片
+    请求格式:
+    {
+        "image": "data:image/jpeg;base64,/9j/4AAQSkZJRg...",
+        "type": "invoice",
+    }
+    """
+    try:
+        chat_logger.info(f"开始处理 {request.type} 单据")
+
+        # 1. 解码Base64图片
+        try:
+            if "," in request.image:
+                # 去掉 data:image/xxx;base64, 前缀
+                base64_str = request.image.split(",", 1)[1]
+            else:
+                base64_str = request.image
+
+            image_bytes = base64.b64decode(base64_str)
+
+        except Exception as e:
+            chat_logger.error(f"图片解码失败: {e}")
+            raise HTTPException(400, f"图片格式错误: {str(e)}")
+
+        # 2. 处理单据
+        result = await doc_service.ocr_create_bill(
+            image_bytes=image_bytes, document_type=request.type
+        )
+
+        # 3. 返回结果
+        return {
+            "success": True,
+            "type": request.type,
+            "text": result.get("ocr_text", ""),  # 识别出的文本
+            "data": result.get("parsed_data", {}),  # 结构化数据
+            "timestamp": datetime.now().isoformat(),
+        }
+
+    except HTTPException:
+        raise
+    except Exception as e:
+        chat_logger.error(f"处理失败: {e}")
+        raise HTTPException(500, f"处理失败: {str(e)}")
+
+
+@router.post("/message_create_bill")
+async def message_create_bill_endpoint(request: MessageCreateBill):
+    """
+    通过文本消息辅助建立单据
+    请求格式:
+    {
+        "message": "这是一条单据描述",
+        "document_type": "invoice",
+    }
+    """
+    try:
+        result = await doc_service.message_create_bill(
+            message=request.message, document_type=request.document_type
+        )
+        return result
+    except Exception as e:
+        chat_logger.error(f"处理失败: {e}")
+        raise HTTPException(500, f"处理失败: {str(e)}")
+
+
+@router.get("/cache/status")
+async def cache_status():
+    """查看缓存状态 - 修复版本"""
+    try:
+        cache_status = await agent_manager.get_cache_status()
+
+        # ✅ 确保返回完整信息
+        return {
+            "success": True,
+            "cache_size": cache_status.get("cache_size", 0),
+            "cache_expiry_seconds": cache_status.get("cache_expiry_seconds", 3600),
+            "cache_entries_count": len(cache_status.get("cache_entries", [])),
+            "cache_entries": cache_status.get("cache_entries", []),
+            "timestamp": time.time(),
+        }
+    except Exception as e:
+        chat_logger.error(f"获取缓存状态失败: {str(e)}")
+        return {
+            "success": False,
+            "error": str(e),
+            "cache_size": 0,
+            "cache_entries_count": 0,
+            "cache_entries": [],
+        }
+
+
+@router.delete("/cache/clear")
+async def clear_cache():
+    """清空缓存"""
+    try:
+        # ✅ 使用异步版本
+        cleared = await agent_manager.clear_cache()
+        chat_logger.info(f"清空agent缓存, 清理数量={cleared}")
+        return {
+            "cleared_entries": cleared,
+            "message": "缓存已清空",
+            "status": "success",
+        }
+    except Exception as e:
+        chat_logger.error(f"清空缓存失败: {str(e)}")
+        raise HTTPException(status_code=500, detail=f"清空缓存失败: {str(e)}")
+
+
+@router.get("/health")
+async def health_check():
+    """健康检查"""
+    try:
+        # ✅ 使用异步版本
+
+        registration_status = await registration_manager.check_registration()
+
+        return {
+            "status": "healthy",
+            "service": "龙嘉软件AI助手API",
+            "registration_status": (
+                "已注册" if registration_status else "未注册或注册过期"
+            ),
+            "device_id": get_device_id(),
+            "timestamp": time.time(),
+        }
+    except Exception as e:
+        return {
+            "status": "unhealthy",
+            "service": "龙嘉软件AI助手API",
+            "error": str(e),
+            "timestamp": time.time(),
+        }
+
+
+@router.get("/tools/status")
+async def tools_status():
+    """查看工具状态"""
+    try:
+        tools = get_all_tools()
+        tool_info = []
+
+        for i, tool in enumerate(tools):
+            tool_info.append(
+                {
+                    "index": i + 1,
+                    "name": getattr(tool, "name", "unknown"),
+                    "description": getattr(tool, "description", "unknown")[:100]
+                    + "...",
+                }
+            )
+
+        return {"total_tools": len(tools), "tools": tool_info, "status": "success"}
+    except Exception as e:
+        chat_logger.error(f"获取工具状态失败: {str(e)}")
+        raise HTTPException(status_code=500, detail=f"获取工具状态失败: {str(e)}")
+
+
+@router.post("/cache/initialize")
+async def initialize_cache():
+    """初始化缓存系统"""
+    try:
+        # ✅ 使用异步版本
+        await agent_manager.initialize()
+        chat_logger.info("缓存系统初始化完成")
+        return {"status": "success", "message": "缓存系统初始化完成"}
+    except Exception as e:
+        chat_logger.error(f"初始化缓存失败: {str(e)}")
+        raise HTTPException(status_code=500, detail=f"初始化缓存失败: {str(e)}")
+
+
+@router.post("/cache/shutdown")
+async def shutdown_cache():
+    """关闭缓存系统"""
+    try:
+        # ✅ 使用异步版本
+        cleared = await agent_manager.shutdown()
+        chat_logger.info(f"缓存系统已关闭,清理了 {cleared} 个实例")
+        return {
+            "cleared_entries": cleared,
+            "message": "缓存系统已关闭",
+            "status": "success",
+        }
+    except Exception as e:
+        chat_logger.error(f"关闭缓存失败: {str(e)}")
+        raise HTTPException(status_code=500, detail=f"关闭缓存失败: {str(e)}")
+
+
+@router.post("/admin/refresh-registration")
+async def refresh_registration():
+    """手动刷新注册状态(管理员用)"""
+    registration_manager.force_refresh()
+    return {"status": "success", "message": "注册状态已刷新"}
+
+
+@router.get("/admin/registration-status")
+async def get_registration_status():
+    """获取注册状态(管理员用)"""
+    status = await registration_manager.check_registration()
+    status_info = registration_manager.get_status()
+    return {"is_registered": status, "status_info": status_info}
+
+
+@router.get("/")
+async def root():
+    registration_status = await registration_manager.check_registration()
+
+    base_info = {
+        "service": "龙嘉软件AI助手API",
+        "version": "1.0.0",
+        "registration_status": "active" if registration_status else "expired",
+        "device_id": get_device_id(),
+    }
+
+    if registration_status:
+        base_info.update(
+            {
+                "endpoints": {
+                    "POST /chat": "聊天",
+                    "GET /health": "健康检查",
+                    "GET /cache/status": "查看agent缓存状态",
+                    "DELETE /cache/clear": "清空agent缓存",
+                    "GET /": "API信息",
+                }
+            }
+        )
+    else:
+        base_info.update(
+            {
+                "message": "⚠️ 服务注册已过期,部分功能受限",
+                "available_endpoints": {
+                    "GET /health": "健康检查",
+                    "GET /registration/status": "查看注册状态",
+                    "GET /service/info": "服务完整信息",
+                    "GET /": "API信息",
+                },
+                "restricted_endpoints": {"POST /chat": "AI聊天功能(需续费)"},
+                "support_contact": "请联系管理员续费服务",
+            }
+        )
+
+    return base_info

+ 27 - 0
app.py

@@ -0,0 +1,27 @@
+from fastapi import FastAPI
+from api.routes import router
+from core.lifespan_manager import create_lifespan
+from middlewares.registration_middleware import registration_check_middleware
+
+app = FastAPI(
+    title="龙嘉软件AI助手API",
+    description="龙嘉软件公司的AI助手",
+    version="1.0.0",
+    lifespan=create_lifespan(),
+)
+# 添加中间件
+app.middleware("http")(registration_check_middleware)
+
+app.include_router(router)
+
+if __name__ == "__main__":
+    import uvicorn
+
+    uvicorn.run(app, host="0.0.0.0", port=8888)
+    # uvicorn.run(
+    #     "app:app",  # 导入字符串格式
+    #     host="0.0.0.0",
+    #     port=8888,
+    #     workers=4,
+    #     reload=False,
+    # )

+ 108 - 0
check_deploy.py

@@ -0,0 +1,108 @@
+"""
+部署检查脚本
+"""
+
+import sys
+from pathlib import Path
+import subprocess
+
+
+def check_deployment():
+    """检查部署包"""
+    deploy_dir = Path("build") / "deploy"
+
+    if not deploy_dir.exists():
+        print("❌ 部署包不存在,请先运行编译: python cythonize.py")
+        return False
+
+    print("🔍 检查部署包...")
+
+    # 检查必要文件
+    required_files = {
+        "app.py": "主应用文件",
+        "api/": "API模块",
+        "core/": "核心模块",
+        "tools/": "工具模块",
+    }
+
+    all_ok = True
+    for file_name, description in required_files.items():
+        file_path = deploy_dir / file_name
+        if file_path.exists():
+            print(f"  ✓ {file_name:20} {description}")
+        else:
+            print(f"  ✗ {file_name:20} {description} (缺失)")
+            all_ok = False
+
+    # 检查编译文件
+    compiled_ext = ".pyd" if sys.platform == "win32" else ".so"
+    compiled_files = list(deploy_dir.rglob(f"*{compiled_ext}"))
+    print(f"  ✓ 编译文件数量: {len(compiled_files)} 个 {compiled_ext} 文件")
+
+    # 检查依赖
+    req_file = deploy_dir / "requirements.txt"
+    if req_file.exists():
+        print(f"  ✓ requirements.txt 存在")
+
+        # 检查Python版本
+        try:
+            result = subprocess.run(
+                [sys.executable, "--version"], capture_output=True, text=True
+            )
+            print(f"  Python版本: {result.stdout.strip()}")
+        except:
+            print("  ⚠️  无法获取Python版本")
+    else:
+        print("  ⚠️  requirements.txt 不存在")
+        all_ok = False
+
+    return all_ok
+
+
+def create_minimal_deploy():
+    """创建最小化部署包"""
+    import zipfile
+
+    deploy_dir = Path("build") / "deploy"
+    if not deploy_dir.exists():
+        print("❌ 部署包不存在")
+        return
+
+    # 创建zip包
+    zip_path = Path("build") / "longjoeagent_deploy.zip"
+
+    with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
+        for file_path in deploy_dir.rglob("*"):
+            if file_path.is_file():
+                arcname = file_path.relative_to(deploy_dir)
+                zipf.write(file_path, arcname)
+
+    print(f"✅ 创建部署包: {zip_path}")
+    print(f"   大小: {zip_path.stat().st_size / 1024 / 1024:.2f} MB")
+
+
+def main():
+    """主函数"""
+    print("=" * 50)
+    print("      LongjoeAgent 部署检查工具")
+    print("=" * 50)
+
+    if check_deployment():
+        print("\n✅ 部署包检查通过")
+
+        # 询问是否创建zip包
+        if sys.platform == "win32":
+            response = input("\n是否创建zip部署包? (y/n): ")
+            if response.lower() == "y":
+                create_minimal_deploy()
+    else:
+        print("\n❌ 部署包检查失败")
+
+    print("\n📋 部署命令示例:")
+    print("  scp -r build/deploy/ user@server:/opt/longjoeagent/")
+    print("  cd /opt/longjoeagent && pip install -r requirements.txt")
+    print("  python app.py")
+
+
+if __name__ == "__main__":
+    main()

+ 92 - 0
config/settings.py

@@ -0,0 +1,92 @@
+import os
+from pathlib import Path
+
+
+# 基础配置
+class Settings:
+    def __init__(self):
+        # 1. 检测环境
+        self.is_development = Path(".env.development").exists()
+        self.env = "development" if self.is_development else "production"
+        print(f"🎯 当前环境: {self.env}")
+
+        # 2. 加载配置文件
+        self._load_env_file()
+
+        # API配置
+        self.DEEPSEEK_API_KEY = self._get_api_key()
+        self.DEEPSEEK_BASE_URL = os.getenv(
+            "DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1"
+        )
+
+        # KMS配置
+        self.KMS_LIST_URL = os.getenv(
+            "KMS_LIST_URL", "http://kms.longjoe.com:82/KMS/api/common/DocListAi"
+        )
+        self.KMS_VIEW_URL = os.getenv(
+            "KMS_VIEW_URL", "http://kms.longjoe.com:82/KMS/api/common/DocViewAi"
+        )
+
+        # 其他配置
+        self.LLM_MODEL = "deepseek-chat"
+        self.LLM_TEMPERATURE = 0.7
+        self.LLM_MAX_TOKENS = 2048
+
+    def _load_env_file(self):
+        """加载对应的.env文件"""
+        from dotenv import load_dotenv
+
+        env_file = ".env.development" if self.is_development else ".env"
+
+        if Path(env_file).exists():
+            load_dotenv(env_file, override=True)
+            print(f"📁 加载配置文件: {env_file}")
+        else:
+            print(f"⚠️  配置文件 {env_file} 不存在")
+
+    def _get_api_key(self) -> str:
+        """根据环境获取API Key"""
+        if self.is_development:
+            # 开发环境:使用明文
+            key = os.getenv("DEEPSEEK_API_KEY")
+            if not key:
+                print("\n⚠️  开发环境提示:")
+                print("   请在 .env.development 中添加:")
+                print("   DEEPSEEK_API_KEY=sk-your-key-here")
+                print("=" * 40)
+            return key or ""
+        else:
+            # 生产环境:必须使用加密
+            encrypted_key = os.getenv("ENCRYPTED_DEEPSEEK_KEY")
+            if not encrypted_key:
+                raise ValueError("生产环境需要 ENCRYPTED_DEEPSEEK_KEY")
+
+            master_key = "ialwayslovelongjoe"
+
+            return self._decrypt_key(encrypted_key, master_key)
+
+    def _decrypt_key(self, encrypted_key: str, master_key: str) -> str:
+        """解密API Key"""
+        # 简单版AES解密
+        from Crypto.Cipher import AES
+        from Crypto.Util.Padding import unpad
+        import base64
+        import hashlib
+
+        try:
+            # 生成密钥
+            aes_key = hashlib.sha256(master_key.encode()).digest()
+            iv = hashlib.md5(master_key.encode()).digest()
+
+            # 解密
+            cipher = AES.new(aes_key, AES.MODE_CBC, iv)
+            ct = base64.b64decode(encrypted_key)
+            pt = unpad(cipher.decrypt(ct), AES.block_size)
+
+            return pt.decode("utf-8")
+        except Exception as e:
+            raise ValueError(f"解密失败: {e}")
+
+
+# 创建全局实例
+settings = Settings()

+ 48 - 0
config/template_config.json

@@ -0,0 +1,48 @@
+{
+    "template_extensions": {
+        "cusamt": {
+            "field_guidance": {
+                "cusname": [
+                    "一般是付款人字样旁边的户名,如果是农商行的回单,在左上角"
+                ],
+                "viewdate": [
+                    "日期格式可能为:YYYY-MM-DD、YYYY/MM/DD、YYYY年MM月DD日"
+                ],
+                "cusamt": [
+                    "金额:注意区分大小写:大写金额和小写金额都要识别"
+                ],
+                "accname": [
+                    "收款账户名称:格式可能: 工行, 广发行"
+                ],
+                "acccode": [
+                    "收款帐号:格式可能:6228****1234"
+                ],
+                "dscrp": [
+                    "备注:附言等信息,已在其他信息明确的无需重复"
+                ]
+            },
+            "additional_rules": "如果识别到农商行回单,付款人信息通常在左上角区域"
+        },
+        "saletask": {
+            "field_guidance": {
+                "cusname": [],
+                "taskdate": [],
+                "requiredate": [],
+                "banktype": [],
+                "relcode": [],
+                "otheramt": [],
+                "damt": [],
+                "cus_tele": [],
+                "rel_rep": [],
+                "dscrp": [],
+                "freight": [],
+                "mtrlname": [],
+                "unit": [],
+                "saleqty": [],
+                "enprice": [],
+                "rebate": []
+            },
+            "additional_rules": ""
+        }
+    }
+}

+ 50 - 0
config/template_config_manager.py

@@ -0,0 +1,50 @@
+import json
+import os
+from typing import Dict, Any, List
+
+
+class TemplateConfigManager:
+    def __init__(self, config_path: str = "config/template_config.json"):
+        self.config_path = config_path
+        self.config = self._load_config()
+
+    def _load_config(self) -> Dict[str, Any]:
+        """加载配置文件"""
+        if os.path.exists(self.config_path):
+            try:
+                with open(self.config_path, "r", encoding="utf-8") as f:
+                    return json.load(f)
+            except Exception as e:
+                print(f"加载配置文件失败: {e}")
+                return {"template_extensions": {}}
+        return {"template_extensions": {}}
+
+    def get_template_config(self, template_name: str) -> Dict[str, Any]:
+        """获取指定模板的配置,并过滤空白值"""
+        template_config = self.config.get("template_extensions", {}).get(
+            template_name, {}
+        )
+        return self._filter_empty_values(template_config)
+
+    def _filter_empty_values(self, config: Dict[str, Any]) -> Dict[str, Any]:
+        """过滤掉空白值(空列表、空字符串等)"""
+        filtered = {}
+
+        for key, value in config.items():
+            if isinstance(value, dict):
+                # 递归处理嵌套字典
+                filtered_subdict = self._filter_empty_values(value)
+                if filtered_subdict:  # 只添加非空的子字典
+                    filtered[key] = filtered_subdict
+            elif isinstance(value, list):
+                # 过滤空列表和非空字符串元素
+                filtered_list = [item for item in value if item and str(item).strip()]
+                if filtered_list:  # 只添加非空列表
+                    filtered[key] = filtered_list
+            elif isinstance(value, str) and value.strip():
+                # 只添加非空字符串
+                filtered[key] = value
+            elif value:  # 其他非空值(数字、布尔值等)
+                filtered[key] = value
+
+        return filtered

+ 38 - 0
config/tool_config.json

@@ -0,0 +1,38 @@
+{
+    "get_mtrlware_data": {
+        "基础描述": "获取指定物料的库存信息",
+        "功能说明": "从仓库管理系统中查询物料的实时库存数据,包括可用数量、已分配数量、在途数量等详细信息",
+        "入参说明": {
+            "backend_url": "后端API地址",
+            "token": "用户认证令牌,用于身份验证",
+            "mtrlname": "物料名称,支持模糊匹配"
+        },
+        "返回值说明": {
+            "格式": "一个包含物料库存数据的字符串",
+            "字段含义": "mtrlcode:物料编码, mtrlname:物料名称, storagename:仓库名称, noallocqty:库存数量, unit:单位, noauditingqty:已开单数量, notauditnoallocqty:未开单数量, pzinfo:配置信息, buydays:采购周期天数"
+        },
+        "输出格式要求": [
+            "以自然语言描述形式输出,不要使用表格",
+            "重复信息要总结归纳,精简显示"
+        ],
+        "使用示例": "用户输入:'查看铜管的库存' -> 系统调用此工具获取铜管的库存信息"
+    },
+    "get_sale_amt": {
+        "基础描述": "获取指定时间范围的销售金额,按月汇总",
+        "入参说明": {
+            "backend_url": "后端API地址",
+            "token": "认证令牌",
+            "funtion_name": "函数名称; get_sale_amt_by_month:按月汇总销售额; get_sale_amt_by_day:按天汇总销售额; get_sale_amt_by_produce:产品销售额; get_sale_amt_by_cus:客户销售额",
+            "firstdate": "开始日期,格式YYYY-MM-DD",
+            "lastdate": "结束日期,格式YYYY-MM-DD 23:59:59"
+        },
+        "返回值说明": {
+            "格式": "一个包含销售金额的字符串"
+        },
+        "输出格式要求": [
+            "以自然语言描述形式输出,不要使用表格",
+            "重复信息要总结归纳,精简显示"
+        ],
+        "使用示例": "用户输入:'查看2023年1月1日至2023年12月31日的销售金额' -> 系统调用此工具获取2023年1月至12月的销售金额"
+    }
+}

+ 292 - 0
core/agent.py

@@ -0,0 +1,292 @@
+import os
+import dotenv
+import datetime
+from pathlib import Path
+
+from langchain.agents import create_agent, AgentState
+from langchain_openai import ChatOpenAI
+from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
+from tools.tool_factory import get_all_tools
+from langchain_core.runnables import RunnableConfig
+from langchain.agents.middleware import before_model
+from langgraph.runtime import Runtime
+from typing import Any, List, Sequence
+from langchain.messages import RemoveMessage
+from langgraph.graph.message import REMOVE_ALL_MESSAGES
+import sqlite3
+from config.settings import settings
+
+dotenv.load_dotenv()
+
+
+def create_system_prompt(
+    backend_url: str = "", token: str = "", username: str = "default"
+) -> str:
+    """
+    创建动态的system_prompt,支持参数化配置
+
+    Args:
+        backend_url: 后端API地址
+        token: 访问后端的认证令牌
+        username: 用户名
+    """
+    # 判断token状态
+    if token:
+        token_status = "已配置有效的认证令牌,可以调用后端API获取用户数据"
+    else:
+        token_status = "未提供认证令牌,后端API调用可能受限"
+
+    # 判断backend_url状态
+    if backend_url:
+        backend_status = f"已配置后端地址: {backend_url}"
+    else:
+        backend_status = "未配置后端地址,只能访问知识库"
+
+    system_prompt = f"""你是属于龙嘉软件公司的AI助手,名字叫小龙。
+
+# 当前会话信息
+- 当前用户: {username}
+- 后端服务状态: {backend_status}
+- 认证状态: {token_status}
+现在时间是{datetime.datetime.now().isoformat()}
+
+# 核心能力
+你可为客户提供ERP问题的解决方案,也可回答与龙嘉软件相关的问题。软件面向全球客户,你需按用户提问的语言回答。
+
+# 工作流程
+1. 分析用户问题的意图,提取关键词
+2. 根据意图及关键词调用相应工具
+3. 可以访问知识库工具,也可以调用后端API获取数据
+
+# 后端API使用指南
+{"- 当用户需要查询个人数据、订单信息、账户状态时,可使用后端API" if backend_url and token else "- 由于缺少认证信息,暂时无法调用后端API"}
+{"- 后端地址: " + backend_url if backend_url else ""}
+{"- API调用会自动包含用户的认证令牌: " + token if token else ""}
+
+# 知识库搜索规则
+- 判断问题所属模块(销售、采购、生产、财务、仓储、权限等)并纳入关键词
+- 文章匹配要精准,例如"销售订单新建权限",拆分为:"销售订单"、"新建"、"权限"
+- 避免使用"的"、"地"、"得"、"了"、"在"等无意义词汇
+- 关键词可以多个,要判断问题属于哪个模块并将其纳入关键字
+- 如果匹配文章太少(少于3篇),尝试以下方法:
+  a) 变更关键字(同义词、近义词)
+  b) 把关键字拆得更细
+  c) 扩大搜索范围(减少关键词数量)
+  d) 重新搜索
+- 获取到文章列表后,用工具获取文章内容然后回答用户问题
+
+# 回答策略
+- 优先使用知识库中的准确信息
+- 如果知识库中有相关文章,结合文章内容进行回答
+- 如果需要实时数据且认证有效,可调用后端API
+- 如果找不到对应知识库文章,向客户说明:"我正在学习这个问题的解决方案,很快就能正式为您服务"
+- 如果用户的问题需要后端数据但认证无效,提示:"查看个人数据需要登录验证,请确保已提供正确的访问令牌"
+
+# 注意事项
+- 保护用户隐私,不在回复中暴露敏感信息
+- 如果API调用失败,提供友好的错误信息
+- 保持回答的专业性和准确性
+- 对于不确定的问题,可以建议用户联系客服或技术支持
+"""
+    # print(system_prompt)
+    return system_prompt
+
+
+def get_day_number(date=None):
+    """获取日期编号 (YYYYMMDD 格式)"""
+    if date is None:
+        date = datetime.datetime.now()
+    return date.strftime("%Y%m%d")  # 格式: 20251229
+
+
+def get_sqlite_checkpointer():
+    """创建按天分割的SQLite检查点保存器"""
+    try:
+        from langgraph.checkpoint.sqlite import SqliteSaver
+
+        # 获取当前日期编号
+        current_day = get_day_number()
+
+        # 数据库文件存放目录
+        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+        base_dir = os.path.join(project_root, "data", "checkpoints")
+        os.makedirs(base_dir, exist_ok=True)
+
+        # 数据库文件名格式: checkpoints_20251229.db
+        db_filename = f"checkpoints_{current_day}.db"
+        db_path = os.path.join(base_dir, db_filename)
+        # checkpointer = SqliteSaver.from_conn_string(db_path)
+
+        conn = sqlite3.connect(db_path, check_same_thread=False)
+        conn.execute("PRAGMA wal_autocheckpoint=500")  # 2MB 就提交
+        conn.execute("PRAGMA journal_size_limit=52428800")  # 最大 50MB
+        checkpointer = SqliteSaver(conn)
+
+        return checkpointer
+
+    except Exception as e:
+        print(f"❌❌ 创建 SQLite 检查器失败: {e}")
+        import traceback
+
+        traceback.print_exc()
+
+        # 回退到内存保存器
+        from langgraph.checkpoint.memory import InMemorySaver
+
+        print("⚠️ 使用 InMemorySaver 作为回退")
+        return InMemorySaver()
+
+
+def cleanup_old_checkpoints(max_days=7):
+    """清理超过指定天数的旧检查点文件(可选功能)"""
+    try:
+        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+        base_dir = os.path.join(project_root, "data", "checkpoints")
+        if not os.path.exists(base_dir):
+            return
+
+        # 获取当前日期
+        current_date = datetime.datetime.now()
+
+        # 遍历目录中的所有.db文件
+        for filename in os.listdir(base_dir):
+            if filename.startswith("checkpoints_") and filename.endswith(".db"):
+                try:
+                    print(f"检查旧检查点文件: {filename}")
+                    # 提取日期 (checkpoints_day_20251229.db -> 20251229)
+                    date_str = filename.replace("checkpoints_day_", "").replace(
+                        ".db", ""
+                    )
+                    file_date = datetime.datetime.strptime(date_str, "%Y%m%d")
+
+                    # 计算天数差
+                    days_diff = (current_date - file_date).days
+
+                    # 删除超过 max_days 天的旧数据
+                    if days_diff > max_days:
+                        file_path = os.path.join(base_dir, filename)
+                        os.remove(file_path)
+                        print(f"🧹🧹 清理旧检查点文件: {filename} (超过 {max_days} 天)")
+
+                except (ValueError, IndexError):
+                    # 文件名不符合预期,跳过
+                    continue
+
+    except Exception as e:
+        print(f"⚠️ 清理旧检查点失败: {e}")
+
+
+# 创建agent
+def create_langchain_agent(
+    backend_url: str = "",
+    token: str = "",
+    username: str = "default",
+    thread_id: str = "default",
+):
+    llm = ChatOpenAI(
+        model=settings.LLM_MODEL,
+        temperature=settings.LLM_TEMPERATURE,
+        api_key=settings.DEEPSEEK_API_KEY,
+        base_url=settings.DEEPSEEK_BASE_URL,
+        max_tokens=settings.LLM_MAX_TOKENS,
+    )
+
+    tools = get_all_tools()
+    # 添加调试信息
+    print(f"🔧🔧🔧🔧 Agent 创建调试信息:")
+    print(f"  - 用户: {username}")
+    print(f"  - Thread ID: {thread_id}")
+    print(f"  - 后端地址: {backend_url}")
+    print(f"  - Token: {'已提供' if token else '未提供'}")
+    print(f"  - 工具数量: {len(tools)}")
+
+    for i, tool in enumerate(tools):
+        print(f"  - 工具 {i+1}: {tool.name}")
+
+    # 获取动态的system_prompt
+    system_prompt = create_system_prompt(backend_url, token, username)
+
+    def simple_turn_based_trim(
+        messages: Sequence[BaseMessage],
+        keep_turns: int = 3,
+        system_message: BaseMessage = None,
+    ) -> List[BaseMessage]:
+        """
+        修正版:按完整对话轮次修剪消息
+        每轮对话从Human开始,到下一个Human之前结束
+        """
+        if not messages:
+            return []
+        # 分离系统消息(始终保留)
+        system_messages = []
+        other_messages = []
+        for msg in messages:
+            if (
+                isinstance(msg, SystemMessage)
+                or getattr(msg, "type", None) == "system"
+                or getattr(msg, "role", None) == "system"
+                or msg.__class__.__name__ == "SystemMessage"
+            ):
+                system_messages.append(msg)
+            else:
+                other_messages.append(msg)
+
+        if len(other_messages) <= 1:
+            return system_messages + other_messages
+
+        # 找出所有Human消息的位置
+        human_indices = []
+        for i, msg in enumerate(other_messages):
+            if (
+                isinstance(msg, HumanMessage)
+                or getattr(msg, "type", None) == "human"
+                or getattr(msg, "role", None) == "user"
+            ):
+                human_indices.append(i)
+
+        # 如果Human消息不足keep_turns轮,返回所有
+        if not human_indices or len(human_indices) <= keep_turns:
+            return system_messages + other_messages
+
+        # 计算起始索引
+        start_idx = human_indices[-keep_turns]
+
+        # 获取要保留的消息
+        preserved_messages = other_messages[start_idx:]
+
+        # 4. 返回从该索引开始的所有消息
+        result = system_messages + preserved_messages
+        # print(f"修剪后消息数: {len(result)}")
+
+        return result
+
+    @before_model
+    def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
+        """Keep only the last few messages to fit context window."""
+        messages = state["messages"]
+
+        if len(messages) <= 3:
+            return None  # No changes needed
+
+        # 保留最后4轮对话
+        trimmed_messages = simple_turn_based_trim(messages, keep_turns=4)
+
+        return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)] + trimmed_messages}
+
+    # 使用SQLiteSaver(按天分割)
+    checkpointer = get_sqlite_checkpointer()
+    # print(f"打印检查点保存器: {checkpointer}")
+
+    # 可选:清理旧检查点(可配置为定期执行)
+    if os.getenv("AUTO_CLEANUP", "false").lower() == "true":
+        cleanup_old_checkpoints(max_days=7)  # 保留最近7天数据
+
+    agent = create_agent(
+        llm,
+        tools,
+        checkpointer=checkpointer,
+        system_prompt=system_prompt,
+        middleware=[trim_messages],
+    )
+
+    return agent

+ 133 - 0
core/agent_manager.py

@@ -0,0 +1,133 @@
+import asyncio
+import time
+import hashlib
+from typing import Dict, Optional
+from core.agent import create_langchain_agent
+from utils.logger import chat_logger
+
+
+class AgentManager:
+
+    def __init__(self):
+        self._local_agent_cache = {}  # 仅缓存agent配置,不缓存实例
+        self._is_shutdown = False
+
+    async def initialize(self):
+        """异步初始化管理器"""
+        self._is_shutdown = False
+        chat_logger.info("🔧🔧🔧🔧 Agent管理器初始化完成")
+
+    async def shutdown(self):
+        """异步关闭管理器"""
+        self._is_shutdown = True
+        self._local_agent_cache.clear()
+        chat_logger.info("🧹🧹🧹🧹 Agent管理器已关闭")
+
+    def _get_agent_config_key(
+        self, thread_id: str, username: str, backend_url: str, token: str
+    ) -> str:
+        """生成agent配置的缓存key"""
+        key_data = f"{thread_id}:{username}:{backend_url}:{token}"
+        return hashlib.md5(key_data.encode()).hexdigest()
+
+    def _get_user_identifier(self, username: str, token: str) -> str:
+        """生成用户标识符"""
+        if not username or username == "default":
+            username_part = "anonymous"
+        else:
+            username_part = username
+
+        if token and len(token) >= 8:
+            token_part = token[:8]
+        else:
+            token_part = "notoken"
+
+        return f"{username_part}_{token_part}"
+
+    async def get_agent_instance(
+        self, thread_id: str, username: str, backend_url: str, token: str
+    ):
+        if self._is_shutdown:
+            raise RuntimeError("Agent管理器已关闭")
+
+        clean_username = username or "anonymous"
+        clean_backend = backend_url or ""
+        clean_token = token or ""
+
+        user_id = self._get_user_identifier(clean_username, clean_token)
+        config_key = self._get_agent_config_key(
+            thread_id, clean_username, clean_backend, clean_token
+        )
+
+        # 检查本地配置缓存
+        current_time = time.time()
+        if config_key in self._local_agent_cache:
+            agent_instance, timestamp = self._local_agent_cache[config_key]
+            if current_time - timestamp <= 300:  # 5分钟本地缓存
+                chat_logger.info(f"使用本地缓存的agent配置: 用户={user_id}")
+                return agent_instance
+
+        chat_logger.info(f"创建新的agent实例: 用户={user_id}")
+        agent_instance = await self._create_agent_async(
+            backend_url=clean_backend,
+            token=clean_token,
+            username=clean_username,
+            thread_id=thread_id,
+        )
+
+        # 缓存agent配置到本地
+        self._local_agent_cache[config_key] = (agent_instance, current_time)
+        chat_logger.info(f"Agent配置已缓存: 用户={user_id}")
+
+        return agent_instance
+
+    async def _create_agent_async(
+        self, backend_url: str, token: str, username: str, thread_id: str
+    ):
+        """创建agent实例"""
+
+        def sync_create_agent():
+            return create_langchain_agent(
+                backend_url=backend_url,
+                token=token,
+                username=username,
+                thread_id=thread_id,
+            )
+
+        loop = asyncio.get_event_loop()
+        return await loop.run_in_executor(None, sync_create_agent)
+
+    async def clear_user_agent(
+        self, thread_id: str, username: str, backend_url: str, token: str
+    ):
+        """清除特定用户的agent配置缓存"""
+        config_key = self._get_agent_config_key(thread_id, username, backend_url, token)
+
+        # 清除本地缓存
+        if config_key in self._local_agent_cache:
+            del self._local_agent_cache[config_key]
+
+        user_id = self._get_user_identifier(username, token)
+        chat_logger.info(f"已清除用户Agent配置缓存: {user_id}")
+
+    async def get_cache_status(self):
+        """获取缓存状态"""
+        if self._is_shutdown:
+            return {"status": "shutdown", "cache_size": 0}
+
+        return {
+            "local_config_cache_size": len(self._local_agent_cache),
+            "status": "active",
+            "message": "",
+        }
+
+    async def clear_cache(self):
+        """清空本地配置缓存"""
+        local_count = len(self._local_agent_cache)
+        self._local_agent_cache.clear()
+        chat_logger.info(f"清空本地配置缓存: {local_count}个配置")
+        return local_count
+
+
+# 全局实例
+agent_manager = AgentManager()

+ 175 - 0
core/chat_service.py

@@ -0,0 +1,175 @@
+# core/chat_service.py - 修复版本
+import asyncio
+from typing import Dict, Any, List
+from langchain_core.messages import HumanMessage
+from utils.logger import chat_logger, log_chat_entry
+from core.agent_manager import agent_manager
+
+
+class ChatService:
+    """聊天服务 - 支持真正并发的版本"""
+
+    def __init__(self):
+        self.agent_manager = agent_manager
+        # 创建专用的线程池用于执行同步的Langchain操作
+        self._thread_pool = None
+
+    def _get_thread_pool(self):
+        """获取或创建线程池"""
+        if self._thread_pool is None:
+            import concurrent.futures
+
+            # 创建足够大的线程池支持并发
+            self._thread_pool = concurrent.futures.ThreadPoolExecutor(
+                max_workers=20,  # 根据服务器配置调整
+                thread_name_prefix="langchain_worker",
+            )
+        return self._thread_pool
+
+    async def process_chat_request(
+        self, request_data: Dict[str, Any]
+    ) -> Dict[str, Any]:
+        """异步处理聊天请求 - 真正并发版本"""
+        try:
+            # 提取请求数据
+            message = request_data["message"]
+            thread_id = request_data["thread_id"]
+            username = request_data["username"]
+            backend_url = request_data["backend_url"]
+            token = request_data["token"]
+
+            # 生成用户标识符
+            user_id = self.agent_manager._get_user_identifier(username, token)
+
+            chat_logger.info(
+                f"收到请求 - 用户={user_id} , 线程ID={thread_id}, 消息={message[:100]}"
+            )
+
+            # 异步获取agent实例
+            agent = await self.agent_manager.get_agent_instance(
+                thread_id=thread_id,
+                username=username,
+                backend_url=backend_url,
+                token=token,
+            )
+
+            # ✅ 修复:在线程池中执行同步的Langchain操作
+            result = await self._run_agent_in_threadpool(
+                agent, message, thread_id, user_id
+            )
+
+            chat_logger.info(f"Agent处理完成 - 用户={user_id}")
+
+            if not isinstance(result, dict) or "messages" not in result:
+                raise ValueError(f"Agent返回格式异常: {type(result)}")
+
+            # 处理结果
+            return self._process_agent_result(result, user_id, request_data)
+
+        except Exception as e:
+            chat_logger.error(f"聊天处理失败: {str(e)}")
+            raise
+
+    async def _run_agent_in_threadpool(
+        self, agent, message: str, thread_id: str, user_id: str
+    ):
+        """在线程池中执行Langchain Agent"""
+        loop = asyncio.get_event_loop()
+        thread_pool = self._get_thread_pool()
+
+        # 准备输入
+        inputs = {"messages": [HumanMessage(content=message)]}
+        config = {"configurable": {"thread_id": thread_id}}
+
+        chat_logger.info(f"在线程池中执行Agent - 用户={user_id}")
+
+        try:
+            # 在线程池中执行同步操作
+            result = await loop.run_in_executor(
+                thread_pool, lambda: agent.invoke(inputs, config)
+            )
+            return result
+        except Exception as e:
+            chat_logger.error(f"Agent执行失败 - 用户={user_id}: {str(e)}")
+            raise
+
+    def _process_agent_result(
+        self, result: Dict[str, Any], user_id: str, request_data: Dict
+    ) -> Dict[str, Any]:
+        """处理Agent返回结果"""
+        all_messages = result["messages"]
+        processed_messages = []
+        all_ai_messages = []
+        all_tool_calls = []
+        final_answer = ""
+
+        for i, msg in enumerate(all_messages):
+            msg_data = {
+                "index": i,
+                "type": getattr(msg, "type", "unknown"),
+                "content": "",
+            }
+
+            # 获取内容
+            if hasattr(msg, "content"):
+                content = msg.content
+                if isinstance(content, str):
+                    msg_data["content"] = content
+                else:
+                    msg_data["content"] = str(content)
+
+            # 获取工具调用
+            if hasattr(msg, "tool_calls") and msg.tool_calls:
+                msg_data["tool_calls"] = msg.tool_calls
+                all_tool_calls.extend(msg.tool_calls)
+
+                for tool_call in msg.tool_calls:
+                    tool_name = tool_call.get("name", "unknown")
+                    tool_args = tool_call.get("args", {})
+                    chat_logger.info(f"工具调用 - 用户={user_id}, 工具={tool_name}")
+
+            if hasattr(msg, "tool_call_id"):
+                msg_data["tool_call_id"] = msg.tool_call_id
+
+            if hasattr(msg, "name"):
+                msg_data["name"] = msg.name
+
+            processed_messages.append(msg_data)
+
+            # 收集AI消息
+            if msg_data["type"] == "ai":
+                all_ai_messages.append(msg_data)
+                final_answer = msg_data["content"]
+
+        # 构建响应
+        response = {
+            "final_answer": final_answer,
+            "all_ai_messages": all_ai_messages,
+            "all_messages": processed_messages,
+            "tool_calls": all_tool_calls,
+            "thread_id": request_data["thread_id"],
+            "user_identifier": user_id,
+            "backend_config": {
+                "backend_url": request_data["backend_url"] or "未配置",
+                "username": request_data["username"],
+                "has_token": bool(request_data["token"]),
+            },
+            "success": True,
+        }
+
+        # 记录日志
+        log_chat_entry(user_id, request_data["message"], response)
+        chat_logger.info(f"请求处理完成 - 用户={user_id}")
+
+        return response
+
+    async def shutdown(self):
+        """关闭线程池"""
+        if self._thread_pool:
+            self._thread_pool.shutdown(wait=False)
+            self._thread_pool = None
+            chat_logger.info("聊天服务线程池已关闭")
+
+
+# 全局实例
+chat_service = ChatService()

+ 165 - 0
core/document_processor/document_service.py

@@ -0,0 +1,165 @@
+from typing import Dict, Any, List, Optional
+from core.ocr_service import PaddleOCRService
+from .llm_parser import LLMParser
+from .templates.template_registry import TemplateRegistry
+import os
+from utils.logger import chat_logger
+
+
+class DocumentProcessingService:
+    """单据处理服务(简化版)"""
+
+    def __init__(self, ocr_service: PaddleOCRService = None):
+        # 依赖注入OCR服务
+        self.ocr_service = ocr_service or self._create_default_ocr_service()
+        self.llm_parser = LLMParser()
+
+    def _create_default_ocr_service(self) -> PaddleOCRService:
+        """创建默认OCR服务"""
+        return PaddleOCRService(
+            api_url=os.getenv("PADDLE_OCR_API_URL", ""),
+            token=os.getenv("PADDLE_OCR_TOKEN", ""),
+        )
+
+    async def ocr_create_bill(
+        self, image_bytes: bytes, document_type: str = None, ocr_options: Dict = None
+    ) -> Dict[str, Any]:
+        """
+        通过扫描图片辅助建立单据
+
+        Args:
+            image_bytes: 图片字节
+            document_type: 单据类型,如'invoice'
+            ocr_options: OCR选项
+
+        Returns:
+            处理结果
+        """
+        # 1. OCR识别
+        ocr_result = await self.ocr_service.recognize_image_async(
+            image_bytes, **(ocr_options or {})
+        )
+
+        # 2. 提取文本
+        ocr_text = self.ocr_service.extract_text_from_result(ocr_result)
+
+        if not ocr_text or not ocr_text.strip():
+            chat_logger.warning("OCR未识别到有效文本")
+            ocr_text = ""
+
+        # 3. 如果没有指定类型,尝试自动检测
+        if not document_type and ocr_text:
+            document_type = await self._detect_document_type(ocr_text)
+
+        # 4. 如果有单据类型,用LLM解析
+        parsed_data = None
+        if document_type and ocr_text:
+            template = TemplateRegistry.get_template(document_type)
+            parsed_data = await self.llm_parser.parse_to_json(ocr_text, template)
+
+        # 5. 返回结果
+        result = {
+            "success": parsed_data is not None,
+            "ocr_raw": ocr_result,  # 原始API返回
+            "ocr_text": ocr_text,  # 格式化文本
+        }
+
+        if parsed_data:
+            result.update(
+                {
+                    "document_type": document_type,
+                    "parsed_data": parsed_data,
+                    "has_parsed_data": True,
+                }
+            )
+
+        return result
+
+    async def pure_ocr(
+        self, image_bytes: bytes, ocr_options: Dict = None
+    ) -> Dict[str, Any]:
+        """
+        扫描图片并返回OCR识别结果
+
+        Args:
+            image_bytes: 图片字节
+            ocr_options: OCR选项
+
+        Returns:
+            处理结果
+        """
+        # 1. OCR识别
+        ocr_result = await self.ocr_service.recognize_image_async(
+            image_bytes, **(ocr_options or {})
+        )
+
+        # 2. 提取文本
+        ocr_text = self.ocr_service.extract_text_from_result(ocr_result)
+
+        if not ocr_text or not ocr_text.strip():
+            chat_logger.warning("OCR未识别到有效文本")
+            ocr_text = ""
+
+        result = {
+            "ocr_text": ocr_text,  # 格式化文本
+        }
+
+        return result
+
+    async def message_create_bill(
+        self, message: str, document_type: str = None
+    ) -> Dict[str, Any]:
+        """
+        通过文本消息辅助建立单据
+
+        Args:
+            message: 文本消息
+            document_type: 单据类型,如'invoice'
+
+        Returns:
+            处理结果
+        """
+        print("message:", message)
+        print("document_type:", document_type)
+        template = TemplateRegistry.get_template(document_type)
+        parsed_data = await self.llm_parser.parse_to_json(message, template)
+        print("parsed_data:", parsed_data)
+
+        # 5. 返回结果
+        result = {"success": parsed_data is not None}
+
+        if parsed_data:
+            result.update(
+                {
+                    "document_type": document_type,
+                    "parsed_data": parsed_data,
+                    "has_parsed_data": True,
+                }
+            )
+        print("result:", result)
+        return result
+
+    async def _detect_document_type(self, ocr_text: str) -> Optional[str]:
+        """简单关键词检测单据类型"""
+        text_lower = ocr_text.lower()
+
+        if any(keyword in text_lower for keyword in ["发票", "增值税", "专用发票"]):
+            return "invoice"
+        elif any(keyword in text_lower for keyword in ["收据", "收款", "收条"]):
+            return "receipt"
+        elif any(keyword in text_lower for keyword in ["订单", "订货单", "采购单"]):
+            return "order"
+
+        return None
+
+    async def get_available_templates(self) -> List[Dict]:
+        """获取可用模板列表"""
+        templates = TemplateRegistry.list_templates()
+        return [
+            {
+                "name": name,
+                "description": desc,
+                "fields": TemplateRegistry.get_template(name).output_schema(),
+            }
+            for name, desc in templates.items()
+        ]

+ 54 - 0
core/document_processor/llm_parser.py

@@ -0,0 +1,54 @@
+from langchain_openai import ChatOpenAI
+from langchain.messages import HumanMessage, SystemMessage
+import json
+import re
+from typing import Dict, Any
+from utils.logger import chat_logger
+import os
+from config.settings import settings
+
+
+class LLMParser:
+    """LLM解析器"""
+
+    def __init__(self, llm_config: Dict = None):
+        self.llm = ChatOpenAI(
+            model=settings.LLM_MODEL,
+            temperature=settings.LLM_TEMPERATURE,
+            api_key=settings.DEEPSEEK_API_KEY,
+            base_url=settings.DEEPSEEK_BASE_URL,
+            max_tokens=settings.LLM_MAX_TOKENS,
+        )
+
+    async def parse_to_json(self, ocr_text: str, template) -> Dict[str, Any]:
+        """使用模板解析OCR文本为JSON"""
+        print("template.system_prompt:", template.system_prompt)
+        messages = [
+            SystemMessage(content=template.system_prompt),
+            HumanMessage(content=f"请从以下文本中提取信息:\n\n{ocr_text}"),
+        ]
+
+        try:
+            response = await self.llm.ainvoke(messages)
+            content = response.content  # response.generations[0][0].text
+
+            # 提取JSON(处理可能的Markdown格式)
+            json_match = re.search(r"```json\n(.*?)\n```", content, re.DOTALL)
+            if json_match:
+                content = json_match.group(1)
+
+            result = json.loads(content)
+
+            # 验证结果
+            if template.validate_result(result):
+                result = template.post_process(result)
+                return result
+            else:
+                raise ValueError("解析结果验证失败")
+
+        except json.JSONDecodeError as e:
+            chat_logger.error(f"JSON解析失败: {e}, 原始内容: {content}")
+            raise
+        except Exception as e:
+            chat_logger.error(f"LLM解析失败: {e}")
+            raise

+ 153 - 0
core/document_processor/templates/base_template.py

@@ -0,0 +1,153 @@
+from abc import ABC, abstractmethod
+from typing import Dict, Any
+from pydantic import BaseModel
+import json
+from datetime import datetime
+from config.template_config_manager import TemplateConfigManager  # 新增导入
+
+
+class DocumentTemplate(ABC):
+    """单据模板基类(支持配置扩展)"""
+
+    def __init__(self, config_manager: TemplateConfigManager = None):
+        # 新增:配置管理器
+        self.config_manager = config_manager or TemplateConfigManager()
+        self._template_config = self.config_manager.get_template_config(
+            self.template_name
+        )
+
+    @property
+    @abstractmethod
+    def template_name(self) -> str:
+        """模板名称标识"""
+        pass
+
+    @property
+    @abstractmethod
+    def description(self) -> str:
+        """模板描述"""
+        pass
+
+    def get_hardcoded_guidance(self) -> Dict[str, Any]:
+        """硬编码的字段指导信息(子类可重写)"""
+        return {"field_guidance": {}, "additional_rules": ""}  # 项目初期为空
+
+    @property
+    def system_prompt(self) -> str:
+        """系统提示词(支持配置扩展)"""
+        # 获取硬编码指导
+        hardcoded_guidance = self.get_hardcoded_guidance()
+        hardcoded_field_guidance = hardcoded_guidance.get("field_guidance", {})
+        hardcoded_additional_rules = hardcoded_guidance.get("additional_rules", "")
+
+        # 获取配置指导(已自动过滤空白值)
+        configured_field_guidance = self._template_config.get("field_guidance", {})
+        configured_additional_rules = self._template_config.get("additional_rules", "")
+
+        # 合并字段指导信息
+        merged_field_guidance = self._merge_field_guidance(
+            hardcoded_field_guidance, configured_field_guidance
+        )
+
+        base_prompt = f"""你是一个专业的单据信息提取助手。现在时间是{datetime.now().isoformat()}
+请从OCR识别结果或用户的输入中提取信息,并严格按照以下JSON格式返回:
+
+{json.dumps(self.output_schema(), indent=2, ensure_ascii=False)}
+
+提取规则:
+{self.extraction_rules()}"""
+
+        # 添加合并后的指导信息(仅在存在内容时添加)
+        extended_prompt = self._build_extended_guidance(
+            merged_field_guidance,
+            hardcoded_additional_rules,
+            configured_additional_rules,
+        )
+
+        if extended_prompt:
+            base_prompt += f"\n\n额外指导信息:\n{extended_prompt}"
+
+        base_prompt += """
+
+请确保:
+1. 只提取明确存在的字段
+2. 日期格式统一为YYYY-MM-DD
+3. 数字类型保持原样"""
+
+        return base_prompt
+
+    def _merge_field_guidance(
+        self, hardcoded: Dict[str, list], configured: Dict[str, list]
+    ) -> Dict[str, list]:
+        """合并硬编码和配置的字段指导信息(自动去重)"""
+        merged = {}
+        all_fields = set(hardcoded.keys()) | set(configured.keys())
+
+        for field in all_fields:
+            hardcoded_hints = hardcoded.get(field, [])
+            configured_hints = configured.get(field, [])
+
+            # 合并并去重,保持顺序
+            combined = []
+            seen = set()
+            for hint in hardcoded_hints + configured_hints:
+                if hint and hint not in seen:  # 过滤空值并去重
+                    combined.append(hint)
+                    seen.add(hint)
+
+            if combined:  # 只添加有内容的字段
+                merged[field] = combined
+
+        return merged
+
+    def _build_extended_guidance(
+        self,
+        field_guidance: Dict[str, list],
+        hardcoded_rules: str,
+        configured_rules: str,
+    ) -> str:
+        """构建扩展指导信息"""
+        guidance_parts = []
+
+        # 字段指导信息(仅在存在内容时添加)
+        if field_guidance:
+            guidance_parts.append("字段识别指导:")
+            for field_name, hints in field_guidance.items():
+                if hints:
+                    combined_hints = "; ".join(hints)
+                    guidance_parts.append(f"- {field_name}: {combined_hints}")
+
+        # 合并额外规则(过滤空值)
+        additional_rules = []
+        if hardcoded_rules and hardcoded_rules.strip():
+            additional_rules.append(hardcoded_rules)
+        if configured_rules and configured_rules.strip():
+            additional_rules.append(configured_rules)
+
+        if additional_rules:
+            guidance_parts.append(f"特殊规则: {'; '.join(additional_rules)}")
+
+        return "\n".join(guidance_parts) if guidance_parts else ""
+
+    @abstractmethod
+    def output_schema(self) -> Dict[str, Any]:
+        """返回JSON输出结构定义"""
+        pass
+
+    @abstractmethod
+    def extraction_rules(self) -> str:
+        """字段提取规则说明"""
+        pass
+
+    def validate_result(self, result: Dict) -> bool:
+        """验证解析结果"""
+        return True
+
+    def post_process(self, result: Dict) -> Dict:
+        """后处理钩子"""
+        result["_metadata"] = {
+            "template": self.template_name,
+            "processed_at": datetime.now().isoformat(),
+            "template_description": self.description,
+        }
+        return result

+ 40 - 0
core/document_processor/templates/cusamt_template.py

@@ -0,0 +1,40 @@
+from typing import Any, Dict
+from core.document_processor.templates.base_template import DocumentTemplate
+
+
+class CusAmtTemplate(DocumentTemplate):
+    """客户收款单模板"""
+
+    @property
+    def template_name(self) -> str:
+        return "cusamt"
+
+    @property
+    def description(self) -> str:
+        return "客户收款单识别模板"
+
+    def get_hardcoded_guidance(self) -> Dict[str, Any]:
+        """硬编码的字段指导信息(项目初期为空,验收后转移内容到这里)"""
+        return {
+            "field_guidance": {
+                # 项目初期保持为空,验收后将template_config.json的内容转移到这里
+                # "cusname": ["可能的字段描述:付款人户名、付款方、转出方、付款人名称"],
+                # "cusamt": ["金额字段可能包含:金额、收款金额、实收金额"],
+            },
+            "additional_rules": "",  # 硬编码的额外规则
+        }
+
+    def output_schema(self) -> Dict[str, Any]:
+        return {
+            "cusname": "付款人户名,或付款方,或转出方",
+            "viewdate": "到账日期",
+            "cusamt": "金额",
+            "accname": "收款账户名称",
+            "acccode": "收款帐号",
+            "dscrp": "备注",
+            "kindstr": "收款类型(规范为以下几类:余款,订金,预收款,其它,缺省值:余款)",
+        }
+
+    def extraction_rules(self) -> str:
+        return """
+        """

+ 47 - 0
core/document_processor/templates/invoice_template.py

@@ -0,0 +1,47 @@
+from typing import Any, Dict
+from core.document_processor.templates.base_template import DocumentTemplate
+
+
+class InvoiceTemplate(DocumentTemplate):
+    """发票模板"""
+
+    @property
+    def template_name(self) -> str:
+        return "invoice"
+
+    @property
+    def description(self) -> str:
+        return "增值税发票识别模板"
+
+    def output_schema(self) -> Dict[str, Any]:
+        return {
+            "invoice_code": "发票代码",
+            "invoice_number": "发票号码",
+            "issue_date": "开票日期",
+            "seller_name": "销售方名称",
+            "seller_tax_id": "销售方税号",
+            "buyer_name": "购买方名称",
+            "buyer_tax_id": "购买方税号",
+            "amount_without_tax": "不含税金额",
+            "tax_amount": "税额",
+            "total_amount": "价税合计",
+            "items": [
+                {
+                    "name": "货物或服务名称",
+                    "specification": "规格型号",
+                    "unit": "单位",
+                    "quantity": "数量",
+                    "unit_price": "单价",
+                    "amount": "金额",
+                }
+            ],
+        }
+
+    def extraction_rules(self) -> str:
+        return """
+        """
+
+    def validate_result(self, result: Dict) -> bool:
+        # 验证必填字段
+        required_fields = ["invoice_number", "total_amount", "issue_date"]
+        return all(field in result for field in required_fields)

+ 51 - 0
core/document_processor/templates/saletask_template.py

@@ -0,0 +1,51 @@
+from typing import Any, Dict
+from core.document_processor.templates.base_template import DocumentTemplate
+
+
+class SaleTaskTemplate(DocumentTemplate):
+    """销售订单模板"""
+
+    @property
+    def template_name(self) -> str:
+        return "saletask"
+
+    @property
+    def description(self) -> str:
+        return "销售订单识别模板"
+
+    def get_hardcoded_guidance(self) -> Dict[str, Any]:
+        """硬编码的字段指导信息(项目初期为空,验收后转移内容到这里)"""
+        return {
+            "field_guidance": {
+                # 项目初期保持为空,验收后将template_config.json的内容转移到这里
+            },
+            "additional_rules": "",  # 硬编码的额外规则
+        }
+
+    def output_schema(self) -> Dict[str, Any]:
+        return {
+            "cusname": "客户名称",
+            "taskdate": "订货日期",
+            "requiredate": "交货日期",
+            "banktype": "结算方式",
+            "relcode": "相关号码",
+            "otheramt": "优惠金额",
+            "damt": "订金",
+            "cus_tele": "客户联系电话",
+            "rel_rep": "客户联系人",
+            "dscrp": "单据备注描述",
+            "freight": "货运部",
+            "items": [
+                {
+                    "mtrlname": "货物\产品名称",
+                    "unit": "计量单位(例如张,台,件等)",
+                    "saleqty": "数量",
+                    "enprice": "单价",
+                    "rebate": "折扣",
+                }
+            ],
+        }
+
+    def extraction_rules(self) -> str:
+        return """
+        """

+ 36 - 0
core/document_processor/templates/template_registry.py

@@ -0,0 +1,36 @@
+from typing import Dict, Type
+from core.document_processor.templates.base_template import DocumentTemplate
+from core.document_processor.templates.invoice_template import InvoiceTemplate
+from core.document_processor.templates.saletask_template import SaleTaskTemplate
+from core.document_processor.templates.cusamt_template import CusAmtTemplate
+
+
+class TemplateRegistry:
+    """模板注册管理器"""
+
+    _templates: Dict[str, Type[DocumentTemplate]] = {}
+
+    @classmethod
+    def register(cls, template_class: Type[DocumentTemplate]):
+        """注册模板"""
+        instance = template_class()
+        cls._templates[instance.template_name] = template_class
+        return template_class
+
+    @classmethod
+    def get_template(cls, template_name: str) -> DocumentTemplate:
+        """获取模板实例"""
+        if template_name not in cls._templates:
+            raise ValueError(f"模板不存在: {template_name}")
+        return cls._templates[template_name]()
+
+    @classmethod
+    def list_templates(cls) -> Dict[str, str]:
+        """列出所有可用模板"""
+        return {name: cls._templates[name]().description for name in cls._templates}
+
+
+# 注册内置模板
+TemplateRegistry.register(InvoiceTemplate)
+TemplateRegistry.register(SaleTaskTemplate)
+TemplateRegistry.register(CusAmtTemplate)

+ 29 - 0
core/lifespan_manager.py

@@ -0,0 +1,29 @@
+from contextlib import asynccontextmanager
+from fastapi import FastAPI
+from core.agent_manager import agent_manager
+from utils.registration_manager import registration_manager
+from utils.logger import chat_logger
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    """应用生命周期管理"""
+    try:
+        # 启动时检查注册状态,但不阻止启动
+        registration_status = await registration_manager.check_registration()
+        if not registration_status:
+            chat_logger.warning("⚠️ 服务启动:注册检查未通过,/chat接口将受限")
+        else:
+            chat_logger.info("✅ 服务启动:注册检查通过")
+
+        await agent_manager.initialize()
+        chat_logger.info("🚀🚀 AI助手服务启动")
+        yield
+    finally:
+        cleared_count = await agent_manager.shutdown()
+        chat_logger.info(f"🛑🛑🛑 AI助手服务停止,清理了 {cleared_count} 个Agent实例")
+
+
+def create_lifespan():
+    """创建生命周期管理器"""
+    return lifespan

+ 134 - 0
core/ocr_service.py

@@ -0,0 +1,134 @@
+import base64
+import requests
+from typing import Dict, Any, Optional
+from utils.logger import chat_logger
+import aiohttp
+import asyncio
+
+
+class PaddleOCRService:
+    """PaddleOCR API服务封装"""
+
+    def __init__(self, api_url: str, token: str):
+        self.api_url = api_url
+        self.token = token
+        self.headers = {
+            "Authorization": f"token {token}",
+            "Content-Type": "application/json",
+        }
+
+    async def recognize_image_async(
+        self,
+        image_bytes: bytes,
+        file_type: int = 1,  # 1: 图片, 0: PDF
+        use_doc_orientation_classify: bool = False,
+        use_doc_unwarping: bool = False,
+        use_textline_orientation: bool = False,
+    ) -> Dict[str, Any]:
+        """异步调用OCR API"""
+
+        # 同步调用(如果不需要异步,可以用requests)
+        return self.recognize_image_sync(
+            image_bytes=image_bytes,
+            file_type=file_type,
+            use_doc_orientation_classify=use_doc_orientation_classify,
+            use_doc_unwarping=use_doc_unwarping,
+            use_textline_orientation=use_textline_orientation,
+        )
+
+    def recognize_image_sync(
+        self,
+        image_bytes: bytes,
+        file_type: int = 1,
+        use_doc_orientation_classify: bool = False,
+        use_doc_unwarping: bool = False,
+        use_textline_orientation: bool = False,
+    ) -> Dict[str, Any]:
+        """同步调用OCR API"""
+
+        try:
+            # 转换为base64
+            file_data = base64.b64encode(image_bytes).decode("ascii")
+
+            payload = {
+                "file": file_data,
+                "fileType": file_type,
+                "useDocOrientationClassify": use_doc_orientation_classify,
+                "useDocUnwarping": use_doc_unwarping,
+                "useTextlineOrientation": use_textline_orientation,
+            }
+
+            response = requests.post(
+                self.api_url, json=payload, headers=self.headers, timeout=30
+            )
+
+            if response.status_code != 200:
+                chat_logger.error(
+                    f"OCR API调用失败: {response.status_code}, {response.text}"
+                )
+                raise Exception(f"OCR识别失败: {response.status_code}")
+
+            result = response.json()
+
+            if "result" not in result:
+                chat_logger.error(f"OCR返回格式错误: {result}")
+                raise Exception("OCR返回格式错误")
+
+            return result["result"]
+
+        except requests.exceptions.Timeout:
+            chat_logger.error("OCR API调用超时")
+            raise Exception("OCR识别超时")
+        except Exception as e:
+            chat_logger.error(f"OCR识别异常: {str(e)}")
+            raise
+
+    def extract_text_from_result(self, json_data: Dict) -> str:
+        """从OCR结果中提取文本(格式化)"""
+
+        try:
+            ocr_result = json_data["ocrResults"][0]["prunedResult"]
+            texts = ocr_result["rec_texts"]
+            scores = ocr_result["rec_scores"]
+            boxes = ocr_result["rec_boxes"]
+
+            ocr_text = "识别文本 | 识别框坐标"
+            ocr_text += "\r\n" + ("-" * 50)
+            for text, score, box in zip(texts, scores, boxes):
+                if score > 0.5:
+                    ocr_text += f"\r\n{text} | {box}"
+
+            # # 按置信度过滤并排序
+            # filtered_results = []
+            # for i, (text, score) in enumerate(zip(texts, scores)):
+            #     if score > 0.5:  # 置信度阈值
+            #         filtered_results.append(
+            #             {
+            #                 "index": i,
+            #                 "text": text,
+            #                 "score": float(score),
+            #                 "box": (
+            #                     ocr_data["rec_boxes"][i]
+            #                     if i < len(ocr_data["rec_boxes"])
+            #                     else None
+            #                 ),
+            #             }
+            #         )
+
+            # # 按位置排序(从上到下,从左到右)
+            # filtered_results.sort(
+            #     key=lambda x: (
+            #         x["box"][0][1] if x["box"] else 0,  # y坐标
+            #         x["box"][0][0] if x["box"] else 0,  # x坐标
+            #     )
+            # )
+
+            # # 拼接为文本
+            # ocr_text = "\n".join([item["text"] for item in filtered_results])
+
+            # chat_logger.info(f"OCR识别成功,识别到{len(filtered_results)}个文本块")
+            return ocr_text
+
+        except Exception as e:
+            chat_logger.error(f"提取OCR文本失败: {str(e)}")
+            raise

+ 278 - 0
cythonize.py

@@ -0,0 +1,278 @@
+# cythonize.py
+import os
+import sys
+import shutil
+from pathlib import Path
+import subprocess
+import time
+
+
+def print_banner():
+    """打印标题"""
+    print("=" * 60)
+    print("       LongjoeAgent - Cython 编译工具")
+    print("=" * 60)
+
+
+def check_environment():
+    """检查环境"""
+    print("\n🔍 检查环境...")
+
+    # 检查Python
+    if sys.version_info < (3, 7):
+        print(f"❌ 需要Python 3.7+,当前: {sys.version}")
+        return False
+
+    # 检查依赖
+    try:
+        import Cython
+
+        print(f"✅ Cython {Cython.__version__}")
+    except ImportError:
+        print("❌ 未安装Cython: pip install cython")
+        return False
+
+    try:
+        import setuptools
+
+        print(f"✅ setuptools {setuptools.__version__}")
+    except ImportError:
+        print("❌ 未安装setuptools: pip install setuptools wheel")
+        return False
+
+    return True
+
+
+def clean_build():
+    """清理构建目录"""
+    print("\n🧹 清理之前构建...")
+
+    dirs_to_remove = ["build", "dist", "cython_build", "LongjoeAgent.egg-info"]
+
+    for dir_name in dirs_to_remove:
+        dir_path = Path(dir_name)
+        if dir_path.exists():
+            try:
+                shutil.rmtree(dir_path)
+                print(f"  已删除: {dir_name}")
+            except:
+                pass
+
+    # 清理.c文件
+    for c_file in Path(".").rglob("*.c"):
+        try:
+            c_file.unlink()
+        except:
+            pass
+
+
+def compile_project():
+    """编译项目"""
+    print("\n⚡ 开始编译...")
+    start_time = time.time()
+
+    # 运行setup.py
+    result = subprocess.run(
+        [sys.executable, "setup.py"], capture_output=True, text=True, encoding="utf-8"
+    )
+
+    if result.returncode != 0:
+        print("❌ 编译失败!")
+        if result.stderr:
+            # 显示关键错误信息
+            lines = result.stderr.split("\n")
+            for line in lines:
+                if "error" in line.lower() or "Error" in line:
+                    print(f"  {line}")
+        return False
+
+    end_time = time.time()
+    print(f"✅ 编译完成 ({end_time - start_time:.1f}秒)")
+
+    # 显示编译结果
+    show_compile_results()
+
+    return True
+
+
+def show_compile_results():
+    """显示编译结果"""
+    compiled_dir = Path("build") / "compiled"
+
+    if not compiled_dir.exists():
+        print("❌ 编译目录不存在")
+        return
+
+    # 统计编译文件
+    pyd_files = list(compiled_dir.rglob("*.pyd"))
+    so_files = list(compiled_dir.rglob("*.so"))
+
+    if pyd_files:
+        print(f"\n📦 生成 {len(pyd_files)} 个 .pyd 文件:")
+        for f in pyd_files[:5]:
+            print(f"  • {f.relative_to(compiled_dir)}")
+        if len(pyd_files) > 5:
+            print(f"  ... 和 {len(pyd_files) - 5} 个其他文件")
+    elif so_files:
+        print(f"\n📦 生成 {len(so_files)} 个 .so 文件:")
+        for f in so_files[:5]:
+            print(f"  • {f.relative_to(compiled_dir)}")
+        if len(so_files) > 5:
+            print(f"  ... 和 {len(so_files) - 5} 个其他文件")
+    else:
+        print("❌ 未找到编译文件")
+
+
+def create_deployment():
+    """创建部署包"""
+    print("\n📦 创建部署包...")
+
+    compiled_dir = Path("build") / "compiled"
+    deploy_dir = Path("build") / "deploy"
+
+    if not compiled_dir.exists():
+        print("❌ 编译目录不存在")
+        return
+
+    # 清理旧的部署包
+    if deploy_dir.exists():
+        shutil.rmtree(deploy_dir)
+
+    # 复制编译文件
+    deploy_dir.mkdir(parents=True, exist_ok=True)
+    for src_file in compiled_dir.rglob("*"):
+        if src_file.is_file():
+            rel_path = src_file.relative_to(compiled_dir)
+            dst_file = deploy_dir / rel_path
+            dst_file.parent.mkdir(parents=True, exist_ok=True)
+            shutil.copy2(src_file, dst_file)
+
+    # 复制必要文件
+    essential_files = [
+        "app.py",
+        ".env.example",
+        "requirements.txt",
+        ".registration.example",
+    ]
+    for file_name in essential_files:
+        src = Path(file_name)
+        if src.exists():
+            dst = deploy_dir / file_name
+            shutil.copy2(src, dst)
+            print(f"  复制: {file_name}")
+
+    # 复制 config 目录下的 JSON 配置文件
+    config_src = Path("config")
+    if config_src.exists():
+        config_dst = deploy_dir / "config"
+        config_dst.mkdir(exist_ok=True)
+
+        # 复制所有 JSON 文件
+        json_files = list(config_src.rglob("*.json"))
+        for json_file in json_files:
+            if json_file.is_file():
+                rel_path = json_file.relative_to(config_src)
+                dst_file = config_dst / rel_path
+                dst_file.parent.mkdir(parents=True, exist_ok=True)
+                shutil.copy2(json_file, dst_file)
+                print(f"  复制配置文件: config/{rel_path}")
+
+        if json_files:
+            print(f"✅ 已复制 {len(json_files)} 个配置文件")
+
+    # 创建日志目录
+    (deploy_dir / "chat_logs").mkdir(exist_ok=True)
+
+    # 创建启动脚本
+    create_start_scripts(deploy_dir)
+
+    print(f"✅ 部署包: {deploy_dir}")
+
+
+def create_start_scripts(deploy_dir):
+    """创建启动脚本"""
+
+    # Windows
+    bat_content = """@echo off
+chcp 65001 >nul
+echo === LongjoeAgent ===
+echo.
+
+REM 检查Python
+python --version >nul 2>&1
+if errorlevel 1 (
+    echo 错误: 未找到Python
+    pause
+    exit /b 1
+)
+
+REM 启动
+echo 启动服务...
+python app.py
+
+pause
+"""
+
+    bat_file = deploy_dir / "start.bat"
+    bat_file.write_text(bat_content, encoding="utf-8")
+
+    # Linux/Mac
+    sh_content = """#!/bin/bash
+echo "=== LongjoeAgent ==="
+
+# 检查Python
+if ! command -v python3 &> /dev/null; then
+    echo "错误: 未找到Python3"
+    exit 1
+fi
+
+# 启动
+echo "启动服务..."
+python3 app.py
+"""
+
+    sh_file = deploy_dir / "start.sh"
+    sh_file.write_text(sh_content, encoding="utf-8")
+
+    if sys.platform != "win32":
+        os.chmod(sh_file, 0o755)
+
+    print("  创建启动脚本: start.bat, start.sh")
+
+
+def main():
+    """主函数"""
+    print_banner()
+
+    # 1. 检查环境
+    if not check_environment():
+        sys.exit(1)
+
+    # 2. 清理
+    clean_build()
+
+    # 3. 编译
+    if not compile_project():
+        sys.exit(1)
+
+    # 4. 创建部署包
+    create_deployment()
+
+    # 5. 完成
+    print("\n" + "=" * 60)
+    print("🎉 编译完成!")
+    print("=" * 60)
+
+    print("\n📁 文件结构:")
+    print("  • 源代码: 保持不变")
+    print("  • 编译文件: build/compiled/")
+    print("  • 部署包:   build/deploy/")
+
+    print("\n🚀 使用方法:")
+    print("  开发: python app.py")
+    print("  部署: 复制 build/deploy/ 到服务器")
+    print("  运行: 执行 start.sh 或 start.bat")
+
+
+if __name__ == "__main__":
+    main()

+ 30 - 0
middlewares/registration_middleware.py

@@ -0,0 +1,30 @@
+from fastapi import Request, HTTPException
+from fastapi.responses import JSONResponse
+from utils.registration_manager import registration_manager
+from utils.logger import chat_logger
+
+
+async def registration_check_middleware(request: Request, call_next):
+    """
+    注册检查中间件
+    对需要注册验证的接口进行检查
+    """
+    # 定义需要检查注册状态的接口列表
+    protected_paths = ["/chat", "/message_create_bill", "/ocr_create_bill", "/ocr"]
+    # 只拦截POST /chat请求
+    if request.url.path in protected_paths and request.method == "POST":
+        if not await registration_manager.check_registration():
+            chat_logger.warning(f"拒绝未注册访问: {request.client.host}")
+
+            if request.url.path == "/chat":
+                response_data = {
+                    "final_answer": "服务未注册或注册已过期,请联系管理员。"
+                }
+            else:
+                response_data = "服务未注册或注册已过期,请联系管理员。"
+
+            return JSONResponse(status_code=403, content=response_data)
+
+    # 其他请求直接放行
+    response = await call_next(request)
+    return response

+ 14 - 0
requirements.txt

@@ -0,0 +1,14 @@
+aiohttp==3.13.2
+cython==3.2.3
+fastapi==0.128.0
+langchain==1.2.7
+langchain_core==1.2.7
+langchain_openai==1.1.7
+langgraph==1.0.7
+pycryptodome==3.23.0
+pydantic==2.12.5
+python-dotenv==1.2.1
+regex==2025.11.3
+Requests==2.32.5
+setuptools==80.9.0
+uvicorn==0.40.0

+ 167 - 0
setup.py

@@ -0,0 +1,167 @@
+# setup.py
+from setuptools import setup, Extension
+from Cython.Build import cythonize
+from Cython.Distutils import build_ext
+import os
+import sys
+from pathlib import Path
+import re
+
+# 获取当前目录
+current_dir = Path(__file__).parent
+COMPILED_DIR = current_dir / "build" / "compiled"
+
+# 定义要排除的文件
+EXCLUDE_FILES = {
+    "app.py",  # 不编译主应用文件
+    "setup.py",  # 不编译本文件
+    "cythonize.py",  # 不编译辅助脚本
+    "check_deploy.py",  # 不编译部署检查
+}
+
+# 定义要排除的目录
+EXCLUDE_DIRS = {
+    "build",
+    "dist",
+    "__pycache__",
+    ".git",
+    ".vscode",
+    ".idea",
+    "venv",
+    "env",
+    "node_modules",
+    "chat_logs",
+    "logs",
+    "tests",
+    "test",
+    "cython_build",  # Cython临时目录
+}
+
+
+def should_exclude_dir(dir_name):
+    """判断是否应该排除目录"""
+    return dir_name in EXCLUDE_DIRS or dir_name.startswith(".")
+
+
+def find_py_files(base_dir):
+    """查找所有需要编译的Python文件"""
+    py_files = []
+
+    for root, dirs, files in os.walk(base_dir):
+        # 排除不需要的目录
+        dirs[:] = [d for d in dirs if not should_exclude_dir(d)]
+
+        for file in files:
+            if file.endswith(".py") and file not in EXCLUDE_FILES:
+                full_path = os.path.join(root, file)
+
+                # 排除build目录下的文件
+                if "build" in full_path.split(os.sep):
+                    continue
+
+                py_files.append(full_path)
+
+    return py_files
+
+
+def create_extension_modules():
+    """创建Cython扩展模块列表"""
+    extensions = []
+    py_files = find_py_files(current_dir)
+
+    for py_file in py_files:
+        # 转换为相对路径
+        rel_path = os.path.relpath(py_file, current_dir)
+
+        # 生成模块名(去掉.py)
+        module_name = rel_path[:-3].replace(os.sep, ".")
+
+        # 创建Extension
+        extension = Extension(
+            module_name,
+            [py_file],
+            extra_compile_args=[
+                "/O2" if sys.platform == "win32" else "-O3",
+                "/std:c++17" if sys.platform == "win32" else "-std=c++17",
+            ],
+            language="c",
+        )
+
+        extensions.append(extension)
+
+    return extensions
+
+
+def main():
+    """主编译函数"""
+    print(f"编译目标目录: {COMPILED_DIR}")
+
+    # 确保编译目录存在
+    COMPILED_DIR.mkdir(parents=True, exist_ok=True)
+
+    # 创建扩展模块
+    extensions = create_extension_modules()
+
+    if not extensions:
+        print("❌ 未找到需要编译的Python文件")
+        return
+
+    print(f"找到 {len(extensions)} 个Python文件需要编译:")
+    for i, ext in enumerate(extensions[:10], 1):  # 只显示前10个
+        print(f"  {i:2}. {ext.name}")
+    if len(extensions) > 10:
+        print(f"  ... 和 {len(extensions) - 10} 个其他文件")
+
+    # 配置编译选项
+    compiler_directives = {
+        "language_level": 3,
+        "boundscheck": False,  # 关闭边界检查
+        "wraparound": False,  # 关闭环绕检查
+        "initializedcheck": False,  # 关闭初始化检查
+        "nonecheck": False,  # 关闭None检查
+        "overflowcheck": False,  # 关闭溢出检查
+        "cdivision": True,  # 使用C除法
+        "infer_types": True,  # 类型推断
+        "optimize.use_switch": True,  # 优化switch语句
+    }
+
+    # 编译
+    cythonized = cythonize(
+        extensions,
+        compiler_directives=compiler_directives,
+        nthreads=0,  # 0表示自动使用所有核心
+        build_dir="./cython_build",  # 临时构建目录
+    )
+
+    # 修复:使用自定义BuildExt类
+    class CustomBuildExt(build_ext):
+        def initialize_options(self):
+            super().initialize_options()
+            # 设置输出目录
+            self.build_lib = str(COMPILED_DIR)
+
+        def get_ext_fullpath(self, ext_name):
+            """获取扩展的完整路径,确保输出到指定目录"""
+            # 首先调用父类方法
+            filename = self.get_ext_filename(ext_name)
+
+            # 确保输出到编译目录
+            return os.path.join(self.build_lib, filename)
+
+    # 修复:使用setuptools的setup,但通过cmdclass和options控制
+    setup(
+        name="LongjoeAgent",
+        ext_modules=cythonized,
+        cmdclass={"build_ext": CustomBuildExt},
+        options={
+            "build_ext": {
+                "build_lib": str(COMPILED_DIR),
+                "inplace": False,  # 不在原地构建
+            }
+        },
+        script_args=["build_ext"],  # 只构建扩展
+    )
+
+
+if __name__ == "__main__":
+    main()

+ 237 - 0
tools/base_tool.py

@@ -0,0 +1,237 @@
+import re
+import html
+import requests
+import json
+from typing import List, Dict, Any, Optional, Callable
+from pathlib import Path
+
+
+def html_to_text(html_content: str) -> str:
+    """HTML转文本"""
+    if not html_content:
+        return ""
+    clean = re.compile(r"<[^>]+>")
+    text = clean.sub("", html_content)
+    text = html.unescape(text)
+    return re.sub(r"\s+", " ", text).strip()
+
+
+def get_unique_match_count(search_text: str, filter_words: List[str]) -> int:
+    """获取唯一匹配计数"""
+    sorted_keywords = sorted(filter_words, key=len, reverse=True)
+    match_count = 0
+    remaining_text = search_text.lower()
+
+    for keyword in sorted_keywords:
+        kw_lower = keyword.lower()
+        if kw_lower in remaining_text:
+            match_count += 1
+            remaining_text = remaining_text.replace(kw_lower, "", 1)
+
+    return match_count
+
+
+def call_csharp_api(
+    backend_url: str, token: str, uoName: str, functionName: str, SParms: dict
+) -> str:
+    """调用C# API的通用方法"""
+    print(f"🔧 API调用调试信息:")
+    print(f"  - 后端地址: {backend_url}")
+    print(f"  - Token: {'已配置' if token else '未配置'}")
+    print(f"  - 功能: {functionName}")
+    print(f"  - 参数: {SParms}")
+
+    if not backend_url or not token:
+        error_msg = f"错误:未配置后端地址或认证令牌。后端: {backend_url or '未配置'}, Token: {'已配置' if token else '未配置'}"
+        print(f"❌ {error_msg}")
+        return error_msg
+
+    headers = {
+        "Accept": "application/json, text/plain, */*",
+        "Content-Type": "application/json",
+        "X-TOKEN": token,
+        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
+    }
+
+    payload = {
+        "token": token,
+        "CList": [
+            {
+                "uoName": uoName,
+                "functionName": functionName,
+                "SParms": SParms,
+                "ifcommit": True,
+                "returnStrList": [],
+            }
+        ],
+        "language": "zh-cn",
+    }
+
+    try:
+        print(f"🌐 发送API请求到: {backend_url}")
+        response = requests.post(backend_url, headers=headers, json=payload, timeout=30)
+        print(f"📡 响应状态码: {response.status_code}")
+
+        if response.status_code == 200:
+            data = response.json()
+            return process_api_response(data)
+        else:
+            error_msg = f"API请求失败,状态码: {response.status_code}"
+            print(f"❌ {error_msg}")
+            return error_msg
+    except Exception as e:
+        error_msg = f"API调用异常: {str(e)}"
+        print(f"❌ {error_msg}")
+        return error_msg
+
+
+def process_api_response(data: Dict[str, Any]) -> str:
+    """处理API响应"""
+    try:
+        inner_json_str = data.get("reJob", {}).get("0", "{}")
+        inner_data = json.loads(inner_json_str)
+
+        if "err_msg" in inner_data:
+            return f"API返回错误: {inner_data['err_msg']}"
+
+        if "data" in inner_data:
+            data_list = inner_data["data"]
+            if not data_list:
+                return "NO_DATA"
+
+            if isinstance(data_list[0], dict):
+                headers = list(data_list[0].keys())
+                result = [",".join(headers)]
+
+                for row in data_list:
+                    result.append(",".join([str(row.get(h, "")) for h in headers]))
+
+                return "\n".join(result)
+
+        return json.dumps(data, ensure_ascii=False)
+    except Exception as e:
+        return f"响应处理错误: {str(e)}"
+
+
+# 工具配置管理函数
+def load_tool_config(
+    config_path: Path, get_default_config: Optional[Callable] = None
+) -> Dict[str, Any]:
+    """
+    加载工具配置的通用函数
+
+    Args:
+        config_path: 配置文件路径
+        get_default_config: 获取默认配置的回调函数,如果不提供则返回空字典
+    """
+    if not config_path.exists():
+        print(f"警告: 配置文件不存在: {config_path}")
+        if get_default_config:
+            return get_default_config()
+        return {}
+
+    try:
+        with open(config_path, "r", encoding="utf-8") as f:
+            return json.load(f)
+    except json.JSONDecodeError as e:
+        print(f"错误: 配置文件格式不正确: {e}")
+        if get_default_config:
+            return get_default_config()
+        return {}
+    except Exception as e:
+        print(f"错误: 读取配置文件失败: {e}")
+        if get_default_config:
+            return get_default_config()
+        return {}
+
+
+def assemble_tool_description(tool_config: Dict[str, Any]) -> str:
+    """组装工具描述,将所有键值组合成一个完整的字符串"""
+    if not tool_config:
+        return ""
+
+    description_parts = []
+
+    # 基础描述
+    if "基础描述" in tool_config:
+        description_parts.append(tool_config["基础描述"])
+
+    # 功能说明
+    if "功能说明" in tool_config:
+        description_parts.append(f"\n功能: {tool_config['功能说明']}")
+
+    # 入参说明
+    if "入参说明" in tool_config:
+        if isinstance(tool_config["入参说明"], dict):
+            description_parts.append("\n参数:")
+            for param, desc in tool_config["入参说明"].items():
+                description_parts.append(f"  {param}: {desc}")
+        else:
+            description_parts.append(f"\n参数说明: {tool_config['入参说明']}")
+
+    # 返回值说明
+    if "返回值说明" in tool_config:
+        if isinstance(tool_config["返回值说明"], dict):
+            description_parts.append("\n返回:")
+            for key, value in tool_config["返回值说明"].items():
+                if isinstance(value, list):
+                    description_parts.append(f"  {key}:")
+                    for item in value:
+                        description_parts.append(f"    - {item}")
+                else:
+                    description_parts.append(f"  {key}: {value}")
+        else:
+            description_parts.append(f"\n返回结果: {tool_config['返回值说明']}")
+
+    # 输出格式要求
+    if "输出格式要求" in tool_config:
+        if isinstance(tool_config["输出格式要求"], list):
+            description_parts.append("\n输出要求:")
+            for requirement in tool_config["输出格式要求"]:
+                description_parts.append(f"  - {requirement}")
+        else:
+            description_parts.append(f"\n注意: {tool_config['输出格式要求']}")
+
+    # 使用示例
+    if "使用示例" in tool_config:
+        description_parts.append(f"\n示例: {tool_config['使用示例']}")
+
+    return "\n".join(description_parts)
+
+
+def get_tool_prompt(
+    tool_name: str, default_config_func: Optional[Callable] = None
+) -> str:
+    """
+    获取工具的完整提示词
+
+    Args:
+        tool_name: 工具名称
+        default_config_func: 获取默认配置的函数
+    """
+    # 计算配置文件路径
+    current_file = Path(__file__)
+    config_path = current_file.parent.parent / "config" / "tool_config.json"
+
+    # 加载配置
+    config = load_tool_config(config_path, default_config_func)
+
+    # 获取工具配置
+    tool_config = config.get(tool_name, {})
+
+    # 如果配置为空且提供了默认配置函数,使用默认配置
+    if not tool_config and default_config_func:
+        default_config = default_config_func()
+        if isinstance(default_config, dict) and tool_name in default_config:
+            tool_config = default_config[tool_name]
+        elif isinstance(default_config, dict) and not default_config:
+            # 如果返回的是整个配置字典
+            tool_config = default_config
+        else:
+            tool_config = {}
+
+    # 组装描述
+    if tool_config:
+        return assemble_tool_description(tool_config)
+    else:
+        return f"执行 {tool_name} 功能"

+ 103 - 0
tools/knowledge_tools.py

@@ -0,0 +1,103 @@
+from langchain.tools import tool
+from typing import List
+import requests
+import json
+import os
+from .base_tool import html_to_text, get_unique_match_count
+
+
+@tool
+def get_knowledge_list(filter_words: List[str], match_limit: int = 3) -> str:
+    """获取知识库列表,返回知识库DocID和标题DocName以及关键字keyword的列表,根据关键字及标题,按用户问题筛选出要用到的文章后, 通过DocID获取文章内容,另外会提供访问文章正文的工具,最终返回的文章数最好在10篇以内
+    Args:
+        filter_words: 筛选知识库文章的关键词列表,只要符合其中任意一个关键词,就会返回该文章
+        match_limit: 筛选知识库文章的匹配计数,默认值为3,即只要符合3个关键词,就会返回该文章。但如果没有符合的文章或数量太少影响作答,可再次调用该工具,将match_limit减1,直到符合文章,但match_limit不能小于1
+    Returns:
+        一个包含知识库DocID、标题DocName和关键字keyword的列表,每个元素为字符串,字符串格式为"DocID:DocName|keyword",如果没有keyword,则格式为"DocID:DocName",每个元素之间用换行符分隔
+    """
+    print(f"正在查询知识库列表,筛选关键词:{filter_words} 匹配下限:{match_limit}")
+
+    kms_list_url = os.getenv("KMS_LIST_URL")
+    payload = {
+        "categorycodeList": [],
+        "ignoreTypeSub": False,
+        "ignoreStandardByTopic": True,
+    }
+
+    headers = {
+        "Accept": "application/json, text/plain, */*",
+        "Content-Type": "application/json",
+        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
+    }
+
+    try:
+        response = requests.post(
+            kms_list_url, headers=headers, json=payload, timeout=10
+        )
+        if response.status_code == 200:
+            data = response.json()
+            matched_lines = ""
+            for doc in data["docList"]:
+                doc_id = doc["DocID"]
+                doc_name = doc["DocName"]
+                doc_keywords = doc["keyword"]
+                search_text = f"{doc_name} {doc_keywords}".lower()
+
+                if not filter_words:
+                    line = (
+                        f"{doc_id}:{doc_name}|{doc_keywords}"
+                        if doc_keywords
+                        else f"{doc_id}:{doc_name}"
+                    )
+                    matched_lines += line + "\n"
+                else:
+                    match_count = get_unique_match_count(search_text, filter_words)
+                    if match_count >= match_limit:
+                        line = (
+                            f"{doc_id}:{doc_name}|{doc_keywords}"
+                            if doc_keywords
+                            else f"{doc_id}:{doc_name}"
+                        )
+                        matched_lines += line + "\n"
+
+            return matched_lines
+        else:
+            return f"请求失败,状态码: {response.status_code}"
+    except Exception as e:
+        return f"请求异常: {e}"
+
+
+@tool
+def get_knowledge_content(docid: str) -> str:
+    """获取知识库文章内容
+
+    Args:
+        docid: 知识库文章的DocID
+
+    Returns:
+        知识库文章内容
+    """
+    print(f"正在获取知识库文章内容,DocID: {docid}")
+
+    kms_view_url = os.getenv("KMS_VIEW_URL")
+    headers = {
+        "Accept": "application/json, text/plain, */*",
+        "Content-Type": "application/json",
+        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
+    }
+
+    try:
+        payload = {"docid": docid}
+        response = requests.post(
+            kms_view_url, headers=headers, json=payload, timeout=10
+        )
+        if response.status_code == 200:
+            data = response.json()
+            doc_html = data.get("DocHtml", "")
+            plain_text = html_to_text(doc_html)
+            print(f"已获取到ID: {docid}的文章内容,长度{len(plain_text)}")
+            return plain_text
+        else:
+            return f"请求失败,状态码: {response.status_code}"
+    except Exception as e:
+        return f"请求异常: {e}"

+ 47 - 0
tools/sale_tools.py

@@ -0,0 +1,47 @@
+from langchain.tools import tool
+from .base_tool import call_csharp_api, get_tool_prompt
+
+
+def get_sale_amt_default_config():
+    """get_sale_amt 工具的默认配置"""
+    return {
+        "get_sale_amt": {
+            "基础描述": "获取指定时间范围的销售金额,按月汇总",
+            "入参说明": {
+                "backend_url": "后端API地址",
+                "token": "认证令牌",
+                "funtion_name": "函数名称; get_sale_amt_by_month:按月汇总销售额; get_sale_amt_by_day:按天汇总销售额; get_sale_amt_by_produce:产品销售额; get_sale_amt_by_cus:客户销售额",
+                "firstdate": "开始日期,格式YYYY-MM-DD",
+                "lastdate": "结束日期,格式YYYY-MM-DD 23:59:59",
+            },
+            "返回值说明": {
+                "格式": "一个包含销售金额的字符串",
+            },
+            "输出格式要求": [
+                "以自然语言描述形式输出,不要使用表格",
+                "重复信息要总结归纳,精简显示",
+            ],
+        }
+    }
+
+
+tool_description = get_tool_prompt("get_sale_amt", get_sale_amt_default_config())
+
+
+def get_sale_amt_func(
+    backend_url: str, token: str, funtion_name: str, firstdate: str, lastdate: str
+) -> str:
+    """实际的函数实现"""
+    print(f"正在获取销售金额{funtion_name},时间范围:{firstdate} 至 {lastdate}")
+
+    return call_csharp_api(
+        backend_url,
+        token,
+        "sale_data_ai",
+        funtion_name,
+        {"arg_firstdate": firstdate, "arg_lastdate": lastdate},
+    )
+
+
+get_sale_amt_func.__doc__ = tool_description
+get_sale_amt = tool(get_sale_amt_func)

+ 234 - 0
tools/tool_factory.py

@@ -0,0 +1,234 @@
+import os
+import sys
+import importlib
+import inspect
+from pathlib import Path
+from typing import List
+from langchain.tools import BaseTool
+
+
+def get_all_tools() -> List[BaseTool]:
+    """
+    自动发现并返回所有工具 - 修复检查逻辑
+    """
+    tools = []
+
+    print("🛠️ 开始自动发现工具...")
+
+    # 获取项目根目录
+    project_root = Path(__file__).parent.parent
+    tools_dir = Path(__file__).parent
+
+    # 将项目根目录添加到Python路径
+    if str(project_root) not in sys.path:
+        sys.path.insert(0, str(project_root))
+
+    # 扫描工具文件
+    tool_files = []
+    for file_path in tools_dir.glob("*_tools.py"):
+        if file_path.is_file():
+            module_name = file_path.stem
+            tool_files.append(module_name)
+            print(f"📦 发现工具文件: {module_name}")
+
+    if not tool_files:
+        tool_files = ["knowledge_tools", "sale_tools"]
+        print("⚠️ 使用默认工具列表")
+
+    for module_name in tool_files:
+        try:
+            # 导入模块
+            full_module_path = f"tools.{module_name}"
+            module = importlib.import_module(full_module_path)
+            # print(f"✅ 加载模块: {module_name}")
+
+            # 查找工具 - 使用更全面的方法
+            tool_count = 0
+
+            # 方法1: 检查模块的所有属性
+            for attr_name in dir(module):
+                if attr_name.startswith("_"):
+                    continue
+
+                attr = getattr(module, attr_name)
+
+                # # 详细调试信息
+                # print(f"  🔍 检查 {attr_name}:")
+                # print(f"    类型: {type(attr)}")
+
+                # 检查是否是BaseTool实例
+                if isinstance(attr, BaseTool):
+                    tools.append(attr)
+                    tool_count += 1
+                    # print(f"  ✅ 发现BaseTool工具: {getattr(attr, 'name', attr_name)}")
+                    continue
+
+                # 检查是否是函数且具有工具属性
+                if callable(attr):
+                    # print(f"    是否有name属性: {hasattr(attr, 'name')}")
+                    # if hasattr(attr, "name"):
+                    #     print(f"    name值: {getattr(attr, 'name', '无')}")
+                    # print(f"    是否有description属性: {hasattr(attr, 'description')}")
+                    # if hasattr(attr, "description"):
+                    #     print(
+                    #         f"    description前50字: {getattr(attr, 'description', '')[:50]}"
+                    #     )
+
+                    # 检查是否是工具函数
+                    if is_tool_function(attr):
+                        tools.append(attr)
+                        tool_count += 1
+                        # print(f"  ✅ 发现工具函数: {attr_name}")
+
+            # 方法2: 检查模块的全局变量
+            print(f"  🔍 检查模块全局变量...")
+            for name, value in module.__dict__.items():
+                if name.startswith("_"):
+                    continue
+
+                if isinstance(value, BaseTool):
+                    if value not in tools:  # 避免重复添加
+                        tools.append(value)
+                        tool_count += 1
+                        # print(
+                        #     f"  ✅ 从全局变量发现BaseTool工具: {getattr(value, 'name', name)}"
+                        # )
+
+            if tool_count == 0:
+                # print(f"  ⚠️ 模块 {module_name} 中未发现工具")
+                # 尝试手动创建工具
+                manual_tools = create_tools_manually(module_name, module)
+                if manual_tools:
+                    tools.extend(manual_tools)
+                    tool_count = len(manual_tools)
+                    # print(f"  🔧 手动创建了 {tool_count} 个工具")
+            else:
+                print(f"  📊 模块 {module_name} 中发现 {tool_count} 个工具")
+
+        except Exception as e:
+            print(f"❌ 加载模块 {module_name} 失败: {e}")
+
+    print(f"🎯 总共发现 {len(tools)} 个工具")
+
+    # # 打印工具详情
+    # for i, tool in enumerate(tools):
+    #     tool_name = getattr(tool, "name", f"tool_{i+1}")
+    #     tool_desc = getattr(tool, "description", "无描述")
+    #     print(f"  {i+1}. {tool_name}: {tool_desc[:50]}...")
+
+    return tools
+
+
+def create_tools_manually(module_name: str, module) -> List[BaseTool]:
+    """手动创建工具 - 针对@tool装饰器的问题"""
+    tools = []
+
+    if module_name == "knowledge_tools":
+        # 手动导入并创建知识库工具
+        try:
+            from langchain.tools import tool
+
+            # 检查模块中是否有工具函数
+            if hasattr(module, "get_knowledge_list"):
+                func = getattr(module, "get_knowledge_list")
+                if callable(func):
+                    # 使用@tool装饰器重新创建工具
+                    tool_instance = tool(func)
+                    tools.append(tool_instance)
+                    print(f"  🔧 手动创建工具: get_knowledge_list")
+
+            if hasattr(module, "get_knowledge_content"):
+                func = getattr(module, "get_knowledge_content")
+                if callable(func):
+                    tool_instance = tool(func)
+                    tools.append(tool_instance)
+                    print(f"  🔧 手动创建工具: get_knowledge_content")
+
+        except Exception as e:
+            print(f"  ❌ 手动创建知识库工具失败: {e}")
+
+    elif module_name == "sale_tools":
+        # 手动创建销售工具
+        try:
+            from langchain.tools import tool
+
+            if hasattr(module, "get_sale_amt"):
+                func = getattr(module, "get_sale_amt")
+                if callable(func):
+                    tool_instance = tool(func)
+                    tools.append(tool_instance)
+                    print(f"  🔧 手动创建工具: get_sale_amt")
+
+        except Exception as e:
+            print(f"  ❌ 手动创建销售工具失败: {e}")
+
+    return tools
+
+
+def is_tool_instance(obj) -> bool:
+    """检查是否是工具实例"""
+    if not callable(obj):
+        return False
+
+    # 检查是否是BaseTool实例
+    if isinstance(obj, BaseTool):
+        return True
+
+    # 检查是否有工具的标准属性
+    tool_attrs = ["name", "description"]
+    if all(hasattr(obj, attr) for attr in tool_attrs):
+        return True
+
+    return False
+
+
+def is_tool_function(obj) -> bool:
+    """检查是否是工具函数"""
+    if not callable(obj):
+        return False
+
+    # 检查是否被@tool装饰
+    if hasattr(obj, "_is_tool") and getattr(obj, "_is_tool", False):
+        return True
+
+    # 检查是否有tool属性
+    if hasattr(obj, "tool"):
+        return True
+
+    # 检查是否是函数且具有工具属性
+    if callable(obj) and hasattr(obj, "name") and hasattr(obj, "description"):
+        return True
+
+    return False
+
+
+# 测试函数
+def test_tool_detection():
+    """测试工具检测"""
+    print("🧪 测试工具检测...")
+
+    # 导入一个模块测试
+    import tools.knowledge_tools as kt
+
+    print(f"模块: {kt}")
+
+    for attr_name in dir(kt):
+        if not attr_name.startswith("_"):
+            attr = getattr(kt, attr_name)
+            print(f"\n🔍 检查 {attr_name}:")
+            print(f"  类型: {type(attr)}")
+            print(f"  可调用: {callable(attr)}")
+
+            if callable(attr):
+                print(f"  是否有name属性: {hasattr(attr, 'name')}")
+                if hasattr(attr, "name"):
+                    print(f"  name值: {getattr(attr, 'name', '无')}")
+
+                print(f"  是否有description属性: {hasattr(attr, 'description')}")
+                if hasattr(attr, "description"):
+                    print(
+                        f"  description前50字: {getattr(attr, 'description', '')[:50]}"
+                    )
+
+                print(f"  是否被@tool装饰: {is_tool_function(attr)}")
+                print(f"  是否是工具实例: {is_tool_instance(attr)}")

+ 48 - 0
tools/ware_tools.py

@@ -0,0 +1,48 @@
+from langchain.tools import tool
+from .base_tool import call_csharp_api, get_tool_prompt
+
+
+def get_mtrlware_default_config():
+    """get_mtrlware_data 工具的默认配置"""
+    return {
+        "get_mtrlware_data": {
+            "基础描述": "获取指定物料的库存信息",
+            "入参说明": {
+                "backend_url": "后端API地址",
+                "token": "认证令牌",
+                "mtrlname": "物料名称",
+            },
+            "返回值说明": {
+                "格式": "一个包含物料库存数据的字符串",
+                "字段含义": "mtrlcode:物料编码, mtrlname:物料名称, storagename:仓库名称, noallocqty:库存数量, unit:单位, noauditingqty:已开单数量, notauditnoallocqty:未开单数量, pzinfo:配置信息, buydays:采购周期天数",
+            },
+            "输出格式要求": [
+                "以自然语言描述形式输出,不要使用表格",
+                "重复信息要总结归纳,精简显示",
+            ],
+        }
+    }
+
+
+# 获取工具描述
+tool_description = get_tool_prompt("get_mtrlware_data", get_mtrlware_default_config)
+
+
+def get_mtrlware_data_func(backend_url: str, token: str, mtrlname: str) -> str:
+    """实际的函数实现"""
+    print(f"正在获取物料{mtrlname}的库存数据")
+
+    return call_csharp_api(
+        backend_url,
+        token,
+        "ware_data_ai",
+        "get_mtrlware_data",
+        {"arg_mtrlname": mtrlname},
+    )
+
+
+# 3. 设置文档字符串
+get_mtrlware_data_func.__doc__ = tool_description
+
+# 4. 最后应用装饰器
+get_mtrlware_data = tool(get_mtrlware_data_func)

+ 235 - 0
utils/device_id.py

@@ -0,0 +1,235 @@
+# device_id.py
+import socket
+import uuid
+from typing import Optional, Dict, List
+import winreg
+import subprocess
+import os
+import hashlib
+import socket
+import uuid
+import platform
+import re
+
+
+def get_windows_device_id() -> Optional[str]:
+    """Windows: 获取BIOS UUID"""
+    uuid = None
+
+    # 方法1: 使用wmic命令
+    try:
+        result = subprocess.run(
+            ["wmic", "csproduct", "get", "uuid"],
+            capture_output=True,
+            text=True,
+            shell=True,
+            creationflags=subprocess.CREATE_NO_WINDOW,
+        )
+
+        if result.returncode == 0:
+            # 按行分割,过滤空行
+            lines = [
+                line.strip()
+                for line in result.stdout.strip().split("\n")
+                if line.strip()
+            ]
+
+            # 找到UUID行
+            for line in lines:
+                if len(line) == 36 and line.count("-") == 4:
+                    uuid = line
+                    break
+    except Exception as e:
+        print(f"WMIC获取失败: {e}")
+
+    # 方法2: 如果wmic失败,使用PowerShell
+    if not uuid:
+        try:
+            result = subprocess.run(
+                [
+                    "powershell",
+                    "-Command",
+                    "(Get-WmiObject -Class Win32_ComputerSystemProduct).UUID",
+                ],
+                capture_output=True,
+                text=True,
+                shell=True,
+            )
+
+            if result.returncode == 0:
+                uuid = result.stdout.strip()
+        except Exception as e:
+            print(f"PowerShell获取失败: {e}")
+
+    # 方法3: 如果PowerShell也失败,使用注册表
+    if not uuid:
+        try:
+            with winreg.OpenKey(
+                winreg.HKEY_LOCAL_MACHINE,
+                r"SOFTWARE\Microsoft\Windows NT\CurrentVersion",
+            ) as key:
+                uuid, _ = winreg.QueryValueEx(key, "ProductId")
+        except Exception as e:
+            print(f"注册表获取失败: {e}")
+
+    return uuid
+
+
+def get_linux_machine_id() -> Optional[str]:
+    """Linux: 获取机器ID"""
+    machine_id_paths = [
+        "/etc/machine-id",  # systemd系统
+        "/var/lib/dbus/machine-id",  # 老版本
+    ]
+
+    for path in machine_id_paths:
+        if os.path.exists(path):
+            try:
+                with open(path, "r") as f:
+                    machine_id = f.read().strip()
+                    if machine_id:
+                        return machine_id
+            except:
+                continue
+    return None
+
+
+def get_linux_system_uuid() -> Optional[str]:
+    """Linux: 获取系统UUID"""
+    try:
+        if os.path.exists("/sys/devices/virtual/dmi/id/product_uuid"):
+            with open("/sys/devices/virtual/dmi/id/product_uuid", "r") as f:
+                uuid = f.read().strip()
+                if uuid and len(uuid) == 36:
+                    return uuid
+    except:
+        pass
+    return None
+
+
+def get_mac_serial_number() -> Optional[str]:
+    """macOS: 获取序列号"""
+    try:
+        result = subprocess.run(
+            ["system_profiler", "SPHardwareDataType"], capture_output=True, text=True
+        )
+
+        if result.returncode == 0:
+            output = result.stdout
+            for line in output.split("\n"):
+                if "Serial Number" in line or "序列号" in line:
+                    parts = line.split(":")
+                    if len(parts) > 1:
+                        serial = parts[1].strip()
+                        if serial:
+                            return serial
+    except:
+        pass
+    return None
+
+
+def get_platform_device_id() -> Optional[str]:
+    """获取平台设备ID"""
+    system = platform.system()
+
+    if system == "Windows":
+        return get_windows_device_id()
+    elif system == "Linux":
+        # 优先使用系统UUID
+        uuid = get_linux_system_uuid()
+        if uuid:
+            return uuid
+        # 回退到machine-id
+        return get_linux_machine_id()
+    elif system == "Darwin":  # macOS
+        return get_mac_serial_number()
+
+    return None
+
+
+def get_device_id():
+    """获取设备特征码"""
+    device_id = get_platform_device_id()
+    if device_id:
+        # 清理设备ID
+        clean_id = re.sub(r"[^a-zA-Z0-9]", "", device_id).upper()
+        fingerprint = hashlib.sha256(clean_id.encode()).hexdigest()[:24].upper()
+    else:
+        # 回退方案
+        info = {
+            "hostname": socket.gethostname(),
+            "system": platform.system(),
+            "machine": platform.machine(),
+        }
+        fingerprint_data = "|".join([f"{k}:{v}" for k, v in sorted(info.items())])
+        fingerprint = hashlib.sha256(fingerprint_data.encode()).hexdigest()[:24].upper()
+
+    return fingerprint
+
+
+def get_device_info() -> Dict:
+    """获取系统详细信息"""
+    info = {}
+
+    try:
+        if platform.system() == "Windows":
+            # 获取Windows产品ID
+            with winreg.OpenKey(
+                winreg.HKEY_LOCAL_MACHINE,
+                r"SOFTWARE\Microsoft\Windows NT\CurrentVersion",
+            ) as key:
+                info["product_id"] = winreg.QueryValueEx(key, "ProductId")[0]
+                info["product_name"] = winreg.QueryValueEx(key, "ProductName")[0]
+
+        # 获取BIOS UUID
+        result = subprocess.run(
+            ["wmic", "csproduct", "get", "uuid"],
+            capture_output=True,
+            text=True,
+            shell=True,
+            creationflags=subprocess.CREATE_NO_WINDOW,
+        )
+
+        if result.returncode == 0:
+            lines = [line.strip() for line in result.stdout.strip().split("\n")]
+            # 过滤掉空行
+            lines = [line for line in lines if line]
+
+            if len(lines) > 1:
+                uuid = lines[1] if len(lines) > 1 else lines[0]
+                if uuid and uuid != "":
+                    info["bios_uuid"] = uuid
+
+        # 获取主机名
+        info["hostname"] = socket.gethostname()
+
+    except Exception as e:
+        info["error"] = str(e)
+
+    return info
+
+
+if __name__ == "__main__":
+    device_id = get_device_id()
+    print(f"设备特征码: {device_id}")
+    print(f"设备信息: {get_platform_device_id()}")
+
+    with winreg.OpenKey(
+        winreg.HKEY_LOCAL_MACHINE,
+        r"SOFTWARE\Microsoft\Windows NT\CurrentVersion",
+    ) as key:
+        uuid, _ = winreg.QueryValueEx(key, "ProductId")
+    print(f"Windows产品ID: {uuid}")
+    # info = get_device_info()
+    # for key, value in info.items():
+    #     print(f"  {key}: {value}")
+    system = platform.system()
+
+    info = {
+        "system": system,
+        "release": platform.release(),
+        "machine": platform.machine(),
+        "node": platform.node(),
+        "hostname": socket.gethostname(),
+    }
+    print(info)

+ 92 - 0
utils/logger.py

@@ -0,0 +1,92 @@
+import logging
+import json
+from datetime import datetime
+import os
+from typing import Dict, Any
+from pathlib import Path
+
+# 确保日志目录存在
+LOG_DIR = Path("chat_logs")
+LOG_DIR.mkdir(exist_ok=True)
+
+
+# 配置日志
+def setup_logging():
+    """配置日志系统"""
+    # 主日志记录器
+    logger = logging.getLogger("chat_logger")
+    logger.setLevel(logging.INFO)
+
+    # 避免重复添加handler
+    if not logger.handlers:
+        # 文件处理器 - 按天分割
+        log_file = LOG_DIR / f"chat_{datetime.now().strftime('%Y%m%d')}.log"
+        file_handler = logging.FileHandler(log_file, encoding="utf-8")
+        file_handler.setLevel(logging.INFO)
+
+        # 控制台处理器
+        console_handler = logging.StreamHandler()
+        console_handler.setLevel(logging.INFO)
+
+        # 格式化
+        formatter = logging.Formatter(
+            "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+            datefmt="%Y-%m-%d %H:%M:%S",
+        )
+        file_handler.setFormatter(formatter)
+        console_handler.setFormatter(formatter)
+
+        logger.addHandler(file_handler)
+        logger.addHandler(console_handler)
+
+    return logger
+
+
+# 初始化日志
+chat_logger = setup_logging()
+
+
+def log_chat_entry(user_id: str, user_message: str, agent_response: Dict[str, Any]):
+    """记录完整的对话日志"""
+    try:
+        log_entry = {
+            "timestamp": datetime.now().isoformat(),
+            "user_id": user_id,
+            "user_message": user_message,
+            "agent_response": {
+                "final_answer": agent_response.get("final_answer", ""),
+                "all_ai_messages_count": len(agent_response.get("all_ai_messages", [])),
+                "all_messages_count": len(agent_response.get("all_messages", [])),
+                "tool_calls_count": len(agent_response.get("tool_calls", [])),
+            },
+            "all_messages": [
+                {
+                    "type": msg.get("type"),
+                    "content": msg.get("content", "")[:500],  # 限制长度
+                    "tool_calls": msg.get("tool_calls"),
+                    "index": msg.get("index"),
+                }
+                for msg in agent_response.get("all_messages", [])
+            ],
+            "tool_calls": agent_response.get("tool_calls", []),
+        }
+
+        # 记录到日志文件
+        chat_logger.info(f"对话记录 - Thread: {user_id}")
+        chat_logger.info(f"用户消息: {user_message}")
+        chat_logger.info(
+            f"Agent响应: {agent_response.get('final_answer', '')[:200]}..."
+        )
+
+        # 保存详细日志到单独文件
+        detailed_log_file = (
+            LOG_DIR
+            / f"detailed_{user_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
+        )
+        with open(detailed_log_file, "w", encoding="utf-8") as f:
+            json.dump(log_entry, f, ensure_ascii=False, indent=2)
+
+        chat_logger.info(f"详细日志已保存到: {detailed_log_file}")
+
+    except Exception as e:
+        chat_logger.error(f"记录日志时出错: {str(e)}")

+ 244 - 0
utils/registration.py

@@ -0,0 +1,244 @@
+import json
+import base64
+import hashlib
+from datetime import datetime
+from Crypto.Cipher import AES
+from Crypto.Util.Padding import pad, unpad
+import os
+import sys
+from typing import Tuple, Optional, Dict
+
+
+class RegistrationValidator:
+    def __init__(self, secret_key: str):
+        """
+        初始化注册验证器
+
+        Args:
+            secret_key: 密钥(必须与生成时相同)
+        """
+        if len(secret_key) not in [16, 24, 32]:
+            # 自动调整密钥长度
+            if len(secret_key) < 16:
+                secret_key = secret_key.ljust(16, "0")
+            elif len(secret_key) < 24:
+                secret_key = secret_key.ljust(24, "0")
+            elif len(secret_key) < 32:
+                secret_key = secret_key.ljust(32, "0")
+            else:
+                secret_key = secret_key[:32]
+
+        self.secret_key = secret_key.encode("utf-8")
+
+    def _get_cipher(self, iv=None):
+        """获取AES加密对象"""
+        if iv is None:
+            # 使用密钥的前16字节作为固定IV
+            iv = hashlib.md5(self.secret_key).digest()[:16]
+        return AES.new(self.secret_key, AES.MODE_CBC, iv)
+
+    def decrypt_registration(self, registration_code: str) -> Optional[Dict]:
+        """
+        解密注册码
+
+        Returns:
+            注册信息字典,解密失败返回None
+        """
+        try:
+            # Base64解码
+            encrypted = base64.urlsafe_b64decode(registration_code.encode("utf-8"))
+
+            # AES解密
+            cipher = self._get_cipher()
+            decrypted = unpad(cipher.decrypt(encrypted), AES.block_size)
+
+            # JSON解析
+            reg_data = json.loads(decrypted.decode("utf-8"))
+
+            return reg_data
+        except Exception as e:
+            print(f"解密注册码失败: {e}")
+            return None
+
+    def validate(self, device_id: str, registration_code: str) -> Tuple[bool, str]:
+        """
+        验证注册码
+
+        Args:
+            device_id: 当前设备特征码
+            registration_code: 注册码
+
+        Returns:
+            (是否有效, 消息)
+        """
+        if not registration_code or registration_code.strip() == "":
+            return False, "未提供注册码"
+
+        # 解密注册码
+        reg_info = self.decrypt_registration(registration_code)
+        if not reg_info:
+            return False, "注册码无效或已损坏"
+
+        # 检查设备ID
+        reg_device_id = reg_info.get("device_id", "")
+        if reg_device_id != device_id:
+            return (
+                False,
+                f"设备不匹配。注册设备: {reg_device_id},当前设备: {device_id}",
+            )
+
+        # 检查有效期
+        expire_date_str = reg_info.get("expire_date", "")
+        if not expire_date_str:
+            return False, "注册信息不完整(缺少有效期)"
+
+        try:
+            expire_date = datetime.strptime(expire_date_str, "%Y-%m-%d").date()
+            current_date = datetime.now().date()
+
+            if current_date > expire_date:
+                days_overdue = (current_date - expire_date).days
+                return (
+                    False,
+                    f"注册已过期 {days_overdue} 天!有效期至: {expire_date_str}",
+                )
+
+            # 计算剩余天数
+            days_left = (expire_date - current_date).days
+
+            if days_left <= 7:
+                return (
+                    True,
+                    f"注册有效({days_left}天后到期)有效期至: {expire_date_str}",
+                )
+            elif days_left <= 30:
+                return True, f"注册有效(剩余{days_left}天)有效期至: {expire_date_str}"
+            else:
+                return True, f"注册有效,有效期至: {expire_date_str}"
+
+        except ValueError as e:
+            return False, f"日期格式错误: {str(e)}"
+
+
+def check_registration_required() -> bool:
+    """
+    检查是否需要注册验证
+
+    可以通过环境变量控制是否启用注册验证
+    """
+    return os.getenv("REGISTRATION_REQUIRED", "true").lower() == "true"
+
+
+def get_registration_code() -> str:
+    """
+    获取注册码
+
+    优先级:
+    1. 环境变量 REGISTRATION_CODE(取消)
+    2. 配置文件 .registration
+    3. 当前目录的 registration.txt
+    """
+    # # 从环境变量获取
+    # reg_code = os.getenv("REGISTRATION_CODE", "").strip()
+    # if reg_code:
+    #     return reg_code
+    current_file_dir = os.path.dirname(os.path.abspath(__file__))  # utils目录
+    project_root = os.path.dirname(current_file_dir)  # 项目根目录
+    # 配置文件列表
+    config_files = [
+        os.path.join(project_root, ".registration"),
+        os.path.join(project_root, "registration.txt"),
+        os.path.join(project_root, "config", "registration.txt"),
+        # 保留当前目录作为备用
+        ".registration",
+        "registration.txt",
+    ]
+
+    for config_file in config_files:
+        if os.path.exists(config_file):
+            try:
+                with open(config_file, "r") as f:
+                    for line in f:
+                        if line.startswith("REGISTRATION_CODE="):
+                            return line.strip().split("=", 1)[1]
+                        elif (
+                            not line.startswith("#")
+                            and "=" not in line
+                            and len(line.strip()) > 20
+                        ):
+                            # 可能是单独的注册码
+                            return line.strip()
+            except:
+                continue
+
+    return ""
+
+
+def validate_registration(
+    secret_key: str, device_id_func=None
+) -> Tuple[bool, str, Optional[Dict]]:
+    """
+    验证注册的主函数
+
+    Args:
+        secret_key: 密钥
+        device_id_func: 获取设备ID的函数,默认为导入device_id模块
+
+    Returns:
+        (是否通过, 消息, 注册信息)
+    """
+    # 如果不要求注册验证,直接通过
+    if not check_registration_required():
+        return True, "注册验证已禁用", None
+
+    # 获取设备ID函数
+    if device_id_func is None:
+        try:
+            from device_id import get_device_id
+
+            device_id_func = get_device_id
+        except ImportError:
+            # 如果在同一文件中定义了get_device_id
+            from device_id import get_device_id as device_id_func
+
+    # 获取当前设备ID
+    try:
+        current_device_id = device_id_func()
+    except Exception as e:
+        return False, f"获取设备特征码失败: {e}", None
+
+    # 获取注册码
+    registration_code = get_registration_code()
+    if not registration_code:
+        return (
+            False,
+            "未找到注册码。请设置REGISTRATION_CODE环境变量或创建注册文件。",
+            None,
+        )
+
+    # 创建验证器并验证
+    validator = RegistrationValidator(secret_key)
+
+    # 解密注册信息
+    reg_info = validator.decrypt_registration(registration_code)
+    if not reg_info:
+        return False, "注册码无效或已损坏", None
+
+    # 验证
+    is_valid, message = validator.validate(current_device_id, registration_code)
+
+    return is_valid, message, reg_info
+
+
+# 使用示例
+if __name__ == "__main__":
+    # 测试验证
+    secret_key = "ialwayslovelongjoe"
+    from device_id import get_device_id
+
+    is_valid, message, reg_info = validate_registration(secret_key, get_device_id)
+
+    print(f"验证结果: {is_valid}")
+    print(f"消息: {message}")
+    if reg_info:
+        print(f"注册信息: {reg_info}")

+ 84 - 0
utils/registration_manager.py

@@ -0,0 +1,84 @@
+import time
+from typing import Optional, Tuple
+from utils.device_id import get_device_id
+from utils.registration import validate_registration
+from regex import T
+from utils.logger import chat_logger
+
+
+class RegistrationManager:
+    """注册状态管理器"""
+
+    def __init__(self, check_interval: int = 3600):
+        self._is_registered: Optional[bool] = None
+        self._last_check_time: float = 0
+        self._check_interval = check_interval  # 检查间隔(秒)
+
+    async def check_registration(self) -> bool:
+        """检查注册状态,带缓存"""
+        current_time = time.time()
+
+        # 如果缓存未过期,直接返回缓存结果
+        if (
+            self._is_registered is True
+            and current_time - self._last_check_time < self._check_interval
+        ):
+            return self._is_registered
+
+        try:
+            is_valid, message = await self._check_license_validity()
+
+            self._is_registered = is_valid
+            self._last_check_time = current_time
+
+            if is_valid:
+                chat_logger.info(f"✅ {message}")
+            else:
+                chat_logger.warning(f"❌ {message}")
+
+            return is_valid
+
+        except Exception as e:
+            chat_logger.error(f"注册检查失败: {str(e)}")
+            # 检查失败时保守处理
+            self._is_registered = False
+            return False
+
+    async def _check_license_validity(self) -> Tuple[bool, str]:
+        """实际的注册检查逻辑"""
+
+        secret_key = "ialwayslovelongjoe"
+
+        try:
+            is_valid, message, reg_info = validate_registration(
+                secret_key, get_device_id
+            )
+
+            print(f"验证结果: {is_valid}")
+            print(f"消息: {message}")
+            if reg_info:
+                print(f"注册信息: {reg_info}")
+
+            return is_valid, message  # 临时返回True
+
+        except Exception as e:
+            chat_logger.error(f"许可证检查异常: {str(e)}")
+            return False, message
+
+    def force_refresh(self):
+        """强制刷新注册状态"""
+        self._is_registered = None
+        self._last_check_time = 0
+        chat_logger.info("强制刷新注册状态")
+
+    def get_status(self):
+        """获取当前状态信息"""
+        return {
+            "is_registered": self._is_registered,
+            "last_check_time": self._last_check_time,
+            "check_interval": self._check_interval,
+        }
+
+
+# 全局实例
+registration_manager = RegistrationManager()