diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 4dc694ee66..0b1fc71755 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -89,3 +89,25 @@ def deserialize(self, data, content_type): return data.read() finally: data.close() + + +class StreamDeserializer(BaseDeserializer): + """Returns the data and content-type received from an inference endpoint. + + It is the user's responsibility to close the data stream once they're done + reading it. + """ + + ACCEPT = "*/*" + + def deserialize(self, data, content_type): + """Returns a stream of the response body and the MIME type of the data. + + Args: + data (object): A stream of bytes. + content_type (str): The MIME type of the data. + + Returns: + tuple: A two-tuple containing the stream and content-type. + """ + return data, content_type diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index fad236ea0a..7b5f695c3e 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -623,31 +623,6 @@ def __call__(self, stream, content_type): csv_deserializer = _CsvDeserializer() -class StreamDeserializer(object): - """Returns the tuple of the response stream and the content-type of the response. - It is the receivers responsibility to close the stream when they're done - reading the stream. - - Args: - accept (str): The Accept header to send to the server (optional). - """ - - def __init__(self, accept=None): - """ - Args: - accept: - """ - self.accept = accept - - def __call__(self, stream, content_type): - """ - Args: - stream: - content_type: - """ - return (stream, content_type) - - class _JsonSerializer(object): """Placeholder docstring""" diff --git a/tests/unit/sagemaker/test_deserializers.py b/tests/unit/sagemaker/test_deserializers.py index e4e3149b7a..bbc0e3359b 100644 --- a/tests/unit/sagemaker/test_deserializers.py +++ b/tests/unit/sagemaker/test_deserializers.py @@ -14,7 +14,7 @@ import io -from sagemaker.deserializers import StringDeserializer, BytesDeserializer +from sagemaker.deserializers import StringDeserializer, BytesDeserializer, StreamDeserializer def test_string_deserializer(): @@ -31,3 +31,16 @@ def test_bytes_deserializer(): result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json") assert result == b"[1, 2, 3]" + + +def test_stream_deserializer(): + deserializer = StreamDeserializer() + + stream, content_type = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json") + try: + result = stream.read() + finally: + stream.close() + + assert result == b"[1, 2, 3]" + assert content_type == "application/json" diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 648758bc3e..d19c3911f7 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -26,7 +26,6 @@ json_deserializer, csv_serializer, csv_deserializer, - StreamDeserializer, numpy_deserializer, npy_serializer, _NumpyDeserializer, @@ -182,13 +181,6 @@ def test_json_deserializer_invalid_data(): assert "column" in str(error) -def test_stream_deserializer(): - stream, content_type = StreamDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json") - result = stream.read() - assert result == b"[1, 2, 3]" - assert content_type == "application/json" - - def test_npy_serializer_python_array(): array = [1, 2, 3] result = npy_serializer(array)