diff --git a/pkg/workloads/cortex/serve/serve.py b/pkg/workloads/cortex/serve/serve.py index 25b6f15e87..93fa5e4743 100644 --- a/pkg/workloads/cortex/serve/serve.py +++ b/pkg/workloads/cortex/serve/serve.py @@ -25,6 +25,7 @@ from fastapi import Body, FastAPI from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.responses import Response from starlette.background import BackgroundTasks @@ -55,6 +56,15 @@ ) app = FastAPI() + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + local_cache = {"api": None, "predictor_impl": None, "client": None, "class_set": set()} @@ -90,21 +100,18 @@ def is_prediction_request(request): @app.exception_handler(StarletteHTTPException) async def http_exception_handler(request, e): response = Response(content=str(e.detail), status_code=e.status_code) - apply_cors_headers(request, response) return response @app.exception_handler(RequestValidationError) async def validation_exception_handler(request, e): response = Response(content=str(e), status_code=400) - apply_cors_headers(request, response) return response @app.exception_handler(Exception) async def uncaught_exception_handler(request, e): response = Response(content="internal server error", status_code=500) - apply_cors_headers(request, response) return response @@ -132,20 +139,12 @@ async def register_request(request: Request, call_next): status_code = 500 if response is not None: status_code = response.status_code - apply_cors_headers(request, response) api = local_cache["api"] api.post_request_metrics(status_code, time.time() - request.state.start_time) return response -def apply_cors_headers(request: Request, response: Response): - response.headers["Access-Control-Allow-Origin"] = "*" - response.headers["Access-Control-Allow-Headers"] = request.headers.get( - "Access-Control-Request-Headers", "*" - ) - - @app.post("/predict") def predict(request: Any = Body(..., media_type="application/json"), debug=False): api = local_cache["api"]