|
|
@@ -0,0 +1,314 @@
|
|
|
+import base64
|
|
|
+from datetime import datetime
|
|
|
+from fastapi import APIRouter, HTTPException
|
|
|
+
|
|
|
+from utils.device_id import get_device_id
|
|
|
+from .models import ChatRequest, ChatResponse, OCRRequest, MessageCreateBill
|
|
|
+from core.chat_service import chat_service
|
|
|
+from core.agent_manager import agent_manager
|
|
|
+from utils.logger import chat_logger
|
|
|
+from tools.tool_factory import get_all_tools
|
|
|
+import time
|
|
|
+from utils.registration_manager import registration_manager
|
|
|
+from core.ocr_service import PaddleOCRService
|
|
|
+from core.document_processor.document_service import DocumentProcessingService
|
|
|
+
|
|
|
+# 初始化服务
|
|
|
+ocr_service = PaddleOCRService(
|
|
|
+ api_url="https://a8l0g1qda8zd48nb.aistudio-app.com/ocr",
|
|
|
+ token="f97d214abf87d5ea3c156e21257732a3b19661cb",
|
|
|
+)
|
|
|
+doc_service = DocumentProcessingService(ocr_service=ocr_service)
|
|
|
+
|
|
|
+
|
|
|
+router = APIRouter()
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/chat", response_model=ChatResponse)
|
|
|
+async def chat_endpoint(request: ChatRequest):
|
|
|
+ """聊天接口"""
|
|
|
+ try:
|
|
|
+ result = await chat_service.process_chat_request(request.model_dump())
|
|
|
+ return ChatResponse(**result)
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.error(f"API处理失败: {str(e)}")
|
|
|
+ raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/ocr")
|
|
|
+async def orc(request: OCRRequest):
|
|
|
+ """OCR接口"""
|
|
|
+ try:
|
|
|
+ chat_logger.info(f"开始进行图片识别")
|
|
|
+
|
|
|
+ # 1. 解码Base64图片
|
|
|
+ try:
|
|
|
+ if "," in request.image:
|
|
|
+ # 去掉 data:image/xxx;base64, 前缀
|
|
|
+ base64_str = request.image.split(",", 1)[1]
|
|
|
+ else:
|
|
|
+ base64_str = request.image
|
|
|
+
|
|
|
+ image_bytes = base64.b64decode(base64_str)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.error(f"图片解码失败: {e}")
|
|
|
+ raise HTTPException(400, f"图片格式错误: {str(e)}")
|
|
|
+
|
|
|
+ # 2. OCR识别
|
|
|
+ result = await doc_service.pure_ocr(image_bytes=image_bytes)
|
|
|
+
|
|
|
+ # 3. 返回结果
|
|
|
+ return result
|
|
|
+
|
|
|
+ except HTTPException:
|
|
|
+ raise
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.error(f"处理失败: {e}")
|
|
|
+ raise HTTPException(500, f"处理失败: {str(e)}")
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/ocr_create_bill")
|
|
|
+async def ocr_create_bill_endpoint(request: OCRRequest):
|
|
|
+ """
|
|
|
+ 处理单张图片
|
|
|
+ 请求格式:
|
|
|
+ {
|
|
|
+ "image": "data:image/jpeg;base64,/9j/4AAQSkZJRg...",
|
|
|
+ "type": "invoice",
|
|
|
+ }
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ chat_logger.info(f"开始处理 {request.type} 单据")
|
|
|
+
|
|
|
+ # 1. 解码Base64图片
|
|
|
+ try:
|
|
|
+ if "," in request.image:
|
|
|
+ # 去掉 data:image/xxx;base64, 前缀
|
|
|
+ base64_str = request.image.split(",", 1)[1]
|
|
|
+ else:
|
|
|
+ base64_str = request.image
|
|
|
+
|
|
|
+ image_bytes = base64.b64decode(base64_str)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.error(f"图片解码失败: {e}")
|
|
|
+ raise HTTPException(400, f"图片格式错误: {str(e)}")
|
|
|
+
|
|
|
+ # 2. 处理单据
|
|
|
+ result = await doc_service.ocr_create_bill(
|
|
|
+ image_bytes=image_bytes, document_type=request.type
|
|
|
+ )
|
|
|
+
|
|
|
+ # 3. 返回结果
|
|
|
+ return {
|
|
|
+ "success": True,
|
|
|
+ "type": request.type,
|
|
|
+ "text": result.get("ocr_text", ""), # 识别出的文本
|
|
|
+ "data": result.get("parsed_data", {}), # 结构化数据
|
|
|
+ "timestamp": datetime.now().isoformat(),
|
|
|
+ }
|
|
|
+
|
|
|
+ except HTTPException:
|
|
|
+ raise
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.error(f"处理失败: {e}")
|
|
|
+ raise HTTPException(500, f"处理失败: {str(e)}")
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/message_create_bill")
|
|
|
+async def message_create_bill_endpoint(request: MessageCreateBill):
|
|
|
+ """
|
|
|
+ 通过文本消息辅助建立单据
|
|
|
+ 请求格式:
|
|
|
+ {
|
|
|
+ "message": "这是一条单据描述",
|
|
|
+ "document_type": "invoice",
|
|
|
+ }
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ result = await doc_service.message_create_bill(
|
|
|
+ message=request.message, document_type=request.document_type
|
|
|
+ )
|
|
|
+ return result
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.error(f"处理失败: {e}")
|
|
|
+ raise HTTPException(500, f"处理失败: {str(e)}")
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/cache/status")
|
|
|
+async def cache_status():
|
|
|
+ """查看缓存状态 - 修复版本"""
|
|
|
+ try:
|
|
|
+ cache_status = await agent_manager.get_cache_status()
|
|
|
+
|
|
|
+ # ✅ 确保返回完整信息
|
|
|
+ return {
|
|
|
+ "success": True,
|
|
|
+ "cache_size": cache_status.get("cache_size", 0),
|
|
|
+ "cache_expiry_seconds": cache_status.get("cache_expiry_seconds", 3600),
|
|
|
+ "cache_entries_count": len(cache_status.get("cache_entries", [])),
|
|
|
+ "cache_entries": cache_status.get("cache_entries", []),
|
|
|
+ "timestamp": time.time(),
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.error(f"获取缓存状态失败: {str(e)}")
|
|
|
+ return {
|
|
|
+ "success": False,
|
|
|
+ "error": str(e),
|
|
|
+ "cache_size": 0,
|
|
|
+ "cache_entries_count": 0,
|
|
|
+ "cache_entries": [],
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+@router.delete("/cache/clear")
|
|
|
+async def clear_cache():
|
|
|
+ """清空缓存"""
|
|
|
+ try:
|
|
|
+ # ✅ 使用异步版本
|
|
|
+ cleared = await agent_manager.clear_cache()
|
|
|
+ chat_logger.info(f"清空agent缓存, 清理数量={cleared}")
|
|
|
+ return {
|
|
|
+ "cleared_entries": cleared,
|
|
|
+ "message": "缓存已清空",
|
|
|
+ "status": "success",
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.error(f"清空缓存失败: {str(e)}")
|
|
|
+ raise HTTPException(status_code=500, detail=f"清空缓存失败: {str(e)}")
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/health")
|
|
|
+async def health_check():
|
|
|
+ """健康检查"""
|
|
|
+ try:
|
|
|
+ # ✅ 使用异步版本
|
|
|
+
|
|
|
+ registration_status = await registration_manager.check_registration()
|
|
|
+
|
|
|
+ return {
|
|
|
+ "status": "healthy",
|
|
|
+ "service": "龙嘉软件AI助手API",
|
|
|
+ "registration_status": (
|
|
|
+ "已注册" if registration_status else "未注册或注册过期"
|
|
|
+ ),
|
|
|
+ "device_id": get_device_id(),
|
|
|
+ "timestamp": time.time(),
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ return {
|
|
|
+ "status": "unhealthy",
|
|
|
+ "service": "龙嘉软件AI助手API",
|
|
|
+ "error": str(e),
|
|
|
+ "timestamp": time.time(),
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/tools/status")
|
|
|
+async def tools_status():
|
|
|
+ """查看工具状态"""
|
|
|
+ try:
|
|
|
+ tools = get_all_tools()
|
|
|
+ tool_info = []
|
|
|
+
|
|
|
+ for i, tool in enumerate(tools):
|
|
|
+ tool_info.append(
|
|
|
+ {
|
|
|
+ "index": i + 1,
|
|
|
+ "name": getattr(tool, "name", "unknown"),
|
|
|
+ "description": getattr(tool, "description", "unknown")[:100]
|
|
|
+ + "...",
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ return {"total_tools": len(tools), "tools": tool_info, "status": "success"}
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.error(f"获取工具状态失败: {str(e)}")
|
|
|
+ raise HTTPException(status_code=500, detail=f"获取工具状态失败: {str(e)}")
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/cache/initialize")
|
|
|
+async def initialize_cache():
|
|
|
+ """初始化缓存系统"""
|
|
|
+ try:
|
|
|
+ # ✅ 使用异步版本
|
|
|
+ await agent_manager.initialize()
|
|
|
+ chat_logger.info("缓存系统初始化完成")
|
|
|
+ return {"status": "success", "message": "缓存系统初始化完成"}
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.error(f"初始化缓存失败: {str(e)}")
|
|
|
+ raise HTTPException(status_code=500, detail=f"初始化缓存失败: {str(e)}")
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/cache/shutdown")
|
|
|
+async def shutdown_cache():
|
|
|
+ """关闭缓存系统"""
|
|
|
+ try:
|
|
|
+ # ✅ 使用异步版本
|
|
|
+ cleared = await agent_manager.shutdown()
|
|
|
+ chat_logger.info(f"缓存系统已关闭,清理了 {cleared} 个实例")
|
|
|
+ return {
|
|
|
+ "cleared_entries": cleared,
|
|
|
+ "message": "缓存系统已关闭",
|
|
|
+ "status": "success",
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ chat_logger.error(f"关闭缓存失败: {str(e)}")
|
|
|
+ raise HTTPException(status_code=500, detail=f"关闭缓存失败: {str(e)}")
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/admin/refresh-registration")
|
|
|
+async def refresh_registration():
|
|
|
+ """手动刷新注册状态(管理员用)"""
|
|
|
+ registration_manager.force_refresh()
|
|
|
+ return {"status": "success", "message": "注册状态已刷新"}
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/admin/registration-status")
|
|
|
+async def get_registration_status():
|
|
|
+ """获取注册状态(管理员用)"""
|
|
|
+ status = await registration_manager.check_registration()
|
|
|
+ status_info = registration_manager.get_status()
|
|
|
+ return {"is_registered": status, "status_info": status_info}
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/")
|
|
|
+async def root():
|
|
|
+ registration_status = await registration_manager.check_registration()
|
|
|
+
|
|
|
+ base_info = {
|
|
|
+ "service": "龙嘉软件AI助手API",
|
|
|
+ "version": "1.0.0",
|
|
|
+ "registration_status": "active" if registration_status else "expired",
|
|
|
+ "device_id": get_device_id(),
|
|
|
+ }
|
|
|
+
|
|
|
+ if registration_status:
|
|
|
+ base_info.update(
|
|
|
+ {
|
|
|
+ "endpoints": {
|
|
|
+ "POST /chat": "聊天",
|
|
|
+ "GET /health": "健康检查",
|
|
|
+ "GET /cache/status": "查看agent缓存状态",
|
|
|
+ "DELETE /cache/clear": "清空agent缓存",
|
|
|
+ "GET /": "API信息",
|
|
|
+ }
|
|
|
+ }
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ base_info.update(
|
|
|
+ {
|
|
|
+ "message": "⚠️ 服务注册已过期,部分功能受限",
|
|
|
+ "available_endpoints": {
|
|
|
+ "GET /health": "健康检查",
|
|
|
+ "GET /registration/status": "查看注册状态",
|
|
|
+ "GET /service/info": "服务完整信息",
|
|
|
+ "GET /": "API信息",
|
|
|
+ },
|
|
|
+ "restricted_endpoints": {"POST /chat": "AI聊天功能(需续费)"},
|
|
|
+ "support_contact": "请联系管理员续费服务",
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ return base_info
|