Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.chainer import defaults
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
from sagemaker.predictor import Predictor, numpy_deserializer
from sagemaker.serializers import NumpySerializer

logger = logging.getLogger("sagemaker")

Expand All @@ -48,7 +49,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(ChainerPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
endpoint_name, sagemaker_session, NumpySerializer(), numpy_deserializer
)


Expand Down
48 changes: 0 additions & 48 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,51 +740,3 @@ def __call__(self, stream, content_type=CONTENT_TYPE_NPY):


numpy_deserializer = _NumpyDeserializer()


class _NPYSerializer(object):
"""Placeholder docstring"""

def __init__(self):
"""Placeholder docstring"""
self.content_type = CONTENT_TYPE_NPY

def __call__(self, data, dtype=None):
"""Serialize data into the request body in NPY format.

Args:
data (object): Data to be serialized. Can be a numpy array, list,
file, or buffer.
dtype:

Returns:
object: NPY serialized data used for the request.
"""
if isinstance(data, np.ndarray):
if not data.size > 0:
raise ValueError("empty array can't be serialized")
return _npy_serialize(data)

if isinstance(data, list):
if not len(data) > 0:
raise ValueError("empty array can't be serialized")
return _npy_serialize(np.array(data, dtype))

# files and buffers. Assumed to hold npy-formatted data.
if hasattr(data, "read"):
return data.read()

return _npy_serialize(np.array(data))


def _npy_serialize(data):
"""
Args:
data:
"""
buffer = BytesIO()
np.save(buffer, data)
return buffer.getvalue()


npy_serializer = _NPYSerializer()
5 changes: 3 additions & 2 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch import defaults
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
from sagemaker.predictor import Predictor, numpy_deserializer
from sagemaker.serializers import NumpySerializer

logger = logging.getLogger("sagemaker")

Expand All @@ -49,7 +50,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(PyTorchPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
endpoint_name, sagemaker_session, NumpySerializer(), numpy_deserializer
)


Expand Down
56 changes: 56 additions & 0 deletions src/sagemaker/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from __future__ import absolute_import

import abc
import io

import numpy as np


class BaseSerializer(abc.ABC):
Expand All @@ -38,3 +41,56 @@ def serialize(self, data):
@abc.abstractmethod
def CONTENT_TYPE(self):
"""The MIME type of the data sent to the inference endpoint."""


class NumpySerializer(BaseSerializer):
"""Serialize data to a buffer using the .npy format."""

CONTENT_TYPE = "application/x-npy"

def __init__(self, dtype=None):
"""Initialize the dtype.

Args:
dtype (str): The dtype of the data.
"""
self.dtype = dtype

def serialize(self, data):
"""Serialize data to a buffer using the .npy format.

Args:
data (object): Data to be serialized. Can be a NumPy array, list,
file, or buffer.

Returns:
io.BytesIO: A buffer containing data serialzied in the .npy format.
"""
if isinstance(data, np.ndarray):
if not data.size > 0:
raise ValueError("Cannot serialize empty array.")
return self._serialize_array(data)

if isinstance(data, list):
if not len(data) > 0:
raise ValueError("Cannot serialize empty array.")
return self._serialize_array(np.array(data, self.dtype))

# files and buffers. Assumed to hold npy-formatted data.
if hasattr(data, "read"):
return data.read()

return self._serialize_array(np.array(data))

def _serialize_array(self, array):
"""Saves a NumPy array in a buffer.

Args:
array (numpy.ndarray): The array to serialize.

Returns:
io.BytesIO: A buffer containing the serialized array.
"""
buffer = io.BytesIO()
np.save(buffer, array)
return buffer.getvalue()
5 changes: 3 additions & 2 deletions src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from sagemaker.fw_registry import default_framework_uri
from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
from sagemaker.predictor import Predictor, numpy_deserializer
from sagemaker.serializers import NumpySerializer
from sagemaker.sklearn import defaults

logger = logging.getLogger("sagemaker")
Expand All @@ -44,7 +45,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(SKLearnPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
endpoint_name, sagemaker_session, NumpySerializer(), numpy_deserializer
)


Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/xgboost/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from sagemaker.fw_utils import model_code_key_prefix
from sagemaker.fw_registry import default_framework_uri
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import Predictor, npy_serializer, csv_deserializer
from sagemaker.predictor import Predictor, csv_deserializer
from sagemaker.serializers import NumpySerializer
from sagemaker.xgboost.defaults import XGBOOST_NAME

logger = logging.getLogger("sagemaker")
Expand All @@ -42,7 +43,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
chain.
"""
super(XGBoostPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, csv_deserializer
endpoint_name, sagemaker_session, NumpySerializer(), csv_deserializer
)


Expand Down
17 changes: 9 additions & 8 deletions tests/integ/test_multidatamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from sagemaker.deserializers import StringDeserializer
from sagemaker.multidatamodel import MultiDataModel
from sagemaker.mxnet import MXNet
from sagemaker.predictor import Predictor, npy_serializer
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base, get_ecr_image_uri_prefix
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.retry import retries
Expand Down Expand Up @@ -158,7 +159,7 @@ def test_multi_data_model_deploy_pretrained_models(
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down Expand Up @@ -216,7 +217,7 @@ def test_multi_data_model_deploy_pretrained_models_local_mode(container_image, s
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=multi_data_model.sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down Expand Up @@ -289,13 +290,13 @@ def test_multi_data_model_deploy_trained_model_from_framework_estimator(
assert PRETRAINED_MODEL_PATH_1 in endpoint_models
assert PRETRAINED_MODEL_PATH_2 in endpoint_models

# Define a predictor to set `serializer` parameter with npy_serializer
# Define a predictor to set `serializer` parameter with NumpySerializer
# instead of `json_serializer` in the default predictor returned by `MXNetPredictor`
# Since we are using a placeholder container image the prediction results are not accurate.
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down Expand Up @@ -390,13 +391,13 @@ def test_multi_data_model_deploy_train_model_from_amazon_first_party_estimator(
assert PRETRAINED_MODEL_PATH_1 in endpoint_models
assert PRETRAINED_MODEL_PATH_2 in endpoint_models

# Define a predictor to set `serializer` parameter with npy_serializer
# Define a predictor to set `serializer` parameter with NumpySerializer
# instead of `json_serializer` in the default predictor returned by `MXNetPredictor`
# Since we are using a placeholder container image the prediction results are not accurate.
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down Expand Up @@ -486,7 +487,7 @@ def test_multi_data_model_deploy_pretrained_models_update_endpoint(
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down
103 changes: 103 additions & 0 deletions tests/unit/sagemaker/test_serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import io

import numpy as np
import pytest

from sagemaker.serializers import NumpySerializer


@pytest.fixture
def numpy_serializer():
return NumpySerializer()


def test_numpy_serializer_python_array(numpy_serializer):
array = [1, 2, 3]
result = numpy_serializer.serialize(array)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_python_array_with_dtype():
numpy_serializer = NumpySerializer(dtype="float16")
array = [1, 2, 3]

result = numpy_serializer.serialize(array)

deserialized = np.load(io.BytesIO(result))
assert np.array_equal(array, deserialized)
assert deserialized.dtype == "float16"


def test_numpy_serializer_numpy_valid_2_dimensional(numpy_serializer):
array = np.array([[1, 2, 3], [3, 4, 5]])
result = numpy_serializer.serialize(array)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_numpy_valid_multidimensional(numpy_serializer):
array = np.ones((10, 10, 10, 10))
result = numpy_serializer.serialize(array)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_numpy_valid_list_of_strings(numpy_serializer):
array = np.array(["one", "two", "three"])
result = numpy_serializer.serialize(array)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_from_buffer_or_file(numpy_serializer):
array = np.ones((2, 3))
stream = io.BytesIO()
np.save(stream, array)
stream.seek(0)

result = numpy_serializer.serialize(stream)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_object(numpy_serializer):
object = {1, 2, 3}

result = numpy_serializer.serialize(object)

assert np.array_equal(np.array(object), np.load(io.BytesIO(result), allow_pickle=True))


def test_numpy_serializer_list_of_empty(numpy_serializer):
with pytest.raises(ValueError) as invalid_input:
numpy_serializer.serialize(np.array([[], []]))

assert "empty array" in str(invalid_input)


def test_numpy_serializer_numpy_invalid_empty(numpy_serializer):
with pytest.raises(ValueError) as invalid_input:
numpy_serializer.serialize(np.array([]))

assert "empty array" in str(invalid_input)


def test_numpy_serializer_python_invalid_empty(numpy_serializer):
with pytest.raises(ValueError) as error:
numpy_serializer.serialize([])
assert "empty array" in str(error)
Loading