Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions doc/frameworks/tensorflow/upgrade_from_legacy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,11 @@ For example, if you want to use JSON serialization and deserialization:

.. code:: python

from sagemaker.predictor import json_deserializer, json_serializer
from sagemaker.predictor import json_serializer
from sagemaker.deserializers import JSONDeserializer

predictor.content_type = "application/json"
predictor.serializer = json_serializer
predictor.accept = "application/json"
predictor.deserializer = json_deserializer
predictor.deserializer = JSONDeserializer()

predictor.predict(data)
5 changes: 3 additions & 2 deletions src/sagemaker/amazon/ipinsights.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
from sagemaker.amazon.validation import ge, le
from sagemaker.predictor import Predictor, csv_serializer, json_deserializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.predictor import Predictor, csv_serializer
from sagemaker.model import Model
from sagemaker.session import Session
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
Expand Down Expand Up @@ -198,7 +199,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
endpoint_name,
sagemaker_session,
serializer=csv_serializer,
deserializer=json_deserializer,
deserializer=JSONDeserializer(),
)


Expand Down
21 changes: 21 additions & 0 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,24 @@ def deserialize(self, data, content_type):
data.close()

raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))


class JSONDeserializer(BaseDeserializer):
"""Deserialize data from an inference endpoint into a Python dictionary."""

ACCEPT = "application/json"

def deserialize(self, data, content_type):
"""Deserialize data from an inference endpoint into a Python dictionary.

Args:
data (botocore.response.StreamingBody): Data to be deserialized.
content_type (str): The MIME type of the data.

Returns:
dict: The data deserialized into a Python dictionary.
"""
try:
return json.load(codecs.getreader("utf-8")(data))
finally:
data.close()
5 changes: 3 additions & 2 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import packaging.version

import sagemaker
from sagemaker.deserializers import JSONDeserializer
from sagemaker.fw_utils import (
create_image_uri,
model_code_key_prefix,
Expand All @@ -26,7 +27,7 @@
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.mxnet import defaults
from sagemaker.predictor import Predictor, json_serializer, json_deserializer
from sagemaker.predictor import Predictor, json_serializer

logger = logging.getLogger("sagemaker")

Expand All @@ -50,7 +51,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(MXNetPredictor, self).__init__(
endpoint_name, sagemaker_session, json_serializer, json_deserializer
endpoint_name, sagemaker_session, json_serializer, JSONDeserializer()
)


Expand Down
27 changes: 0 additions & 27 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"""Placeholder docstring"""
from __future__ import print_function, absolute_import

import codecs
import csv
import json
import six
Expand Down Expand Up @@ -672,32 +671,6 @@ def _json_serialize_from_buffer(buff):
return buff.read()


class _JsonDeserializer(object):
"""Placeholder docstring"""

def __init__(self):
"""Placeholder docstring"""
self.accept = CONTENT_TYPE_JSON

def __call__(self, stream, content_type):
"""Decode a JSON object into the corresponding Python object.

Args:
stream (stream): The response stream to be deserialized.
content_type (str): The content type of the response.

Returns:
object: Body of the response deserialized into a JSON object.
"""
try:
return json.load(codecs.getreader("utf-8")(stream))
finally:
stream.close()


json_deserializer = _JsonDeserializer()


class _NPYSerializer(object):
"""Placeholder docstring"""

Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

import sagemaker
from sagemaker.content_types import CONTENT_TYPE_JSON
from sagemaker.deserializers import JSONDeserializer
from sagemaker.fw_utils import create_image_uri
from sagemaker.predictor import json_serializer, json_deserializer, Predictor
from sagemaker.predictor import json_serializer, Predictor


class TensorFlowPredictor(Predictor):
Expand All @@ -31,7 +32,7 @@ def __init__(
endpoint_name,
sagemaker_session=None,
serializer=json_serializer,
deserializer=json_deserializer,
deserializer=JSONDeserializer(),
content_type=None,
model_name=None,
model_version=None,
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_byo_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type, training_se
predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name)
predictor.serializer = fm_serializer
predictor.content_type = "application/json"
predictor.deserializer = sagemaker.predictor.json_deserializer
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()

result = predictor.predict(training_set[0][:10])

Expand Down Expand Up @@ -132,7 +132,7 @@ def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, train
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
predictor.serializer = fm_serializer
predictor.content_type = "application/json"
predictor.deserializer = sagemaker.predictor.json_deserializer
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()

result = predictor.predict(training_set[0][:10])

Expand Down
6 changes: 3 additions & 3 deletions tests/integ/test_tfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_predict_jsons_json_content_type(tfs_predictor):
tfs_predictor.endpoint_name,
tfs_predictor.sagemaker_session,
serializer=None,
deserializer=sagemaker.predictor.json_deserializer,
deserializer=sagemaker.deserializers.JSONDeserializer(),
content_type="application/json",
accept="application/json",
)
Expand All @@ -205,7 +205,7 @@ def test_predict_jsons(tfs_predictor):
tfs_predictor.endpoint_name,
tfs_predictor.sagemaker_session,
serializer=None,
deserializer=sagemaker.predictor.json_deserializer,
deserializer=sagemaker.predictor.JSONDeserializer(),
content_type="application/jsons",
accept="application/jsons",
)
Expand All @@ -222,7 +222,7 @@ def test_predict_jsonlines(tfs_predictor):
tfs_predictor.endpoint_name,
tfs_predictor.sagemaker_session,
serializer=None,
deserializer=sagemaker.predictor.json_deserializer,
deserializer=sagemaker.predictor.JSONDeserializer(),
content_type="application/jsonlines",
accept="application/jsonlines",
)
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.amazon.common import read_records
from sagemaker.chainer import Chainer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.estimator import Estimator
from sagemaker.mxnet.estimator import MXNet
from sagemaker.predictor import json_deserializer
from sagemaker.pytorch import PyTorch
from sagemaker.tensorflow import TensorFlow
from sagemaker.tuner import (
Expand Down Expand Up @@ -891,7 +891,7 @@ def test_tuning_byo_estimator(sagemaker_session, cpu_instance_type):
predictor = tuner.deploy(1, cpu_instance_type, endpoint_name=best_training_job)
predictor.serializer = _fm_serializer
predictor.content_type = "application/json"
predictor.deserializer = json_deserializer
predictor.deserializer = JSONDeserializer()

result = predictor.predict(datasets.one_p_mnist()[0][:10])

Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_tuner_multi_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.analytics import HyperparameterTuningJobAnalytics
from sagemaker.content_types import CONTENT_TYPE_JSON
from sagemaker.deserializers import JSONDeserializer
from sagemaker.estimator import Estimator
from sagemaker.predictor import json_deserializer
from sagemaker.tuner import ContinuousParameter, IntegerParameter, HyperparameterTuner
from tests.integ import datasets, DATA_DIR, TUNING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
Expand Down Expand Up @@ -219,7 +219,7 @@ def _create_training_inputs(sagemaker_session):
def _make_prediction(predictor, data):
predictor.serializer = _prediction_data_serializer
predictor.content_type = CONTENT_TYPE_JSON
predictor.deserializer = json_deserializer
predictor.deserializer = JSONDeserializer()
return predictor.predict(data)


Expand Down
26 changes: 26 additions & 0 deletions tests/unit/sagemaker/test_deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
BytesDeserializer,
StreamDeserializer,
NumpyDeserializer,
JSONDeserializer,
)


Expand Down Expand Up @@ -119,3 +120,28 @@ def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
result = numpy_deserializer.deserialize(stream, "application/x-npy")

assert np.array_equal(array, result)


@pytest.fixture
def json_deserializer():
return JSONDeserializer()


def test_json_deserializer_array(json_deserializer):
result = json_deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")

assert result == [1, 2, 3]


def test_json_deserializer_2dimensional(json_deserializer):
result = json_deserializer.deserialize(
io.BytesIO(b"[[1, 2, 3], [3, 4, 5]]"), "application/json"
)

assert result == [[1, 2, 3], [3, 4, 5]]


def test_json_deserializer_invalid_data(json_deserializer):
with pytest.raises(ValueError) as error:
json_deserializer.deserialize(io.BytesIO(b"[[1]"), "application/json")
assert "column" in str(error)
19 changes: 0 additions & 19 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from sagemaker.predictor import Predictor
from sagemaker.predictor import (
json_serializer,
json_deserializer,
csv_serializer,
csv_deserializer,
npy_serializer,
Expand Down Expand Up @@ -161,24 +160,6 @@ def test_csv_deserializer_2dimensional():
assert result == [["1", "2", "3"], ["3", "4", "5"]]


def test_json_deserializer_array():
result = json_deserializer(io.BytesIO(b"[1, 2, 3]"), "application/json")

assert result == [1, 2, 3]


def test_json_deserializer_2dimensional():
result = json_deserializer(io.BytesIO(b"[[1, 2, 3], [3, 4, 5]]"), "application/json")

assert result == [[1, 2, 3], [3, 4, 5]]


def test_json_deserializer_invalid_data():
with pytest.raises(ValueError) as error:
json_deserializer(io.BytesIO(b"[[1]"), "application/json")
assert "column" in str(error)


def test_npy_serializer_python_array():
array = [1, 2, 3]
result = npy_serializer(array)
Expand Down