|  | 
| 22 | 22 | 
 | 
| 23 | 23 | from sagemaker.amazon.record_pb2 import Record | 
| 24 | 24 | from sagemaker.deserializers import BaseDeserializer | 
|  | 25 | +from sagemaker.serializers import BaseSerializer | 
| 25 | 26 | from sagemaker.utils import DeferredError | 
| 26 | 27 | 
 | 
| 27 | 28 | 
 | 
| 28 |  | -class numpy_to_record_serializer(object): | 
| 29 |  | -    """Placeholder docstring""" | 
|  | 29 | +class RecordSerializer(BaseSerializer): | 
|  | 30 | +    """Serialize a NumPy array for an inference request.""" | 
| 30 | 31 | 
 | 
| 31 |  | -    def __init__(self, content_type="application/x-recordio-protobuf"): | 
| 32 |  | -        """ | 
| 33 |  | -        Args: | 
| 34 |  | -            content_type: | 
| 35 |  | -        """ | 
| 36 |  | -        self.content_type = content_type | 
|  | 32 | +    CONTENT_TYPE = "application/x-recordio-protobuf" | 
|  | 33 | + | 
|  | 34 | +    def serialize(self, data): | 
|  | 35 | +        """Serialize a NumPy array into a buffer containing RecordIO records. | 
| 37 | 36 | 
 | 
| 38 |  | -    def __call__(self, array): | 
| 39 |  | -        """ | 
| 40 | 37 |         Args: | 
| 41 |  | -            array: | 
|  | 38 | +            data (numpy.ndarray): The data to serialize. | 
|  | 39 | +
 | 
|  | 40 | +        Returns: | 
|  | 41 | +            io.BytesIO: A buffer containing the data serialized as records. | 
| 42 | 42 |         """ | 
| 43 |  | -        if len(array.shape) == 1: | 
| 44 |  | -            array = array.reshape(1, array.shape[0]) | 
| 45 |  | -        assert len(array.shape) == 2, "Expecting a 1 or 2 dimensional array" | 
| 46 |  | -        buf = io.BytesIO() | 
| 47 |  | -        write_numpy_to_dense_tensor(buf, array) | 
| 48 |  | -        buf.seek(0) | 
| 49 |  | -        return buf | 
|  | 43 | +        if len(data.shape) == 1: | 
|  | 44 | +            data = data.reshape(1, data.shape[0]) | 
|  | 45 | + | 
|  | 46 | +        if len(data.shape) != 2: | 
|  | 47 | +            raise ValueError( | 
|  | 48 | +                "Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape) | 
|  | 49 | +            ) | 
|  | 50 | + | 
|  | 51 | +        buffer = io.BytesIO() | 
|  | 52 | +        write_numpy_to_dense_tensor(buffer, data) | 
|  | 53 | +        buffer.seek(0) | 
|  | 54 | + | 
|  | 55 | +        return buffer | 
| 50 | 56 | 
 | 
| 51 | 57 | 
 | 
| 52 | 58 | class RecordDeserializer(BaseDeserializer): | 
|  | 
0 commit comments