@@ -20,6 +20,8 @@ from libcpp.vector cimport vector as c_vector
2020from cpython.ref cimport PyObject
2121from cython.operator cimport dereference as deref
2222
23+ import cloudpickle as pickle
24+
2325from pyarrow.lib cimport Buffer, NativeFile, check_status, _RecordBatchFileWriter
2426
2527cdef extern from " arrow/python/python_to_arrow.h" :
@@ -30,6 +32,10 @@ cdef extern from "arrow/python/python_to_arrow.h":
3032
3133 cdef shared_ptr[CRecordBatch] MakeBatch(shared_ptr[CArray] data)
3234
35+ cdef extern PyObject * pyarrow_serialize_callback
36+
37+ cdef extern PyObject * pyarrow_deserialize_callback
38+
3339cdef extern from " arrow/python/arrow_to_python.h" :
3440
3541 cdef CStatus DeserializeList(shared_ptr[CArray] array, int32_t start_idx,
@@ -45,6 +51,81 @@ cdef class PythonObject:
4551 def __cinit__ (self ):
4652 pass
4753
54+ # Types with special serialization handlers
55+ type_to_type_id = dict ()
56+ whitelisted_types = dict ()
57+ types_to_pickle = set ()
58+ custom_serializers = dict ()
59+ custom_deserializers = dict ()
60+
61+ def register_type (type , type_id , pickle = False , custom_serializer = None , custom_deserializer = None ):
62+ """ Add type to the list of types we can serialize.
63+
64+ Args:
65+ type (type): The type that we can serialize.
66+ type_id: A string of bytes used to identify the type.
67+ pickle (bool): True if the serialization should be done with pickle.
68+ False if it should be done efficiently with Arrow.
69+ custom_serializer: This argument is optional, but can be provided to
70+ serialize objects of the class in a particular way.
71+ custom_deserializer: This argument is optional, but can be provided to
72+ deserialize objects of the class in a particular way.
73+ """
74+ type_to_type_id[type ] = type_id
75+ whitelisted_types[type_id] = type
76+ if pickle:
77+ types_to_pickle.add(type_id)
78+ if custom_serializer is not None :
79+ custom_serializers[type_id] = custom_serializer
80+ custom_deserializers[type_id] = custom_deserializer
81+
82+ def serialization_callback (obj ):
83+ if type (obj) not in type_to_type_id:
84+ raise " error"
85+ type_id = type_to_type_id[type (obj)]
86+ if type_id in types_to_pickle:
87+ serialized_obj = {" data" : pickle.dumps(obj), " pickle" : True }
88+ elif type_id in custom_serializers:
89+ serialized_obj = {" data" : custom_serializers[type_id](obj)}
90+ else :
91+ if hasattr (obj, " __dict__" ):
92+ serialized_obj = obj.__dict__
93+ else :
94+ raise " error"
95+ return dict (serialized_obj, ** {" _pytype_" : type_id})
96+
97+ def deserialization_callback (serialized_obj ):
98+ type_id = serialized_obj[" _pytype_" ]
99+
100+ if " pickle" in serialized_obj:
101+ # The object was pickled, so unpickle it.
102+ obj = pickle.loads(serialized_obj[" data" ])
103+ else :
104+ assert type_id not in types_to_pickle
105+ if type_id not in whitelisted_types:
106+ raise " error"
107+ type = whitelisted_types[type_id]
108+ if type_id in custom_deserializers:
109+ obj = custom_deserializers[type_id](serialized_obj[" data" ])
110+ else :
111+ # In this case, serialized_obj should just be the __dict__ field.
112+ if " _ray_getnewargs_" in serialized_obj:
113+ obj = type .__new__ (type , * serialized_obj[" _ray_getnewargs_" ])
114+ else :
115+ obj = type .__new__ (type )
116+ serialized_obj.pop(" _pytype_" )
117+ obj.__dict__ .update(serialized_obj)
118+ return obj
119+
120+ def set_serialization_callbacks (serialization_callback , deserialization_callback ):
121+ global pyarrow_serialize_callback, pyarrow_deserialize_callback
122+ # TODO(pcm): Are refcounts correct here?
123+ print (" setting serialization callback" )
124+ pyarrow_serialize_callback = < PyObject* > serialization_callback
125+ print (" val1 is" , < object > pyarrow_serialize_callback)
126+ pyarrow_deserialize_callback = < PyObject* > deserialization_callback
127+ print (" val2 is" , < object > pyarrow_deserialize_callback)
128+
48129# Main entry point for serialization
49130def serialize_sequence (object value ):
50131 cdef int32_t recursion_depth = 0
@@ -57,18 +138,20 @@ def serialize_sequence(object value):
57138 sequences.push_back(< PyObject* > value)
58139 check_status(SerializeSequences(sequences, recursion_depth, & array, tensors))
59140 result.batch = MakeBatch(array)
141+ num_tensors = 0
60142 for tensor in tensors:
61143 check_status(NdarrayToTensor(c_default_memory_pool(), < object > tensor, & out))
62144 result.tensors.push_back(out)
63- return result
145+ num_tensors += 1
146+ return result, num_tensors
64147
65148# Main entry point for deserialization
66149def deserialize_sequence (PythonObject value , object base ):
67150 cdef PyObject* result
68151 check_status(DeserializeList(deref(value.batch).column(0 ), 0 , deref(value.batch).num_rows(), < PyObject* > base, value.tensors, & result))
69152 return < object > result
70153
71- def write_python_object (PythonObject value , NativeFile sink ):
154+ def write_python_object (PythonObject value , int32_t num_tensors , NativeFile sink ):
72155 cdef shared_ptr[OutputStream] stream
73156 sink.write_handle(& stream)
74157 cdef shared_ptr[CRecordBatchStreamWriter] writer
@@ -79,6 +162,9 @@ def write_python_object(PythonObject value, NativeFile sink):
79162 cdef int64_t body_length
80163
81164 with nogil:
165+ # write number of tensors
166+ check_status(stream.get().Write(< uint8_t* > & num_tensors, sizeof(int32_t)))
167+
82168 check_status(CRecordBatchStreamWriter.Open(stream.get(), schema, & writer))
83169 check_status(deref(writer).WriteRecordBatch(deref(batch)))
84170 check_status(deref(writer).Close())
@@ -93,18 +179,21 @@ def read_python_object(NativeFile source):
93179 cdef shared_ptr[CRecordBatchStreamReader] reader
94180 cdef shared_ptr[CTensor] tensor
95181 cdef int64_t offset
182+ cdef int64_t bytes_read
183+ cdef int32_t num_tensors
96184
97185 with nogil:
186+ # read number of tensors
187+ check_status(stream.get().Read(sizeof(int32_t), & bytes_read, < uint8_t* > & num_tensors))
188+
98189 check_status(CRecordBatchStreamReader.Open(< shared_ptr[InputStream]> stream, & reader))
99190 check_status(reader.get().ReadNextRecordBatch(& result.batch))
100191
101192 check_status(deref(stream).Tell(& offset))
102193
103- while True :
104- s = ReadTensor(offset, stream.get(), & tensor)
194+ for i in range (num_tensors) :
195+ check_status( ReadTensor(offset, stream.get(), & tensor) )
105196 result.tensors.push_back(tensor)
106- if not s.ok():
107- break
108197 check_status(deref(stream).Tell(& offset))
109198
110199 return result
0 commit comments