diff --git a/src/sagemaker/serve/model_server/multi_model_server/inference.py b/src/sagemaker/serve/model_server/multi_model_server/inference.py index 908ffcc7aa..9361765da0 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/inference.py +++ b/src/sagemaker/serve/model_server/multi_model_server/inference.py @@ -46,18 +46,28 @@ def input_fn(input_data, content_type, context=None): if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type[0], ) diff --git a/src/sagemaker/serve/model_server/torchserve/inference.py b/src/sagemaker/serve/model_server/torchserve/inference.py index 489cc1bc1e..058103a1fd 100644 --- a/src/sagemaker/serve/model_server/torchserve/inference.py +++ b/src/sagemaker/serve/model_server/torchserve/inference.py @@ -68,18 +68,28 @@ def input_fn(input_data, content_type): if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type[0], ) diff --git a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py index 517c774bbc..49cec5aab5 100644 --- a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py +++ b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py @@ -71,18 +71,28 @@ def input_fn(input_data, content_type): if hasattr(schema_builder, "custom_input_translator"): return schema_builder.custom_input_translator.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type, ) else: return schema_builder.input_deserializer.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type[0], )