Skip to content

Support multiple types of responses in predict() #915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
47 changes: 47 additions & 0 deletions docs/deployments/predictors.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Which Predictor you use depends on how your model is exported:
* [ONNX Predictor](#onnx-predictor) if your model is exported in the ONNX format
* [Python Predictor](#python-predictor) for all other cases

The response type of the predictor can vary depending on your requirements, see [API responses](#api-responses) below.

## Project files

Cortex makes all files in the project directory (i.e. the directory which contains `cortex.yaml`) available for use in your Predictor implementation. Python bytecode files (`*.pyc`, `*.pyo`, `*.pyd`), files or folders that start with `.`, and the api configuration file (e.g. `cortex.yaml`) are excluded. You may also add a `.cortexignore` file at the root of the project directory, which follows the same syntax and behavior as a [.gitignore file](https://git-scm.com/docs/gitignore).
Expand Down Expand Up @@ -308,3 +310,48 @@ requests==2.22.0
The pre-installed system packages are listed in [images/onnx-serve/Dockerfile](https://github.com/cortexlabs/cortex/tree/master/images/onnx-serve/Dockerfile) (for CPU) or [images/onnx-serve-gpu/Dockerfile](https://github.com/cortexlabs/cortex/tree/master/images/onnx-serve-gpu/Dockerfile) (for GPU).

If your application requires additional dependencies, you can install additional [Python packages](python-packages.md) and [system packages](system-packages.md).

## API responses

The response of your `predict()` function may be:

1. A JSON-serializable object (*lists*, *dictionaries*, *numbers*, etc.)

2. A `string` object (e.g. `"class 1"`)

3. A `bytes` object (e.g. `bytes(4)` or `pickle.dumps(obj)`)

4. An instance of [starlette.responses.Response](https://www.starlette.io/responses/#response)

Here are some examples:

```python
def predict(self, payload):
# json-serializable object
response = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
return response
```

```python
def predict(self, payload):
# string object
response = "class 1"
return response
```

```python
def predict(self, payload):
# bytes-like object
array = np.random.randn(3, 3)
response = pickle.dumps(array)
return response
```

```python
def predict(self, payload):
# starlette.responses.Response
data = "class 1"
response = starlette.responses.Response(
content=data, media_type="text/plain")
return response
```
30 changes: 19 additions & 11 deletions pkg/workloads/cortex/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,25 @@ def predict(request: Any = Body(..., media_type="application/json"), debug=False
debug_obj("payload", request, debug)
prediction = predictor_impl.predict(request)

try:
json_string = json.dumps(prediction)
except Exception as e:
raise UserRuntimeException(
f"the return value of predict() or one of its nested values is not JSON serializable",
str(e),
) from e

debug_obj("prediction", json_string, debug)

response = Response(content=json_string, media_type="application/json")
if isinstance(prediction, bytes):
debug_obj("prediction", prediction, debug)
response = Response(content=prediction, media_type="application/octet-stream")
elif isinstance(prediction, str):
debug_obj("prediction", prediction, debug)
response = Response(content=prediction, media_type="text/plain")
elif isinstance(prediction, Response):
debug_obj("prediction", prediction, debug)
response = prediction
else:
try:
json_string = json.dumps(prediction)
except Exception as e:
raise UserRuntimeException(
"the return value of predict() or one of its nested values is not JSON serializable; please return a JSON serializeable object, a bytes object, a string object, or a starlette.response.Response object",
str(e),
) from e
debug_obj("prediction", json_string, debug)
response = Response(content=json_string, media_type="application/json")

if api.tracker is not None:
try:
Expand Down