-
Notifications
You must be signed in to change notification settings - Fork 811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix JsonInput micro-batching support #860
Changes from all commits
0091f75
3340eef
f7c6390
c7314f4
07927b8
be8f14e
e20544e
ceb0d3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,9 +26,28 @@ | |
|
||
|
||
class JsonInput(BaseInputAdapter): | ||
"""JsonInput parses REST API request or CLI command into parsed_json(a | ||
dict in python) and pass down to user defined API function | ||
|
||
"""JsonInput parses REST API request or CLI command into parsed_jsons(a list of | ||
json serializable object in python) and pass down to user defined API function | ||
|
||
**** | ||
How to upgrade from LegacyJsonInput(JsonInput before 0.8.3) | ||
|
||
To enable micro batching for API with json inputs, custom bento service should use | ||
JsonInput and modify the handler method like this: | ||
``` | ||
@bentoml.api(input=LegacyJsonInput()) | ||
def predict(self, parsed_json): | ||
result = do_something_to_json(parsed_json) | ||
return result | ||
``` | ||
---> | ||
``` | ||
@bentoml.api(input=JsonInput()) | ||
def predict(self, parsed_jsons): | ||
results = do_something_to_list_of_json(parsed_jsons) | ||
return results | ||
``` | ||
For clients, the request is the same as LegacyJsonInput, each includes single json. | ||
""" | ||
|
||
BATCH_MODE_SUPPORTED = True | ||
|
@@ -37,16 +56,15 @@ def __init__(self, is_batch_input=False, **base_kwargs): | |
super(JsonInput, self).__init__(is_batch_input=is_batch_input, **base_kwargs) | ||
|
||
def handle_request(self, request: flask.Request, func): | ||
if request.content_type == "application/json": | ||
parsed_json = json.loads(request.get_data(as_text=True)) | ||
else: | ||
if request.content_type != "application/json": | ||
raise BadInput( | ||
"Request content-type must be 'application/json' for this " | ||
"BentoService API" | ||
) | ||
|
||
result = func(parsed_json) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At least it should be func([parsed_json]) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will be a regression right? Should we consider something like the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LegacyJsonInput? |
||
return self.output_adapter.to_response(result, request) | ||
resps = self.handle_batch_request( | ||
[SimpleRequest.from_flask_request(request)], func | ||
) | ||
return resps[0].to_flask_response() | ||
|
||
def handle_batch_request( | ||
self, requests: Iterable[SimpleRequest], func | ||
|
@@ -90,7 +108,7 @@ def handle_cli(self, args, func): | |
content = parsed_args.input | ||
|
||
input_json = json.loads(content) | ||
result = func(input_json) | ||
result = func([input_json])[0] | ||
return self.output_adapter.to_cli(result, unknown_args) | ||
|
||
def handle_aws_lambda_event(self, event, func): | ||
|
@@ -102,5 +120,5 @@ def handle_aws_lambda_event(self, event, func): | |
"BentoService API lambda endpoint" | ||
) | ||
|
||
result = func(parsed_json) | ||
result = func([parsed_json])[0] | ||
return self.output_adapter.to_aws_lambda_event(result, event) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright 2019 Atalaya Tech, Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import json | ||
import argparse | ||
from typing import Iterable | ||
|
||
import flask | ||
|
||
from bentoml.exceptions import BadInput | ||
from bentoml.marshal.utils import SimpleResponse, SimpleRequest | ||
from bentoml.adapters.base_input import BaseInputAdapter | ||
|
||
|
||
class LegacyJsonInput(BaseInputAdapter): | ||
"""LegacyJsonInput parses REST API request or CLI command into parsed_json(a | ||
dict in python) and pass down to user defined API function | ||
|
||
""" | ||
|
||
BATCH_MODE_SUPPORTED = False | ||
|
||
def __init__(self, is_batch_input=False, **base_kwargs): | ||
super(LegacyJsonInput, self).__init__( | ||
is_batch_input=is_batch_input, **base_kwargs | ||
) | ||
|
||
def handle_request(self, request: flask.Request, func): | ||
if request.content_type.lower() == "application/json": | ||
parsed_json = json.loads(request.get_data(as_text=True)) | ||
else: | ||
raise BadInput( | ||
"Request content-type must be 'application/json' for this " | ||
"BentoService API lambda endpoint" | ||
) | ||
|
||
result = func(parsed_json) | ||
return self.output_adapter.to_response(result, request) | ||
|
||
def handle_batch_request( | ||
self, requests: Iterable[SimpleRequest], func | ||
) -> Iterable[SimpleResponse]: | ||
raise NotImplementedError("Use JsonInput instead to enable micro batching") | ||
|
||
def handle_cli(self, args, func): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--input", required=True) | ||
parsed_args, unknown_args = parser.parse_known_args(args) | ||
|
||
if os.path.isfile(parsed_args.input): | ||
with open(parsed_args.input, "r") as content_file: | ||
content = content_file.read() | ||
else: | ||
content = parsed_args.input | ||
|
||
input_json = json.loads(content) | ||
result = func(input_json) | ||
return self.output_adapter.to_cli(result, unknown_args) | ||
|
||
def handle_aws_lambda_event(self, event, func): | ||
if event["headers"]["Content-Type"] == "application/json": | ||
parsed_json = json.loads(event["body"]) | ||
else: | ||
raise BadInput( | ||
"Request content-type must be 'application/json' for this " | ||
"BentoService API lambda endpoint" | ||
) | ||
|
||
result = func(parsed_json) | ||
return self.output_adapter.to_aws_lambda_event(result, event) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls help to refine instructions here @parano