registration_middleware.py 1.2 KB

123456789101112131415161718192021222324252627282930
  1. from fastapi import Request, HTTPException
  2. from fastapi.responses import JSONResponse
  3. from utils.registration_manager import registration_manager
  4. from utils.logger import chat_logger
  5. async def registration_check_middleware(request: Request, call_next):
  6. """
  7. 注册检查中间件
  8. 对需要注册验证的接口进行检查
  9. """
  10. # 定义需要检查注册状态的接口列表
  11. protected_paths = ["/chat", "/message_create_bill", "/ocr_create_bill", "/ocr"]
  12. # 只拦截POST /chat请求
  13. if request.url.path in protected_paths and request.method == "POST":
  14. if not await registration_manager.check_registration():
  15. chat_logger.warning(f"拒绝未注册访问: {request.client.host}")
  16. if request.url.path == "/chat":
  17. response_data = {
  18. "final_answer": "服务未注册或注册已过期,请联系管理员。"
  19. }
  20. else:
  21. response_data = "服务未注册或注册已过期,请联系管理员。"
  22. return JSONResponse(status_code=403, content=response_data)
  23. # 其他请求直接放行
  24. response = await call_next(request)
  25. return response