diff --git a/doc/frameworks/tensorflow/upgrade_from_legacy.rst b/doc/frameworks/tensorflow/upgrade_from_legacy.rst index 84f77c01a2..ea80b65c51 100644 --- a/doc/frameworks/tensorflow/upgrade_from_legacy.rst +++ b/doc/frameworks/tensorflow/upgrade_from_legacy.rst @@ -245,11 +245,10 @@ For example, if you want to use JSON serialization and deserialization: .. code:: python - from sagemaker.predictor import json_deserializer + from sagemaker.deserializers import JSONDeserializer from sagemaker.serializers import JSONSerializer predictor.serializer = JSONSerializer() - predictor.accept = "application/json" - predictor.deserializer = json_deserializer + predictor.deserializer = JSONDeserializer() predictor.predict(data) diff --git a/src/sagemaker/amazon/ipinsights.py b/src/sagemaker/amazon/ipinsights.py index 01d6e3b404..3126435789 100644 --- a/src/sagemaker/amazon/ipinsights.py +++ b/src/sagemaker/amazon/ipinsights.py @@ -16,7 +16,8 @@ from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import ge, le -from sagemaker.predictor import Predictor, csv_serializer, json_deserializer +from sagemaker.deserializers import JSONDeserializer +from sagemaker.predictor import Predictor, csv_serializer from sagemaker.model import Model from sagemaker.session import Session from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -198,7 +199,7 @@ def __init__(self, endpoint_name, sagemaker_session=None): endpoint_name, sagemaker_session, serializer=csv_serializer, - deserializer=json_deserializer, + deserializer=JSONDeserializer(), ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index b49b5aeabb..e72dbc94a8 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -187,3 +187,24 @@ def deserialize(self, data, content_type): data.close() raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type)) + + +class JSONDeserializer(BaseDeserializer): + """Deserialize JSON data from an inference endpoint into a Python object.""" + + ACCEPT = "application/json" + + def deserialize(self, data, content_type): + """Deserialize JSON data from an inference endpoint into a Python object. + + Args: + data (botocore.response.StreamingBody): Data to be deserialized. + content_type (str): The MIME type of the data. + + Returns: + object: The JSON-formatted data deserialized into a Python object. + """ + try: + return json.load(codecs.getreader("utf-8")(data)) + finally: + data.close() diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index cc1f9396f2..8a7bb90e19 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -18,6 +18,7 @@ import packaging.version import sagemaker +from sagemaker.deserializers import JSONDeserializer from sagemaker.fw_utils import ( create_image_uri, model_code_key_prefix, @@ -26,7 +27,7 @@ ) from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.mxnet import defaults -from sagemaker.predictor import Predictor, json_deserializer +from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer logger = logging.getLogger("sagemaker") @@ -51,7 +52,7 @@ def __init__(self, endpoint_name, sagemaker_session=None): using the default AWS configuration chain. """ super(MXNetPredictor, self).__init__( - endpoint_name, sagemaker_session, JSONSerializer(), json_deserializer + endpoint_name, sagemaker_session, JSONSerializer(), JSONDeserializer() ) diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 40e11cad83..c6012d3722 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -13,13 +13,11 @@ """Placeholder docstring""" from __future__ import print_function, absolute_import -import codecs import csv -import json from six import StringIO import numpy as np -from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV +from sagemaker.content_types import CONTENT_TYPE_CSV from sagemaker.deserializers import BaseDeserializer from sagemaker.model_monitor import DataCaptureConfig from sagemaker.serializers import BaseSerializer @@ -594,29 +592,3 @@ def _row_to_csv(obj): if isinstance(obj, str): return obj return ",".join(obj) - - -class _JsonDeserializer(object): - """Placeholder docstring""" - - def __init__(self): - """Placeholder docstring""" - self.accept = CONTENT_TYPE_JSON - - def __call__(self, stream, content_type): - """Decode a JSON object into the corresponding Python object. - - Args: - stream (stream): The response stream to be deserialized. - content_type (str): The content type of the response. - - Returns: - object: Body of the response deserialized into a JSON object. - """ - try: - return json.load(codecs.getreader("utf-8")(stream)) - finally: - stream.close() - - -json_deserializer = _JsonDeserializer() diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 5215f971cc..4018e3e776 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -17,8 +17,9 @@ import sagemaker from sagemaker.content_types import CONTENT_TYPE_JSON +from sagemaker.deserializers import JSONDeserializer from sagemaker.fw_utils import create_image_uri -from sagemaker.predictor import json_deserializer, Predictor +from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer @@ -32,7 +33,7 @@ def __init__( endpoint_name, sagemaker_session=None, serializer=JSONSerializer(), - deserializer=json_deserializer, + deserializer=JSONDeserializer(), content_type=None, model_name=None, model_version=None, diff --git a/tests/integ/test_byo_estimator.py b/tests/integ/test_byo_estimator.py index 9cc4902305..23e11a2dde 100644 --- a/tests/integ/test_byo_estimator.py +++ b/tests/integ/test_byo_estimator.py @@ -86,7 +86,7 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type, training_se predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name) predictor.serializer = fm_serializer predictor.content_type = "application/json" - predictor.deserializer = sagemaker.predictor.json_deserializer + predictor.deserializer = sagemaker.deserializers.JSONDeserializer() result = predictor.predict(training_set[0][:10]) @@ -132,7 +132,7 @@ def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, train predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name) predictor.serializer = fm_serializer predictor.content_type = "application/json" - predictor.deserializer = sagemaker.predictor.json_deserializer + predictor.deserializer = sagemaker.deserializers.JSONDeserializer() result = predictor.predict(training_set[0][:10]) diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index 170773578f..002e696d71 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -188,7 +188,7 @@ def test_predict_jsons_json_content_type(tfs_predictor): tfs_predictor.endpoint_name, tfs_predictor.sagemaker_session, serializer=None, - deserializer=sagemaker.predictor.json_deserializer, + deserializer=sagemaker.deserializers.JSONDeserializer(), content_type="application/json", accept="application/json", ) @@ -205,7 +205,7 @@ def test_predict_jsons(tfs_predictor): tfs_predictor.endpoint_name, tfs_predictor.sagemaker_session, serializer=None, - deserializer=sagemaker.predictor.json_deserializer, + deserializer=sagemaker.deserializers.JSONDeserializer(), content_type="application/jsons", accept="application/jsons", ) @@ -222,7 +222,7 @@ def test_predict_jsonlines(tfs_predictor): tfs_predictor.endpoint_name, tfs_predictor.sagemaker_session, serializer=None, - deserializer=sagemaker.predictor.json_deserializer, + deserializer=sagemaker.deserializers.JSONDeserializer(), content_type="application/jsonlines", accept="application/jsonlines", ) diff --git a/tests/integ/test_tuner.py b/tests/integ/test_tuner.py index 9fa3e9b482..82111a30c7 100644 --- a/tests/integ/test_tuner.py +++ b/tests/integ/test_tuner.py @@ -25,9 +25,9 @@ from sagemaker.amazon.amazon_estimator import get_image_uri from sagemaker.amazon.common import read_records from sagemaker.chainer import Chainer +from sagemaker.deserializers import JSONDeserializer from sagemaker.estimator import Estimator from sagemaker.mxnet.estimator import MXNet -from sagemaker.predictor import json_deserializer from sagemaker.pytorch import PyTorch from sagemaker.tensorflow import TensorFlow from sagemaker.tuner import ( @@ -891,7 +891,7 @@ def test_tuning_byo_estimator(sagemaker_session, cpu_instance_type): predictor = tuner.deploy(1, cpu_instance_type, endpoint_name=best_training_job) predictor.serializer = _fm_serializer predictor.content_type = "application/json" - predictor.deserializer = json_deserializer + predictor.deserializer = JSONDeserializer() result = predictor.predict(datasets.one_p_mnist()[0][:10]) diff --git a/tests/integ/test_tuner_multi_algo.py b/tests/integ/test_tuner_multi_algo.py index f382a17ef0..640237e20c 100644 --- a/tests/integ/test_tuner_multi_algo.py +++ b/tests/integ/test_tuner_multi_algo.py @@ -21,8 +21,8 @@ from sagemaker.amazon.amazon_estimator import get_image_uri from sagemaker.analytics import HyperparameterTuningJobAnalytics from sagemaker.content_types import CONTENT_TYPE_JSON +from sagemaker.deserializers import JSONDeserializer from sagemaker.estimator import Estimator -from sagemaker.predictor import json_deserializer from sagemaker.tuner import ContinuousParameter, IntegerParameter, HyperparameterTuner from tests.integ import datasets, DATA_DIR, TUNING_DEFAULT_TIMEOUT_MINUTES from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name @@ -219,7 +219,7 @@ def _create_training_inputs(sagemaker_session): def _make_prediction(predictor, data): predictor.serializer = _prediction_data_serializer predictor.content_type = CONTENT_TYPE_JSON - predictor.deserializer = json_deserializer + predictor.deserializer = JSONDeserializer() return predictor.predict(data) diff --git a/tests/unit/sagemaker/test_deserializers.py b/tests/unit/sagemaker/test_deserializers.py index edd4deb474..44a459ae47 100644 --- a/tests/unit/sagemaker/test_deserializers.py +++ b/tests/unit/sagemaker/test_deserializers.py @@ -23,6 +23,7 @@ CSVDeserializer, StreamDeserializer, NumpyDeserializer, + JSONDeserializer, ) @@ -145,3 +146,28 @@ def test_numpy_deserializer_from_npy_object_array(numpy_deserializer): result = numpy_deserializer.deserialize(stream, "application/x-npy") assert np.array_equal(array, result) + + +@pytest.fixture +def json_deserializer(): + return JSONDeserializer() + + +def test_json_deserializer_array(json_deserializer): + result = json_deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json") + + assert result == [1, 2, 3] + + +def test_json_deserializer_2dimensional(json_deserializer): + result = json_deserializer.deserialize( + io.BytesIO(b"[[1, 2, 3], [3, 4, 5]]"), "application/json" + ) + + assert result == [[1, 2, 3], [3, 4, 5]] + + +def test_json_deserializer_invalid_data(json_deserializer): + with pytest.raises(ValueError) as error: + json_deserializer.deserialize(io.BytesIO(b"[[1]"), "application/json") + assert "column" in str(error) diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 2cb3b22781..955ecf7fc2 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import io import json import os @@ -21,10 +20,7 @@ from mock import Mock, call, patch from sagemaker.predictor import Predictor -from sagemaker.predictor import ( - json_deserializer, - csv_serializer, -) +from sagemaker.predictor import csv_serializer from sagemaker.serializers import JSONSerializer from tests.unit import DATA_DIR @@ -97,24 +93,6 @@ def test_csv_serializer_csv_reader(): assert result == validation_data -def test_json_deserializer_array(): - result = json_deserializer(io.BytesIO(b"[1, 2, 3]"), "application/json") - - assert result == [1, 2, 3] - - -def test_json_deserializer_2dimensional(): - result = json_deserializer(io.BytesIO(b"[[1, 2, 3], [3, 4, 5]]"), "application/json") - - assert result == [[1, 2, 3], [3, 4, 5]] - - -def test_json_deserializer_invalid_data(): - with pytest.raises(ValueError) as error: - json_deserializer(io.BytesIO(b"[[1]"), "application/json") - assert "column" in str(error) - - # testing 'predict' invocations