Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JsonInput micro-batching support #860

Merged
merged 8 commits into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bentoml/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from bentoml.adapters.dataframe_input import DataframeInput
from bentoml.adapters.tensorflow_tensor_input import TfTensorInput
from bentoml.adapters.json_input import JsonInput
from bentoml.adapters.legacy_json_input import LegacyJsonInput
from bentoml.adapters.image_input import ImageInput
from bentoml.adapters.multi_image_input import MultiImageInput
from bentoml.adapters.legacy_image_input import LegacyImageInput
Expand Down Expand Up @@ -48,6 +49,7 @@
"TfTensorInput",
'TfTensorOutput',
"JsonInput",
"LegacyJsonInput",
'JsonSerializableOutput',
"ImageInput",
"MultiImageInput",
Expand Down
40 changes: 29 additions & 11 deletions bentoml/adapters/json_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,28 @@


class JsonInput(BaseInputAdapter):
"""JsonInput parses REST API request or CLI command into parsed_json(a
dict in python) and pass down to user defined API function

"""JsonInput parses REST API request or CLI command into parsed_jsons(a list of
json serializable object in python) and pass down to user defined API function

****
How to upgrade from LegacyJsonInput(JsonInput before 0.8.3)

To enable micro batching for API with json inputs, custom bento service should use
JsonInput and modify the handler method like this:
```
@bentoml.api(input=LegacyJsonInput())
def predict(self, parsed_json):
result = do_something_to_json(parsed_json)
return result
```
--->
```
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls help to refine instructions here @parano

@bentoml.api(input=JsonInput())
def predict(self, parsed_jsons):
results = do_something_to_list_of_json(parsed_jsons)
return results
```
For clients, the request is the same as LegacyJsonInput, each includes single json.
"""

BATCH_MODE_SUPPORTED = True
Expand All @@ -37,16 +56,15 @@ def __init__(self, is_batch_input=False, **base_kwargs):
super(JsonInput, self).__init__(is_batch_input=is_batch_input, **base_kwargs)

def handle_request(self, request: flask.Request, func):
if request.content_type == "application/json":
parsed_json = json.loads(request.get_data(as_text=True))
else:
if request.content_type != "application/json":
raise BadInput(
"Request content-type must be 'application/json' for this "
"BentoService API"
)

result = func(parsed_json)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least it should be func([parsed_json])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be a regression right? Should we consider something like the LegacyImageHandler?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LegacyJsonInput?

return self.output_adapter.to_response(result, request)
resps = self.handle_batch_request(
[SimpleRequest.from_flask_request(request)], func
)
return resps[0].to_flask_response()

def handle_batch_request(
self, requests: Iterable[SimpleRequest], func
Expand Down Expand Up @@ -90,7 +108,7 @@ def handle_cli(self, args, func):
content = parsed_args.input

input_json = json.loads(content)
result = func(input_json)
result = func([input_json])[0]
return self.output_adapter.to_cli(result, unknown_args)

def handle_aws_lambda_event(self, event, func):
Expand All @@ -102,5 +120,5 @@ def handle_aws_lambda_event(self, event, func):
"BentoService API lambda endpoint"
)

result = func(parsed_json)
result = func([parsed_json])[0]
return self.output_adapter.to_aws_lambda_event(result, event)
82 changes: 82 additions & 0 deletions bentoml/adapters/legacy_json_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2019 Atalaya Tech, Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import json
import argparse
from typing import Iterable

import flask

from bentoml.exceptions import BadInput
from bentoml.marshal.utils import SimpleResponse, SimpleRequest
from bentoml.adapters.base_input import BaseInputAdapter


class LegacyJsonInput(BaseInputAdapter):
"""LegacyJsonInput parses REST API request or CLI command into parsed_json(a
dict in python) and pass down to user defined API function

"""

BATCH_MODE_SUPPORTED = False

def __init__(self, is_batch_input=False, **base_kwargs):
super(LegacyJsonInput, self).__init__(
is_batch_input=is_batch_input, **base_kwargs
)

def handle_request(self, request: flask.Request, func):
if request.content_type.lower() == "application/json":
parsed_json = json.loads(request.get_data(as_text=True))
else:
raise BadInput(
"Request content-type must be 'application/json' for this "
"BentoService API lambda endpoint"
)

result = func(parsed_json)
return self.output_adapter.to_response(result, request)

def handle_batch_request(
self, requests: Iterable[SimpleRequest], func
) -> Iterable[SimpleResponse]:
raise NotImplementedError("Use JsonInput instead to enable micro batching")

def handle_cli(self, args, func):
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True)
parsed_args, unknown_args = parser.parse_known_args(args)

if os.path.isfile(parsed_args.input):
with open(parsed_args.input, "r") as content_file:
content = content_file.read()
else:
content = parsed_args.input

input_json = json.loads(content)
result = func(input_json)
return self.output_adapter.to_cli(result, unknown_args)

def handle_aws_lambda_event(self, event, func):
if event["headers"]["Content-Type"] == "application/json":
parsed_json = json.loads(event["body"])
else:
raise BadInput(
"Request content-type must be 'application/json' for this "
"BentoService API lambda endpoint"
)

result = func(parsed_json)
return self.output_adapter.to_aws_lambda_event(result, event)
10 changes: 8 additions & 2 deletions bentoml/marshal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@ class SimpleRequest(NamedTuple):
@property
@lru_cache()
def formated_headers(self):
return {hk.decode().lower(): hv.decode() for hk, hv in self.headers or tuple()}
return {
hk.decode("ascii").lower(): hv.decode("ascii")
for hk, hv in self.headers or tuple()
}

@classmethod
def from_flask_request(cls, request):
return cls(tuple(request.headers), request.get_data())
return cls(
tuple((k.encode("ascii"), v.encode("ascii")) for k, v in request.headers),
request.get_data(),
)


class SimpleResponse(NamedTuple):
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@
"pandas",
"pylint>=2.5.2",
"pytest-cov>=2.7.1",
"pytest>=4.6.0",
"pytest>=5.4.0",
"pytest-asyncio",
"scikit-learn",
"protobuf==3.6.0",
] + aws_sam_cli
Expand Down
5 changes: 5 additions & 0 deletions tests/bento_service_examples/example_bento_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
ImageInput,
LegacyImageInput,
JsonInput,
LegacyJsonInput,
# FastaiImageInput,
)
from bentoml.handlers import DataframeHandler # deprecated
Expand Down Expand Up @@ -47,6 +48,10 @@ def predict_images(self, original, compared):
def predict_json(self, input_data):
return self.artifacts.model.predict_json(input_data)

@bentoml.api(input=LegacyJsonInput())
def predict_legacy_json(self, input_data):
return self.artifacts.model.predict_legacy_json(input_data)

# Disabling fastai related tests to fix travis build
# @bentoml.api(input=FastaiImageInput())
# def predict_fastai_image(self, input_data):
Expand Down
24 changes: 22 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,22 @@
from tests.bento_service_examples.example_bento_service import ExampleBentoService


def pytest_configure():
# dataframe json orients
pytest.DF_ORIENTS = {
'split',
'records',
'index',
'columns',
'values',
# 'table', # TODO(bojiang)
}
pytest.DF_AUTO_ORIENTS = {
'records',
'columns',
}


@pytest.fixture()
def img_file(tmpdir):
img_file_ = tmpdir.join("test_img.jpg")
Expand All @@ -31,8 +47,12 @@ def predict_image(self, input_datas):
assert input_data is not None
return [input_data.shape for input_data in input_datas]

def predict_json(self, input_data):
assert input_data is not None
def predict_json(self, input_jsons):
assert input_jsons
return [{"ok": True}] * len(input_jsons)

def predict_legacy_json(self, input_json):
assert input_json is not None
return {"ok": True}


Expand Down
Empty file added tests/handlers/__init__.py
Empty file.
16 changes: 3 additions & 13 deletions tests/handlers/test_dataframe_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_function(df):
input_adapter = DataframeInput()
csv_data = 'name,game,city\njohn,mario,sf'
request = MagicMock(spec=flask.Request)
request.headers = {'orient': 'records'}
request.headers = (('orient', 'records'),)
request.content_type = 'text/csv'
request.get_data.return_value = csv_data

Expand All @@ -142,16 +142,6 @@ def assert_df_equal(left: pd.DataFrame, right: pd.DataFrame):
)


DF_ORIENTS = {
'split',
'records',
'index',
'columns',
'values',
# 'table', # TODO(bojiang)
}


DF_CASES = (
pd.DataFrame(np.random.rand(1, 3)),
pd.DataFrame(np.random.rand(2, 3)),
Expand All @@ -171,7 +161,7 @@ def df(request):
return request.param


@pytest.fixture(params=DF_ORIENTS)
@pytest.fixture(params=pytest.DF_ORIENTS)
def orient(request):
return request.param

Expand All @@ -181,7 +171,7 @@ def test_batch_read_dataframes_from_mixed_json_n_csv(df):
test_types = []

# test content_type=application/json with various orients
for orient in DF_ORIENTS:
for orient in pytest.DF_ORIENTS:
try:
assert_df_equal(df, pd.read_json(df.to_json(orient=orient)))
except (AssertionError, ValueError):
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_fastai_image_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ def predict(self, image):
api = ms.get_service_apis()[0]
test_args = ["--input={}".format(img_file)]
api.handle_cli(test_args)
out, err = capsys.readouterr()
out, _ = capsys.readouterr()
assert out.strip() == '[3, 10, 10]'
14 changes: 7 additions & 7 deletions tests/handlers/test_json_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,26 @@


def test_json_handle_cli(capsys, tmpdir):
def test_func(obj):
return obj[0]["name"]
def test_func(objs):
return [obj["name"] for obj in objs]

input_adapter = JsonInput()

json_file = tmpdir.join("test.json")
with open(str(json_file), "w") as f:
f.write('[{"name": "john","game": "mario","city": "sf"}]')
f.write('{"name": "john","game": "mario","city": "sf"}')

test_args = ["--input={}".format(json_file)]
input_adapter.handle_cli(test_args, test_func)
out, err = capsys.readouterr()
out, _ = capsys.readouterr()
assert out.strip().endswith("john")


def test_json_handle_aws_lambda_event():
test_content = '[{"name": "john","game": "mario","city": "sf"}]'
test_content = '{"name": "john","game": "mario","city": "sf"}'

def test_func(obj):
return obj[0]["name"]
def test_func(objs):
return [obj["name"] for obj in objs]

input_adapter = JsonInput()
success_event_obj = {
Expand Down
4 changes: 2 additions & 2 deletions tests/handlers/test_legacy_image_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_image_input_cli(capsys, img_file):

test_args = ["--input={}".format(img_file)]
test_image_input.handle_cli(test_args, predict)
out, err = capsys.readouterr()
out, _ = capsys.readouterr()
assert out.strip().endswith("(10, 10, 3)")


Expand All @@ -28,7 +28,7 @@ def test_image_input_aws_lambda_event(img_file):
try:
image_bytes_encoded = base64.encodebytes(content)
except AttributeError:
image_bytes_encoded = base64.encodestring(str(img_file))
image_bytes_encoded = base64.encodebytes(img_file)

aws_lambda_event = {
"body": image_bytes_encoded,
Expand Down
Loading