diff --git a/pkg/workloads/cortex/serve/serve.py b/pkg/workloads/cortex/serve/serve.py index 4cf18b8eca..cf560a09a1 100644 --- a/pkg/workloads/cortex/serve/serve.py +++ b/pkg/workloads/cortex/serve/serve.py @@ -176,13 +176,16 @@ async def parse_payload(request: Request, call_next): return await call_next(request) -def predict(request: Request): +async def predict(request: Request): tasks = BackgroundTasks() api = local_cache["api"] predictor_impl = local_cache["predictor_impl"] kwargs = build_predict_kwargs(request) - prediction = predictor_impl.predict(**kwargs) + if inspect.iscoroutinefunction(predictor_impl.predict): + prediction = await predictor_impl.predict(**kwargs) + else: + prediction = predictor_impl.predict(**kwargs) if isinstance(prediction, bytes): response = Response(content=prediction, media_type="application/octet-stream")