diff --git a/docs/deployments/predictors.md b/docs/deployments/predictors.md index 569e1ee042..6d03939e8e 100644 --- a/docs/deployments/predictors.md +++ b/docs/deployments/predictors.md @@ -55,11 +55,13 @@ class PythonPredictor: """ pass - def predict(self, payload): + def predict(self, payload, query_params, headers): """Called once per request. Preprocesses the request payload (if necessary), runs inference, and postprocesses the inference output (if necessary). Args: - payload: The parsed JSON request payload. + payload: The request payload (see below for the possible payload types) (optional). + query_params: A dictionary of the query parameters used in the request (optional). + headers: A dictionary of the headers sent in the request (optional). Returns: Prediction or a batch of predictions. @@ -69,6 +71,8 @@ class PythonPredictor: For proper separation of concerns, it is recommended to use the constructor's `config` paramater for information such as from where to download the model and initialization files, or any configurable model parameters. You define `config` in your [API configuration](api-configuration.md), and it is passed through to your Predictor's constructor. +The `payload` parameter is parsed according to the `Content-Type` header in the request. For `Content-Type: application/json`, `payload` will be the parsed JSON body. For `Content-Type: multipart/form` or `Content-Type: application/x-www-form-urlencoded`, `payload` will be `starlette.datastructures.FormData` (key-value pairs where the value is a `string` for form data, or `starlette.datastructures.UploadFile` for file uploads, see [Starlette's documentation](https://www.starlette.io/requests/#request-files)). For all other `Content-Type` values, `payload` will the the raw `bytes` of the request body. + ### Examples @@ -173,11 +177,13 @@ class TensorFlowPredictor: self.client = tensorflow_client # Additional initialization may be done here - def predict(self, payload): + def predict(self, payload, query_params, headers): """Called once per request. Preprocesses the request payload (if necessary), runs inference (e.g. by calling self.client.predict(model_input)), and postprocesses the inference output (if necessary). Args: - payload: The parsed JSON request payload. + payload: The request payload (see below for the possible payload types) (optional). + query_params: A dictionary of the query parameters used in the request (optional). + headers: A dictionary of the headers sent in the request (optional). Returns: Prediction or a batch of predictions. @@ -190,6 +196,8 @@ Cortex provides a `tensorflow_client` to your Predictor's constructor. `tensorfl For proper separation of concerns, it is recommended to use the constructor's `config` paramater for information such as configurable model parameters or download links for initialization files. You define `config` in your [API configuration](api-configuration.md), and it is passed through to your Predictor's constructor. +The `payload` parameter is parsed according to the `Content-Type` header in the request. For `Content-Type: application/json`, `payload` will be the parsed JSON body. For `Content-Type: multipart/form` or `Content-Type: application/x-www-form-urlencoded`, `payload` will be `starlette.datastructures.FormData` (key-value pairs where the value is a `string` for form data, or `starlette.datastructures.UploadFile` for file uploads, see [Starlette's documentation](https://www.starlette.io/requests/#request-files)). For all other `Content-Type` values, `payload` will the the raw `bytes` of the request body. + ### Examples @@ -249,11 +257,13 @@ class ONNXPredictor: self.client = onnx_client # Additional initialization may be done here - def predict(self, payload): + def predict(self, payload, query_params, headers): """Called once per request. Preprocesses the request payload (if necessary), runs inference (e.g. by calling self.client.predict(model_input)), and postprocesses the inference output (if necessary). Args: - payload: The parsed JSON request payload. + payload: The request payload (see below for the possible payload types) (optional). + query_params: A dictionary of the query parameters used in the request (optional). + headers: A dictionary of the headers sent in the request (optional). Returns: Prediction or a batch of predictions. @@ -266,6 +276,8 @@ Cortex provides an `onnx_client` to your Predictor's constructor. `onnx_client` For proper separation of concerns, it is recommended to use the constructor's `config` paramater for information such as configurable model parameters or download links for initialization files. You define `config` in your [API configuration](api-configuration.md), and it is passed through to your Predictor's constructor. +The `payload` parameter is parsed according to the `Content-Type` header in the request. For `Content-Type: application/json`, `payload` will be the parsed JSON body. For `Content-Type: multipart/form` or `Content-Type: application/x-www-form-urlencoded`, `payload` will be `starlette.datastructures.FormData` (key-value pairs where the value is a `string` for form data, or `starlette.datastructures.UploadFile` for file uploads, see [Starlette's documentation](https://www.starlette.io/requests/#request-files)). For all other `Content-Type` values, `payload` will the the raw `bytes` of the request body. + ### Examples diff --git a/pkg/workloads/cortex/lib/type/predictor.py b/pkg/workloads/cortex/lib/type/predictor.py index e33d270ed4..037d663dd9 100644 --- a/pkg/workloads/cortex/lib/type/predictor.py +++ b/pkg/workloads/cortex/lib/type/predictor.py @@ -57,10 +57,12 @@ def initialize_client(self, model_dir=None, tf_serving_host=None, tf_serving_por def initialize_impl(self, project_dir, client=None): class_impl = self.class_impl(project_dir) try: - if self.type == "python": - return class_impl(self.config) + if self.type == "onnx": + return class_impl(onnx_client=client, config=self.config) + elif self.type == "tensorflow": + return class_impl(tensorflow_client=client, config=self.config) else: - return class_impl(client, self.config) + return class_impl(config=self.config) except Exception as e: raise UserRuntimeException(self.path, "__init__", str(e)) from e finally: @@ -128,55 +130,90 @@ def _load_module(self, module_name, impl_path): PYTHON_CLASS_VALIDATION = { "required": [ - {"name": "__init__", "args": ["self", "config"]}, - {"name": "predict", "args": ["self", "payload"]}, + {"name": "__init__", "required_args": ["self", "config"]}, + { + "name": "predict", + "required_args": ["self"], + "optional_args": ["payload", "query_params", "headers"], + }, ] } TENSORFLOW_CLASS_VALIDATION = { "required": [ - {"name": "__init__", "args": ["self", "tensorflow_client", "config"]}, - {"name": "predict", "args": ["self", "payload"]}, + {"name": "__init__", "required_args": ["self", "tensorflow_client", "config"]}, + { + "name": "predict", + "required_args": ["self"], + "optional_args": ["payload", "query_params", "headers"], + }, ] } ONNX_CLASS_VALIDATION = { "required": [ - {"name": "__init__", "args": ["self", "onnx_client", "config"]}, - {"name": "predict", "args": ["self", "payload"]}, + {"name": "__init__", "required_args": ["self", "onnx_client", "config"]}, + { + "name": "predict", + "required_args": ["self"], + "optional_args": ["payload", "query_params", "headers"], + }, ] } def _validate_impl(impl, impl_req): - for optional_func in impl_req.get("optional", []): - _validate_optional_fn_args(impl, optional_func["name"], optional_func["args"]) + for optional_func_signature in impl_req.get("optional", []): + _validate_optional_fn_args(impl, optional_func_signature) - for required_func in impl_req.get("required", []): - _validate_required_fn_args(impl, required_func["name"], required_func["args"]) + for required_func_signature in impl_req.get("required", []): + _validate_required_fn_args(impl, required_func_signature) -def _validate_optional_fn_args(impl, fn_name, args): - if fn_name in vars(impl): - _validate_required_fn_args(impl, fn_name, args) +def _validate_optional_fn_args(impl, func_signature): + if getattr(impl, func_signature["name"], None): + _validate_required_fn_args(impl, func_signature) -def _validate_required_fn_args(impl, fn_name, args): - fn = getattr(impl, fn_name, None) +def _validate_required_fn_args(impl, func_signature): + fn = getattr(impl, func_signature["name"], None) if not fn: - raise UserException('required function "{}" is not defined'.format(fn_name)) + raise UserException(f'required function "{func_signature["name"]}" is not defined') if not callable(fn): - raise UserException('"{}" is defined, but is not a function'.format(fn_name)) + raise UserException(f'"{func_signature["name"]}" is defined, but is not a function') argspec = inspect.getfullargspec(fn) - if argspec.args != args: - raise UserException( - 'invalid signature for function "{}": expected arguments ({}) but found ({})'.format( - fn_name, ", ".join(args), ", ".join(argspec.args) + required_args = func_signature.get("required_args", []) + optional_args = func_signature.get("optional_args", []) + fn_str = f'{func_signature["name"]}({", ".join(argspec.args)})' + + for arg_name in required_args: + if arg_name not in argspec.args: + raise UserException( + f'invalid signature for function "{fn_str}": "{arg_name}" is a required argument, but was not provided' + ) + + if arg_name == "self": + if argspec.args[0] != "self": + raise UserException( + f'invalid signature for function "{fn_str}": "self" must be the first argument' + ) + + seen_args = [] + for arg_name in argspec.args: + if arg_name not in required_args and arg_name not in optional_args: + raise UserException( + f'invalid signature for function "{fn_str}": "{arg_name}" is not a supported argument' ) - ) + + if arg_name in seen_args: + raise UserException( + f'invalid signature for function "{fn_str}": "{arg_name}" is duplicated' + ) + + seen_args.append(arg_name) tf_expected_dir_structure = """tensorflow model directories must have the following structure: diff --git a/pkg/workloads/cortex/serve/requirements.txt b/pkg/workloads/cortex/serve/requirements.txt index 6ae6933fef..466e22bfa3 100644 --- a/pkg/workloads/cortex/serve/requirements.txt +++ b/pkg/workloads/cortex/serve/requirements.txt @@ -4,6 +4,7 @@ dill==0.3.1.1 fastapi==0.54.1 msgpack==1.0.0 numpy==1.18.4 +python-multipart==0.0.5 pyyaml==5.3.1 requests==2.23.0 uvicorn==0.11.5 diff --git a/pkg/workloads/cortex/serve/serve.py b/pkg/workloads/cortex/serve/serve.py index 6929da275c..ce5dc84641 100644 --- a/pkg/workloads/cortex/serve/serve.py +++ b/pkg/workloads/cortex/serve/serve.py @@ -15,6 +15,7 @@ import sys import os import argparse +import inspect import time import json import msgpack @@ -154,11 +155,31 @@ async def register_request(request: Request, call_next): return response -def predict(request: Any = Body(..., media_type="application/json")): +@app.middleware("http") +async def parse_payload(request: Request, call_next): + if "payload" not in local_cache["predict_fn_args"]: + return await call_next(request) + + content_type = request.headers.get("content-type", "").lower() + + if content_type.startswith("multipart/form") or content_type.startswith( + "application/x-www-form-urlencoded" + ): + request.state.payload = await request.form() + elif content_type.startswith("application/json"): + request.state.payload = await request.json() + else: + request.state.payload = await request.body() + + return await call_next(request) + + +def predict(request: Request): api = local_cache["api"] predictor_impl = local_cache["predictor_impl"] + args = build_predict_args(request) - prediction = predictor_impl.predict(request) + prediction = predictor_impl.predict(**args) if isinstance(prediction, bytes): response = Response(content=prediction, media_type="application/octet-stream") @@ -194,6 +215,19 @@ def predict(request: Any = Body(..., media_type="application/json")): return response +def build_predict_args(request: Request): + args = {} + + if "payload" in local_cache["predict_fn_args"]: + args["payload"] = request.state.payload + if "headers" in local_cache["predict_fn_args"]: + args["headers"] = request.headers + if "query_params" in local_cache["predict_fn_args"]: + args["query_params"] = request.query_params + + return args + + def get_summary(): response = {"message": API_SUMMARY_MESSAGE} @@ -230,6 +264,7 @@ def start(): storage = LocalStorage(os.getenv("CORTEX_CACHE_DIR")) else: storage = S3(bucket=os.environ["CORTEX_BUCKET"], region=os.environ["AWS_REGION"]) + try: raw_api_spec = get_spec(provider, storage, cache_dir, spec_path) api = API(provider=provider, storage=storage, cache_dir=cache_dir, **raw_api_spec) @@ -243,6 +278,7 @@ def start(): local_cache["provider"] = provider local_cache["client"] = client local_cache["predictor_impl"] = predictor_impl + local_cache["predict_fn_args"] = inspect.getfullargspec(predictor_impl.predict).args predict_route = "/" if provider != "local": predict_route = "/predict" @@ -250,6 +286,7 @@ def start(): except: cx_logger().exception("failed to start api") sys.exit(1) + if ( provider != "local" and api.monitoring is not None