Skip to content

Commit 3af1c67

Browse files
committed
deserialization path (need to figure out if base object and refcounting is handled correctly)
1 parent deb3b46 commit 3af1c67

File tree

4 files changed

+232
-1
lines changed

4 files changed

+232
-1
lines changed

cpp/src/arrow/python/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ set(ARROW_PYTHON_TEST_LINK_LIBS ${ARROW_PYTHON_MIN_TEST_LIBS})
4343

4444
set(ARROW_PYTHON_SRCS
4545
arrow_to_pandas.cc
46+
arrow_to_python.cc
4647
builtin_convert.cc
4748
common.cc
4849
config.cc
@@ -86,6 +87,7 @@ endif()
8687
install(FILES
8788
api.h
8889
arrow_to_pandas.h
90+
arrow_to_python.h
8991
builtin_convert.h
9092
common.h
9193
config.h
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "arrow_to_python.h"
19+
20+
#include <arrow/util/logging.h>
21+
22+
#include "numpy_convert.h"
23+
24+
namespace arrow {
25+
26+
#if PY_MAJOR_VERSION >= 3
27+
#define PyInt_FromLong PyLong_FromLong
28+
#endif
29+
30+
Status get_value(std::shared_ptr<Array> arr, int32_t index, int32_t type, PyObject* base,
31+
const std::vector<std::shared_ptr<Tensor>>& tensors, PyObject** result) {
32+
switch (arr->type()->id()) {
33+
case Type::BOOL:
34+
*result =
35+
PyBool_FromLong(std::static_pointer_cast<BooleanArray>(arr)->Value(index));
36+
return Status::OK();
37+
case Type::INT64:
38+
*result = PyInt_FromLong(std::static_pointer_cast<Int64Array>(arr)->Value(index));
39+
return Status::OK();
40+
case Type::BINARY: {
41+
int32_t nchars;
42+
const uint8_t* str =
43+
std::static_pointer_cast<BinaryArray>(arr)->GetValue(index, &nchars);
44+
*result = PyBytes_FromStringAndSize(reinterpret_cast<const char*>(str), nchars);
45+
return Status::OK();
46+
}
47+
case Type::STRING: {
48+
int32_t nchars;
49+
const uint8_t* str =
50+
std::static_pointer_cast<StringArray>(arr)->GetValue(index, &nchars);
51+
*result = PyUnicode_FromStringAndSize(reinterpret_cast<const char*>(str), nchars);
52+
return Status::OK();
53+
}
54+
case Type::FLOAT:
55+
*result =
56+
PyFloat_FromDouble(std::static_pointer_cast<FloatArray>(arr)->Value(index));
57+
return Status::OK();
58+
case Type::DOUBLE:
59+
*result =
60+
PyFloat_FromDouble(std::static_pointer_cast<DoubleArray>(arr)->Value(index));
61+
return Status::OK();
62+
case Type::STRUCT: {
63+
auto s = std::static_pointer_cast<StructArray>(arr);
64+
auto l = std::static_pointer_cast<ListArray>(s->field(0));
65+
if (s->type()->child(0)->name() == "list") {
66+
return DeserializeList(l->values(), l->value_offset(index),
67+
l->value_offset(index + 1), base, tensors, result);
68+
} else if (s->type()->child(0)->name() == "tuple") {
69+
return DeserializeTuple(l->values(), l->value_offset(index),
70+
l->value_offset(index + 1), base, tensors, result);
71+
} else if (s->type()->child(0)->name() == "dict") {
72+
return DeserializeDict(l->values(), l->value_offset(index),
73+
l->value_offset(index + 1), base, tensors, result);
74+
} else {
75+
DCHECK(false) << "error";
76+
}
77+
}
78+
// We use an Int32Builder here to distinguish the tensor indices from
79+
// the Type::INT64 above (see tensor_indices_ in sequence.h).
80+
case Type::INT32: {
81+
return DeserializeArray(arr, index, base, tensors, result);
82+
}
83+
default:
84+
DCHECK(false) << "union tag not recognized " << type;
85+
}
86+
return Status::OK();
87+
}
88+
89+
#define DESERIALIZE_SEQUENCE(CREATE, SET_ITEM) \
90+
auto data = std::dynamic_pointer_cast<UnionArray>(array); \
91+
int32_t size = array->length(); \
92+
PyObject* result = CREATE(stop_idx - start_idx); \
93+
auto types = std::make_shared<Int8Array>(size, data->type_ids()); \
94+
auto offsets = std::make_shared<Int32Array>(size, data->value_offsets()); \
95+
for (int32_t i = start_idx; i < stop_idx; ++i) { \
96+
if (data->IsNull(i)) { \
97+
Py_INCREF(Py_None); \
98+
SET_ITEM(result, i - start_idx, Py_None); \
99+
} else { \
100+
int32_t offset = offsets->Value(i); \
101+
int8_t type = types->Value(i); \
102+
std::shared_ptr<Array> arr = data->child(type); \
103+
PyObject* value; \
104+
RETURN_NOT_OK(get_value(arr, offset, type, base, tensors, &value)); \
105+
SET_ITEM(result, i - start_idx, value); \
106+
} \
107+
} \
108+
*out = result; \
109+
return Status::OK();
110+
111+
Status DeserializeList(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx,
112+
PyObject* base, const std::vector<std::shared_ptr<Tensor>>& tensors, PyObject** out) {
113+
DESERIALIZE_SEQUENCE(PyList_New, PyList_SetItem)
114+
}
115+
116+
Status DeserializeTuple(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx,
117+
PyObject* base, const std::vector<std::shared_ptr<Tensor>>& tensors, PyObject** out) {
118+
DESERIALIZE_SEQUENCE(PyTuple_New, PyTuple_SetItem)
119+
}
120+
121+
Status DeserializeDict(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx,
122+
PyObject* base, const std::vector<std::shared_ptr<Tensor>>& tensors, PyObject** out) {
123+
auto data = std::dynamic_pointer_cast<StructArray>(array);
124+
// TODO(pcm): error handling, get rid of the temporary copy of the list
125+
PyObject *keys, *vals;
126+
PyObject* result = PyDict_New();
127+
ARROW_RETURN_NOT_OK(
128+
DeserializeList(data->field(0), start_idx, stop_idx, base, tensors, &keys));
129+
ARROW_RETURN_NOT_OK(
130+
DeserializeList(data->field(1), start_idx, stop_idx, base, tensors, &vals));
131+
for (int32_t i = start_idx; i < stop_idx; ++i) {
132+
PyDict_SetItem(
133+
result, PyList_GetItem(keys, i - start_idx), PyList_GetItem(vals, i - start_idx));
134+
}
135+
Py_XDECREF(keys); // PyList_GetItem(keys, ...) incremented the reference count
136+
Py_XDECREF(vals); // PyList_GetItem(vals, ...) incremented the reference count
137+
static PyObject* py_type = PyUnicode_FromString("_pytype_");
138+
if (PyDict_Contains(result, py_type) && pyarrow_deserialize_callback) {
139+
PyObject* arglist = Py_BuildValue("(O)", result);
140+
// The result of the call to PyObject_CallObject will be passed to Python
141+
// and its reference count will be decremented by the interpreter.
142+
PyObject* callback_result = PyObject_CallObject(pyarrow_deserialize_callback, arglist);
143+
Py_XDECREF(arglist);
144+
Py_XDECREF(result);
145+
result = callback_result;
146+
if (!callback_result) { return Status::NotImplemented("python error"); }
147+
}
148+
*out = result;
149+
return Status::OK();
150+
}
151+
152+
Status DeserializeArray(std::shared_ptr<Array> array, int32_t offset, PyObject* base,
153+
const std::vector<std::shared_ptr<arrow::Tensor>>& tensors, PyObject** out) {
154+
DCHECK(array);
155+
int32_t index = std::static_pointer_cast<Int32Array>(array)->Value(offset);
156+
RETURN_NOT_OK(py::TensorToNdarray(*tensors[index], base, out));
157+
/* Mark the array as immutable. */
158+
PyObject* flags = PyObject_GetAttrString(*out, "flags");
159+
DCHECK(flags != NULL) << "Could not mark Numpy array immutable";
160+
int flag_set = PyObject_SetAttrString(flags, "writeable", Py_False);
161+
DCHECK(flag_set == 0) << "Could not mark Numpy array immutable";
162+
Py_XDECREF(flags);
163+
return Status::OK();
164+
}
165+
166+
} // namespace arrow
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#ifndef ARROW_PYTHON_ARROW_TO_PYTHON_H
19+
#define ARROW_PYTHON_ARROW_TO_PYTHON_H
20+
21+
#include <Python.h>
22+
23+
#include <arrow/api.h>
24+
25+
extern "C" {
26+
extern PyObject* pyarrow_serialize_callback;
27+
extern PyObject* pyarrow_deserialize_callback;
28+
}
29+
30+
namespace arrow {
31+
32+
arrow::Status DeserializeList(std::shared_ptr<arrow::Array> array, int32_t start_idx,
33+
int32_t stop_idx, PyObject* base,
34+
const std::vector<std::shared_ptr<arrow::Tensor>>& tensors, PyObject** out);
35+
36+
arrow::Status DeserializeTuple(std::shared_ptr<arrow::Array> array, int32_t start_idx,
37+
int32_t stop_idx, PyObject* base,
38+
const std::vector<std::shared_ptr<arrow::Tensor>>& tensors, PyObject** out);
39+
40+
arrow::Status DeserializeDict(std::shared_ptr<arrow::Array> array, int32_t start_idx,
41+
int32_t stop_idx, PyObject* base,
42+
const std::vector<std::shared_ptr<arrow::Tensor>>& tensors, PyObject** out);
43+
44+
arrow::Status DeserializeArray(std::shared_ptr<arrow::Array> array, int32_t offset,
45+
PyObject* base, const std::vector<std::shared_ptr<arrow::Tensor>>& tensors,
46+
PyObject** out);
47+
48+
} // namespace arrow
49+
50+
#endif // ARROW_PYTHON_ARROW_TO_PYTHON_H

python/pyarrow/serialization.pxi

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,24 @@
1818
from libcpp cimport bool as c_bool, nullptr
1919
from libcpp.vector cimport vector as c_vector
2020
from cpython.ref cimport PyObject
21+
from cython.operator cimport dereference as deref
2122

2223
from pyarrow.lib cimport Buffer, NativeFile, check_status
2324

24-
cdef extern from "arrow/python/python_to_arrow.h" nogil:
25+
cdef extern from "arrow/python/python_to_arrow.h":
2526

2627
cdef CStatus SerializeSequences(c_vector[PyObject*] sequences,
2728
int32_t recursion_depth, shared_ptr[CArray]* array_out,
2829
c_vector[PyObject*]& tensors_out)
2930

3031
cdef shared_ptr[CRecordBatch] MakeBatch(shared_ptr[CArray] data)
3132

33+
cdef extern from "arrow/python/arrow_to_python.h":
34+
35+
cdef CStatus DeserializeList(shared_ptr[CArray] array, int32_t start_idx,
36+
int32_t stop_idx, PyObject* base,
37+
const c_vector[shared_ptr[CTensor]]& tensors, PyObject** out)
38+
3239
cdef class PythonObject:
3340

3441
cdef:
@@ -54,3 +61,9 @@ def serialize_sequence(object value):
5461
check_status(NdarrayToTensor(c_default_memory_pool(), <object> tensor, &out))
5562
result.tensors.push_back(out)
5663
return result
64+
65+
# Main entry point for deserialization
66+
def deserialize_sequence(PythonObject value, object base):
67+
cdef PyObject* result
68+
check_status(DeserializeList(deref(value.batch).column(0), 0, deref(value.batch).num_rows(), <PyObject*> base, value.tensors, &result))
69+
return <object> result

0 commit comments

Comments
 (0)