Skip to content

Support query params, headers, and non-json payloads (bytes and forms) #1062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions docs/deployments/predictors.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ class PythonPredictor:
"""
pass

def predict(self, payload):
def predict(self, payload, query_params, headers):
"""Called once per request. Preprocesses the request payload (if necessary), runs inference, and postprocesses the inference output (if necessary).

Args:
payload: The parsed JSON request payload.
payload: The request payload (see below for the possible payload types) (optional).
query_params: A dictionary of the query parameters used in the request (optional).
headers: A dictionary of the headers sent in the request (optional).

Returns:
Prediction or a batch of predictions.
Expand All @@ -69,6 +71,8 @@ class PythonPredictor:

For proper separation of concerns, it is recommended to use the constructor's `config` paramater for information such as from where to download the model and initialization files, or any configurable model parameters. You define `config` in your [API configuration](api-configuration.md), and it is passed through to your Predictor's constructor.

The `payload` parameter is parsed according to the `Content-Type` header in the request. For `Content-Type: application/json`, `payload` will be the parsed JSON body. For `Content-Type: multipart/form` or `Content-Type: application/x-www-form-urlencoded`, `payload` will be `starlette.datastructures.FormData` (key-value pairs where the value is a `string` for form data, or `starlette.datastructures.UploadFile` for file uploads, see [Starlette's documentation](https://www.starlette.io/requests/#request-files)). For all other `Content-Type` values, `payload` will the the raw `bytes` of the request body.

### Examples

<!-- CORTEX_VERSION_MINOR -->
Expand Down Expand Up @@ -173,11 +177,13 @@ class TensorFlowPredictor:
self.client = tensorflow_client
# Additional initialization may be done here

def predict(self, payload):
def predict(self, payload, query_params, headers):
"""Called once per request. Preprocesses the request payload (if necessary), runs inference (e.g. by calling self.client.predict(model_input)), and postprocesses the inference output (if necessary).

Args:
payload: The parsed JSON request payload.
payload: The request payload (see below for the possible payload types) (optional).
query_params: A dictionary of the query parameters used in the request (optional).
headers: A dictionary of the headers sent in the request (optional).

Returns:
Prediction or a batch of predictions.
Expand All @@ -190,6 +196,8 @@ Cortex provides a `tensorflow_client` to your Predictor's constructor. `tensorfl

For proper separation of concerns, it is recommended to use the constructor's `config` paramater for information such as configurable model parameters or download links for initialization files. You define `config` in your [API configuration](api-configuration.md), and it is passed through to your Predictor's constructor.

The `payload` parameter is parsed according to the `Content-Type` header in the request. For `Content-Type: application/json`, `payload` will be the parsed JSON body. For `Content-Type: multipart/form` or `Content-Type: application/x-www-form-urlencoded`, `payload` will be `starlette.datastructures.FormData` (key-value pairs where the value is a `string` for form data, or `starlette.datastructures.UploadFile` for file uploads, see [Starlette's documentation](https://www.starlette.io/requests/#request-files)). For all other `Content-Type` values, `payload` will the the raw `bytes` of the request body.

### Examples

<!-- CORTEX_VERSION_MINOR -->
Expand Down Expand Up @@ -249,11 +257,13 @@ class ONNXPredictor:
self.client = onnx_client
# Additional initialization may be done here

def predict(self, payload):
def predict(self, payload, query_params, headers):
"""Called once per request. Preprocesses the request payload (if necessary), runs inference (e.g. by calling self.client.predict(model_input)), and postprocesses the inference output (if necessary).

Args:
payload: The parsed JSON request payload.
payload: The request payload (see below for the possible payload types) (optional).
query_params: A dictionary of the query parameters used in the request (optional).
headers: A dictionary of the headers sent in the request (optional).

Returns:
Prediction or a batch of predictions.
Expand All @@ -266,6 +276,8 @@ Cortex provides an `onnx_client` to your Predictor's constructor. `onnx_client`

For proper separation of concerns, it is recommended to use the constructor's `config` paramater for information such as configurable model parameters or download links for initialization files. You define `config` in your [API configuration](api-configuration.md), and it is passed through to your Predictor's constructor.

The `payload` parameter is parsed according to the `Content-Type` header in the request. For `Content-Type: application/json`, `payload` will be the parsed JSON body. For `Content-Type: multipart/form` or `Content-Type: application/x-www-form-urlencoded`, `payload` will be `starlette.datastructures.FormData` (key-value pairs where the value is a `string` for form data, or `starlette.datastructures.UploadFile` for file uploads, see [Starlette's documentation](https://www.starlette.io/requests/#request-files)). For all other `Content-Type` values, `payload` will the the raw `bytes` of the request body.

### Examples

<!-- CORTEX_VERSION_MINOR -->
Expand Down
87 changes: 62 additions & 25 deletions pkg/workloads/cortex/lib/type/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ def initialize_client(self, model_dir=None, tf_serving_host=None, tf_serving_por
def initialize_impl(self, project_dir, client=None):
class_impl = self.class_impl(project_dir)
try:
if self.type == "python":
return class_impl(self.config)
if self.type == "onnx":
return class_impl(onnx_client=client, config=self.config)
elif self.type == "tensorflow":
return class_impl(tensorflow_client=client, config=self.config)
else:
return class_impl(client, self.config)
return class_impl(config=self.config)
except Exception as e:
raise UserRuntimeException(self.path, "__init__", str(e)) from e
finally:
Expand Down Expand Up @@ -128,55 +130,90 @@ def _load_module(self, module_name, impl_path):

PYTHON_CLASS_VALIDATION = {
"required": [
{"name": "__init__", "args": ["self", "config"]},
{"name": "predict", "args": ["self", "payload"]},
{"name": "__init__", "required_args": ["self", "config"]},
{
"name": "predict",
"required_args": ["self"],
"optional_args": ["payload", "query_params", "headers"],
},
]
}

TENSORFLOW_CLASS_VALIDATION = {
"required": [
{"name": "__init__", "args": ["self", "tensorflow_client", "config"]},
{"name": "predict", "args": ["self", "payload"]},
{"name": "__init__", "required_args": ["self", "tensorflow_client", "config"]},
{
"name": "predict",
"required_args": ["self"],
"optional_args": ["payload", "query_params", "headers"],
},
]
}

ONNX_CLASS_VALIDATION = {
"required": [
{"name": "__init__", "args": ["self", "onnx_client", "config"]},
{"name": "predict", "args": ["self", "payload"]},
{"name": "__init__", "required_args": ["self", "onnx_client", "config"]},
{
"name": "predict",
"required_args": ["self"],
"optional_args": ["payload", "query_params", "headers"],
},
]
}


def _validate_impl(impl, impl_req):
for optional_func in impl_req.get("optional", []):
_validate_optional_fn_args(impl, optional_func["name"], optional_func["args"])
for optional_func_signature in impl_req.get("optional", []):
_validate_optional_fn_args(impl, optional_func_signature)

for required_func in impl_req.get("required", []):
_validate_required_fn_args(impl, required_func["name"], required_func["args"])
for required_func_signature in impl_req.get("required", []):
_validate_required_fn_args(impl, required_func_signature)


def _validate_optional_fn_args(impl, fn_name, args):
if fn_name in vars(impl):
_validate_required_fn_args(impl, fn_name, args)
def _validate_optional_fn_args(impl, func_signature):
if getattr(impl, func_signature["name"], None):
_validate_required_fn_args(impl, func_signature)


def _validate_required_fn_args(impl, fn_name, args):
fn = getattr(impl, fn_name, None)
def _validate_required_fn_args(impl, func_signature):
fn = getattr(impl, func_signature["name"], None)
if not fn:
raise UserException('required function "{}" is not defined'.format(fn_name))
raise UserException(f'required function "{func_signature["name"]}" is not defined')

if not callable(fn):
raise UserException('"{}" is defined, but is not a function'.format(fn_name))
raise UserException(f'"{func_signature["name"]}" is defined, but is not a function')

argspec = inspect.getfullargspec(fn)

if argspec.args != args:
raise UserException(
'invalid signature for function "{}": expected arguments ({}) but found ({})'.format(
fn_name, ", ".join(args), ", ".join(argspec.args)
required_args = func_signature.get("required_args", [])
optional_args = func_signature.get("optional_args", [])
fn_str = f'{func_signature["name"]}({", ".join(argspec.args)})'

for arg_name in required_args:
if arg_name not in argspec.args:
raise UserException(
f'invalid signature for function "{fn_str}": "{arg_name}" is a required argument, but was not provided'
)

if arg_name == "self":
if argspec.args[0] != "self":
raise UserException(
f'invalid signature for function "{fn_str}": "self" must be the first argument'
)

seen_args = []
for arg_name in argspec.args:
if arg_name not in required_args and arg_name not in optional_args:
raise UserException(
f'invalid signature for function "{fn_str}": "{arg_name}" is not a supported argument'
)
)

if arg_name in seen_args:
raise UserException(
f'invalid signature for function "{fn_str}": "{arg_name}" is duplicated'
)

seen_args.append(arg_name)


tf_expected_dir_structure = """tensorflow model directories must have the following structure:
Expand Down
1 change: 1 addition & 0 deletions pkg/workloads/cortex/serve/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ dill==0.3.1.1
fastapi==0.54.1
msgpack==1.0.0
numpy==1.18.4
python-multipart==0.0.5
pyyaml==5.3.1
requests==2.23.0
uvicorn==0.11.5
41 changes: 39 additions & 2 deletions pkg/workloads/cortex/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sys
import os
import argparse
import inspect
import time
import json
import msgpack
Expand Down Expand Up @@ -154,11 +155,31 @@ async def register_request(request: Request, call_next):
return response


def predict(request: Any = Body(..., media_type="application/json")):
@app.middleware("http")
async def parse_payload(request: Request, call_next):
if "payload" not in local_cache["predict_fn_args"]:
return await call_next(request)

content_type = request.headers.get("content-type", "").lower()

if content_type.startswith("multipart/form") or content_type.startswith(
"application/x-www-form-urlencoded"
):
request.state.payload = await request.form()
elif content_type.startswith("application/json"):
request.state.payload = await request.json()
else:
request.state.payload = await request.body()

return await call_next(request)


def predict(request: Request):
api = local_cache["api"]
predictor_impl = local_cache["predictor_impl"]
args = build_predict_args(request)

prediction = predictor_impl.predict(request)
prediction = predictor_impl.predict(**args)

if isinstance(prediction, bytes):
response = Response(content=prediction, media_type="application/octet-stream")
Expand Down Expand Up @@ -194,6 +215,19 @@ def predict(request: Any = Body(..., media_type="application/json")):
return response


def build_predict_args(request: Request):
args = {}

if "payload" in local_cache["predict_fn_args"]:
args["payload"] = request.state.payload
if "headers" in local_cache["predict_fn_args"]:
args["headers"] = request.headers
if "query_params" in local_cache["predict_fn_args"]:
args["query_params"] = request.query_params

return args


def get_summary():
response = {"message": API_SUMMARY_MESSAGE}

Expand Down Expand Up @@ -230,6 +264,7 @@ def start():
storage = LocalStorage(os.getenv("CORTEX_CACHE_DIR"))
else:
storage = S3(bucket=os.environ["CORTEX_BUCKET"], region=os.environ["AWS_REGION"])

try:
raw_api_spec = get_spec(provider, storage, cache_dir, spec_path)
api = API(provider=provider, storage=storage, cache_dir=cache_dir, **raw_api_spec)
Expand All @@ -243,13 +278,15 @@ def start():
local_cache["provider"] = provider
local_cache["client"] = client
local_cache["predictor_impl"] = predictor_impl
local_cache["predict_fn_args"] = inspect.getfullargspec(predictor_impl.predict).args
predict_route = "/"
if provider != "local":
predict_route = "/predict"
local_cache["predict_route"] = predict_route
except:
cx_logger().exception("failed to start api")
sys.exit(1)

if (
provider != "local"
and api.monitoring is not None
Expand Down