Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,25 @@ def deserialize(self, data, content_type):
return data.read()
finally:
data.close()


class StreamDeserializer(BaseDeserializer):
"""Returns the data and content-type received from an inference endpoint.

It is the user's responsibility to close the data stream once they're done
reading it.
"""

ACCEPT = "*/*"

def deserialize(self, data, content_type):
"""Returns a stream of the response body and the MIME type of the data.

Args:
data (object): A stream of bytes.
content_type (str): The MIME type of the data.

Returns:
tuple: A two-tuple containing the stream and content-type.
"""
return data, content_type
25 changes: 0 additions & 25 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,31 +623,6 @@ def __call__(self, stream, content_type):
csv_deserializer = _CsvDeserializer()


class StreamDeserializer(object):
"""Returns the tuple of the response stream and the content-type of the response.
It is the receivers responsibility to close the stream when they're done
reading the stream.

Args:
accept (str): The Accept header to send to the server (optional).
"""

def __init__(self, accept=None):
"""
Args:
accept:
"""
self.accept = accept

def __call__(self, stream, content_type):
"""
Args:
stream:
content_type:
"""
return (stream, content_type)


class _JsonSerializer(object):
"""Placeholder docstring"""

Expand Down
15 changes: 14 additions & 1 deletion tests/unit/sagemaker/test_deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import io

from sagemaker.deserializers import StringDeserializer, BytesDeserializer
from sagemaker.deserializers import StringDeserializer, BytesDeserializer, StreamDeserializer


def test_string_deserializer():
Expand All @@ -31,3 +31,16 @@ def test_bytes_deserializer():
result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")

assert result == b"[1, 2, 3]"


def test_stream_deserializer():
deserializer = StreamDeserializer()

stream, content_type = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")
try:
result = stream.read()
finally:
stream.close()

assert result == b"[1, 2, 3]"
assert content_type == "application/json"
8 changes: 0 additions & 8 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
json_deserializer,
csv_serializer,
csv_deserializer,
StreamDeserializer,
numpy_deserializer,
npy_serializer,
_NumpyDeserializer,
Expand Down Expand Up @@ -182,13 +181,6 @@ def test_json_deserializer_invalid_data():
assert "column" in str(error)


def test_stream_deserializer():
stream, content_type = StreamDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")
result = stream.read()
assert result == b"[1, 2, 3]"
assert content_type == "application/json"


def test_npy_serializer_python_array():
array = [1, 2, 3]
result = npy_serializer(array)
Expand Down