Skip to content

Commit 8b2ffe6

Browse files
committed
working version
1 parent bd36c83 commit 8b2ffe6

File tree

3 files changed

+169
-17
lines changed

3 files changed

+169
-17
lines changed

cpp/src/arrow/python/python_to_arrow.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ Status CallCustomSerializationCallback(PyObject* elem, PyObject** serialized_obj
4343
// must be decremented. This is done in SerializeDict in this file.
4444
PyObject* result = PyObject_CallObject(pyarrow_serialize_callback, arglist);
4545
Py_XDECREF(arglist);
46-
if (!result) { return Status::NotImplemented("python error"); }
46+
if (!result || !PyDict_Check(result)) {
47+
// TODO(pcm): Propagate Python error here if !result
48+
return Status::TypeError("serialization callback must return a valid dictionary");
49+
}
4750
*serialized_object = result;
4851
}
4952
return Status::OK();

python/pyarrow/serialization.pxi

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ from libcpp.vector cimport vector as c_vector
2020
from cpython.ref cimport PyObject
2121
from cython.operator cimport dereference as deref
2222

23+
import cloudpickle as pickle
24+
2325
from pyarrow.lib cimport Buffer, NativeFile, check_status, _RecordBatchFileWriter
2426

2527
cdef extern from "arrow/python/python_to_arrow.h":
@@ -30,6 +32,10 @@ cdef extern from "arrow/python/python_to_arrow.h":
3032

3133
cdef shared_ptr[CRecordBatch] MakeBatch(shared_ptr[CArray] data)
3234

35+
cdef extern PyObject *pyarrow_serialize_callback
36+
37+
cdef extern PyObject *pyarrow_deserialize_callback
38+
3339
cdef extern from "arrow/python/arrow_to_python.h":
3440

3541
cdef CStatus DeserializeList(shared_ptr[CArray] array, int32_t start_idx,
@@ -45,6 +51,81 @@ cdef class PythonObject:
4551
def __cinit__(self):
4652
pass
4753

54+
# Types with special serialization handlers
55+
type_to_type_id = dict()
56+
whitelisted_types = dict()
57+
types_to_pickle = set()
58+
custom_serializers = dict()
59+
custom_deserializers = dict()
60+
61+
def register_type(type, type_id, pickle=False, custom_serializer=None, custom_deserializer=None):
62+
"""Add type to the list of types we can serialize.
63+
64+
Args:
65+
type (type): The type that we can serialize.
66+
type_id: A string of bytes used to identify the type.
67+
pickle (bool): True if the serialization should be done with pickle.
68+
False if it should be done efficiently with Arrow.
69+
custom_serializer: This argument is optional, but can be provided to
70+
serialize objects of the class in a particular way.
71+
custom_deserializer: This argument is optional, but can be provided to
72+
deserialize objects of the class in a particular way.
73+
"""
74+
type_to_type_id[type] = type_id
75+
whitelisted_types[type_id] = type
76+
if pickle:
77+
types_to_pickle.add(type_id)
78+
if custom_serializer is not None:
79+
custom_serializers[type_id] = custom_serializer
80+
custom_deserializers[type_id] = custom_deserializer
81+
82+
def serialization_callback(obj):
83+
if type(obj) not in type_to_type_id:
84+
raise "error"
85+
type_id = type_to_type_id[type(obj)]
86+
if type_id in types_to_pickle:
87+
serialized_obj = {"data": pickle.dumps(obj), "pickle": True}
88+
elif type_id in custom_serializers:
89+
serialized_obj = {"data": custom_serializers[type_id](obj)}
90+
else:
91+
if hasattr(obj, "__dict__"):
92+
serialized_obj = obj.__dict__
93+
else:
94+
raise "error"
95+
return dict(serialized_obj, **{"_pytype_": type_id})
96+
97+
def deserialization_callback(serialized_obj):
98+
type_id = serialized_obj["_pytype_"]
99+
100+
if "pickle" in serialized_obj:
101+
# The object was pickled, so unpickle it.
102+
obj = pickle.loads(serialized_obj["data"])
103+
else:
104+
assert type_id not in types_to_pickle
105+
if type_id not in whitelisted_types:
106+
raise "error"
107+
type = whitelisted_types[type_id]
108+
if type_id in custom_deserializers:
109+
obj = custom_deserializers[type_id](serialized_obj["data"])
110+
else:
111+
# In this case, serialized_obj should just be the __dict__ field.
112+
if "_ray_getnewargs_" in serialized_obj:
113+
obj = type.__new__(type, *serialized_obj["_ray_getnewargs_"])
114+
else:
115+
obj = type.__new__(type)
116+
serialized_obj.pop("_pytype_")
117+
obj.__dict__.update(serialized_obj)
118+
return obj
119+
120+
def set_serialization_callbacks(serialization_callback, deserialization_callback):
121+
global pyarrow_serialize_callback, pyarrow_deserialize_callback
122+
# TODO(pcm): Are refcounts correct here?
123+
print("setting serialization callback")
124+
pyarrow_serialize_callback = <PyObject*> serialization_callback
125+
print("val1 is", <object> pyarrow_serialize_callback)
126+
pyarrow_deserialize_callback = <PyObject*> deserialization_callback
127+
print("val2 is", <object> pyarrow_deserialize_callback)
128+
48129
# Main entry point for serialization
49130
def serialize_sequence(object value):
50131
cdef int32_t recursion_depth = 0
@@ -57,18 +138,20 @@ def serialize_sequence(object value):
57138
sequences.push_back(<PyObject*> value)
58139
check_status(SerializeSequences(sequences, recursion_depth, &array, tensors))
59140
result.batch = MakeBatch(array)
141+
num_tensors = 0
60142
for tensor in tensors:
61143
check_status(NdarrayToTensor(c_default_memory_pool(), <object> tensor, &out))
62144
result.tensors.push_back(out)
63-
return result
145+
num_tensors += 1
146+
return result, num_tensors
64147

65148
# Main entry point for deserialization
66149
def deserialize_sequence(PythonObject value, object base):
67150
cdef PyObject* result
68151
check_status(DeserializeList(deref(value.batch).column(0), 0, deref(value.batch).num_rows(), <PyObject*> base, value.tensors, &result))
69152
return <object> result
70153

71-
def write_python_object(PythonObject value, NativeFile sink):
154+
def write_python_object(PythonObject value, int32_t num_tensors, NativeFile sink):
72155
cdef shared_ptr[OutputStream] stream
73156
sink.write_handle(&stream)
74157
cdef shared_ptr[CRecordBatchStreamWriter] writer
@@ -79,6 +162,9 @@ def write_python_object(PythonObject value, NativeFile sink):
79162
cdef int64_t body_length
80163

81164
with nogil:
165+
# write number of tensors
166+
check_status(stream.get().Write(<uint8_t*> &num_tensors, sizeof(int32_t)))
167+
82168
check_status(CRecordBatchStreamWriter.Open(stream.get(), schema, &writer))
83169
check_status(deref(writer).WriteRecordBatch(deref(batch)))
84170
check_status(deref(writer).Close())
@@ -93,18 +179,21 @@ def read_python_object(NativeFile source):
93179
cdef shared_ptr[CRecordBatchStreamReader] reader
94180
cdef shared_ptr[CTensor] tensor
95181
cdef int64_t offset
182+
cdef int64_t bytes_read
183+
cdef int32_t num_tensors
96184

97185
with nogil:
186+
# read number of tensors
187+
check_status(stream.get().Read(sizeof(int32_t), &bytes_read, <uint8_t*> &num_tensors))
188+
98189
check_status(CRecordBatchStreamReader.Open(<shared_ptr[InputStream]> stream, &reader))
99190
check_status(reader.get().ReadNextRecordBatch(&result.batch))
100191

101192
check_status(deref(stream).Tell(&offset))
102193

103-
while True:
104-
s = ReadTensor(offset, stream.get(), &tensor)
194+
for i in range(num_tensors):
195+
check_status(ReadTensor(offset, stream.get(), &tensor))
105196
result.tensors.push_back(tensor)
106-
if not s.ok():
107-
break
108197
check_status(deref(stream).Tell(&offset))
109198

110199
return result

python/pyarrow/tests/test_serialization.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,85 @@
2020
from __future__ import print_function
2121

2222
import os
23+
import string
24+
import sys
2325

2426
import pyarrow as pa
2527
import numpy as np
28+
from numpy.testing import assert_equal
2629

27-
obj = pa.lib.serialize_sequence([np.array([1, 2, 3]), None, np.array([4, 5, 6])])
30+
def serialization_callback(value):
31+
if isinstance(value, np.ndarray):
32+
return {"data": value.tolist(), "_pytype_": str(value.dtype.str)}
33+
else:
34+
return {"data": str(value), "_pytype_": "long"}
2835

29-
SIZE = 4096
36+
def deserialization_callback(value):
37+
data = value["data"]
38+
if value["_pytype_"] == "long":
39+
return int(data)
40+
else:
41+
return np.array(data, dtype=np.dtype(value["_pytype_"]))
42+
43+
pa.lib.set_serialization_callbacks(serialization_callback, deserialization_callback)
44+
45+
def array_custom_serializer(obj):
46+
return obj.tolist(), obj.dtype.str
47+
48+
def array_custom_deserializer(serialized_obj):
49+
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
50+
51+
pa.lib.register_type(np.ndarray, 20 * b"\x01", pickle=False,
52+
custom_serializer=array_custom_serializer,
53+
custom_deserializer=array_custom_deserializer)
54+
55+
if sys.version_info >= (3, 0):
56+
long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])]
57+
else:
58+
long_extras = [long(0), np.array([["hi", u"hi"], [1.3, long(1)]])] # noqa: E501,F821
59+
60+
PRIMITIVE_OBJECTS = [
61+
0, 0.0, 0.9, 1 << 62, 1 << 100, 1 << 999,
62+
[1 << 100, [1 << 100]], "a", string.printable, "\u262F",
63+
u"hello world", u"\xff\xfe\x9c\x001\x000\x00", None, True,
64+
False, [], (), {}, np.int8(3), np.int32(4), np.int64(5),
65+
np.uint8(3), np.uint32(4), np.uint64(5), np.float32(1.9),
66+
np.float64(1.9), np.zeros([100, 100]),
67+
np.random.normal(size=[100, 100]), np.array(["hi", 3]),
68+
np.array(["hi", 3], dtype=object)] + long_extras
69+
70+
COMPLEX_OBJECTS = [
71+
[[[[[[[[[[[[]]]]]]]]]]]],
72+
{"obj{}".format(i): np.random.normal(size=[100, 100]) for i in range(10)},
73+
{(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {
74+
(): {(): {}}}}}}}}}}}}},
75+
((((((((((),),),),),),),),),),
76+
{"a": {"b": {"c": {"d": {}}}}}]
77+
78+
def serialization_roundtrip(value, f):
79+
f.seek(0)
80+
serialized, num_tensors = pa.lib.serialize_sequence(value)
81+
pa.lib.write_python_object(serialized, num_tensors, f)
82+
f.seek(0)
83+
res = pa.lib.read_python_object(f)
84+
base = None
85+
result = pa.lib.deserialize_sequence(res, base)
86+
assert_equal(value, result)
87+
88+
# Create a large memory mapped file
89+
SIZE = 100 * 1024 * 1024 # 100 MB
3090
arr = np.random.randint(0, 256, size=SIZE).astype('u1')
3191
data = arr.tobytes()[:SIZE]
32-
path = os.path.join("/tmp/temp")
92+
path = os.path.join("/tmp/pyarrow-temp-file")
3393
with open(path, 'wb') as f:
3494
f.write(data)
3595

36-
f = pa.memory_map(path, mode="w")
37-
38-
pa.lib.write_python_object(obj, f)
39-
40-
f = pa.memory_map(path, mode="r")
96+
MEMORY_MAPPED_FILE = pa.memory_map(path, mode="r+")
4197

42-
res = pa.lib.read_python_object(f)
98+
def test_primitive_serialization():
99+
for obj in PRIMITIVE_OBJECTS:
100+
serialization_roundtrip([obj], MEMORY_MAPPED_FILE)
43101

44-
pa.lib.deserialize_sequence(res, res)
102+
def test_complex_serialization():
103+
for obj in COMPLEX_OBJECTS:
104+
serialization_roundtrip([obj], MEMORY_MAPPED_FILE)

0 commit comments

Comments
 (0)