Skip to content

Commit 5766b8c

Browse files
committed
python to arrow serialization
1 parent 94b7cfa commit 5766b8c

File tree

10 files changed

+840
-0
lines changed

10 files changed

+840
-0
lines changed

cpp/src/arrow/python/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,15 @@ set(ARROW_PYTHON_SRCS
4646
builtin_convert.cc
4747
common.cc
4848
config.cc
49+
dict.cc
4950
helpers.cc
5051
init.cc
5152
io.cc
5253
numpy_convert.cc
5354
pandas_to_arrow.cc
55+
python_to_arrow.cc
5456
pyarrow.cc
57+
sequence
5558
)
5659

5760
set(ARROW_PYTHON_SHARED_LINK_LIBS
@@ -86,14 +89,17 @@ install(FILES
8689
builtin_convert.h
8790
common.h
8891
config.h
92+
dict.h
8993
helpers.h
9094
init.h
9195
io.h
9296
numpy_convert.h
9397
numpy_interop.h
9498
pandas_to_arrow.h
99+
python_to_arrow.h
95100
platform.h
96101
pyarrow.h
102+
sequence.h
97103
type_traits.h
98104
DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/arrow/python")
99105

cpp/src/arrow/python/dict.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "dict.h"
19+
20+
namespace arrow {
21+
22+
Status DictBuilder::Finish(std::shared_ptr<Array> key_tuple_data,
23+
std::shared_ptr<Array> key_dict_data, std::shared_ptr<Array> val_list_data,
24+
std::shared_ptr<Array> val_tuple_data, std::shared_ptr<Array> val_dict_data,
25+
std::shared_ptr<arrow::Array>* out) {
26+
// lists and dicts can't be keys of dicts in Python, that is why for
27+
// the keys we do not need to collect sublists
28+
std::shared_ptr<Array> keys, vals;
29+
RETURN_NOT_OK(keys_.Finish(nullptr, key_tuple_data, key_dict_data, &keys));
30+
RETURN_NOT_OK(vals_.Finish(val_list_data, val_tuple_data, val_dict_data, &vals));
31+
auto keys_field = std::make_shared<Field>("keys", keys->type());
32+
auto vals_field = std::make_shared<Field>("vals", vals->type());
33+
auto type =
34+
std::make_shared<StructType>(std::vector<FieldPtr>({keys_field, vals_field}));
35+
std::vector<std::shared_ptr<Array>> field_arrays({keys, vals});
36+
DCHECK(keys->length() == vals->length());
37+
out->reset(new StructArray(type, keys->length(), field_arrays));
38+
return Status::OK();
39+
}
40+
41+
} // namespace arrow

cpp/src/arrow/python/dict.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#ifndef PYTHON_ARROW_DICT_H
19+
#define PYTHON_ARROW_DICT_H
20+
21+
#include <arrow/api.h>
22+
23+
#include "sequence.h"
24+
25+
namespace arrow {
26+
27+
/// Constructing dictionaries of key/value pairs. Sequences of
28+
/// keys and values are built separately using a pair of
29+
/// SequenceBuilders. The resulting Arrow representation
30+
/// can be obtained via the Finish method.
31+
class DictBuilder {
32+
public:
33+
DictBuilder(arrow::MemoryPool* pool = nullptr) : keys_(pool), vals_(pool) {}
34+
35+
/// Builder for the keys of the dictionary
36+
SequenceBuilder& keys() { return keys_; }
37+
/// Builder for the values of the dictionary
38+
SequenceBuilder& vals() { return vals_; }
39+
40+
/// Construct an Arrow StructArray representing the dictionary.
41+
/// Contains a field "keys" for the keys and "vals" for the values.
42+
43+
/// \param list_data
44+
/// List containing the data from nested lists in the value
45+
/// list of the dictionary
46+
///
47+
/// \param dict_data
48+
/// List containing the data from nested dictionaries in the
49+
/// value list of the dictionary
50+
arrow::Status Finish(std::shared_ptr<arrow::Array> key_tuple_data,
51+
std::shared_ptr<arrow::Array> key_dict_data,
52+
std::shared_ptr<arrow::Array> val_list_data,
53+
std::shared_ptr<arrow::Array> val_tuple_data,
54+
std::shared_ptr<arrow::Array> val_dict_data, std::shared_ptr<arrow::Array>* out);
55+
56+
private:
57+
SequenceBuilder keys_;
58+
SequenceBuilder vals_;
59+
};
60+
61+
} // namespace arrow
62+
63+
#endif // PYARROW_DICT_H
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "python_to_arrow.h"
19+
20+
#include <sstream>
21+
22+
#include "scalars.h"
23+
24+
constexpr int32_t kMaxRecursionDepth = 100;
25+
26+
extern "C" {
27+
PyObject* pyarrow_serialize_callback = NULL;
28+
PyObject* pyarrow_deserialize_callback = NULL;
29+
}
30+
31+
namespace arrow {
32+
33+
Status append(PyObject* elem, SequenceBuilder& builder, std::vector<PyObject*>& sublists,
34+
std::vector<PyObject*>& subtuples, std::vector<PyObject*>& subdicts,
35+
std::vector<PyObject*>& tensors_out) {
36+
// The bool case must precede the int case (PyInt_Check passes for bools)
37+
if (PyBool_Check(elem)) {
38+
RETURN_NOT_OK(builder.AppendBool(elem == Py_True));
39+
} else if (PyFloat_Check(elem)) {
40+
RETURN_NOT_OK(builder.AppendDouble(PyFloat_AS_DOUBLE(elem)));
41+
} else if (PyLong_Check(elem)) {
42+
int overflow = 0;
43+
int64_t data = PyLong_AsLongLongAndOverflow(elem, &overflow);
44+
RETURN_NOT_OK(builder.AppendInt64(data));
45+
if (overflow) { return Status::NotImplemented("long overflow"); }
46+
#if PY_MAJOR_VERSION < 3
47+
} else if (PyInt_Check(elem)) {
48+
RETURN_NOT_OK(builder.AppendInt64(static_cast<int64_t>(PyInt_AS_LONG(elem))));
49+
#endif
50+
} else if (PyBytes_Check(elem)) {
51+
auto data = reinterpret_cast<uint8_t*>(PyBytes_AS_STRING(elem));
52+
auto size = PyBytes_GET_SIZE(elem);
53+
RETURN_NOT_OK(builder.AppendBytes(data, size));
54+
} else if (PyUnicode_Check(elem)) {
55+
Py_ssize_t size;
56+
#if PY_MAJOR_VERSION >= 3
57+
char* data = PyUnicode_AsUTF8AndSize(elem, &size);
58+
Status s = builder.AppendString(data, size);
59+
#else
60+
PyObject* str = PyUnicode_AsUTF8String(elem);
61+
char* data = PyString_AS_STRING(str);
62+
size = PyString_GET_SIZE(str);
63+
Status s = builder.AppendString(data, size);
64+
Py_XDECREF(str);
65+
#endif
66+
RETURN_NOT_OK(s);
67+
} else if (PyList_Check(elem)) {
68+
builder.AppendList(PyList_Size(elem));
69+
sublists.push_back(elem);
70+
} else if (PyDict_Check(elem)) {
71+
builder.AppendDict(PyDict_Size(elem));
72+
subdicts.push_back(elem);
73+
} else if (PyTuple_CheckExact(elem)) {
74+
builder.AppendTuple(PyTuple_Size(elem));
75+
subtuples.push_back(elem);
76+
} else if (PyArray_IsScalar(elem, Generic)) {
77+
RETURN_NOT_OK(AppendScalar(elem, builder));
78+
} else if (PyArray_Check(elem)) {
79+
RETURN_NOT_OK(SerializeArray((PyArrayObject*)elem, builder, subdicts, tensors_out));
80+
} else if (elem == Py_None) {
81+
RETURN_NOT_OK(builder.AppendNone());
82+
} else {
83+
if (!pyarrow_serialize_callback) {
84+
std::stringstream ss;
85+
ss << "data type of " << PyBytes_AS_STRING(PyObject_Repr(elem))
86+
<< " not recognized and custom serialization handler not registered";
87+
return Status::NotImplemented(ss.str());
88+
} else {
89+
PyObject* arglist = Py_BuildValue("(O)", elem);
90+
// The reference count of the result of the call to PyObject_CallObject
91+
// must be decremented. This is done in SerializeDict in this file.
92+
PyObject* result = PyObject_CallObject(pyarrow_serialize_callback, arglist);
93+
Py_XDECREF(arglist);
94+
if (!result) { return Status::NotImplemented("python error"); }
95+
builder.AppendDict(PyDict_Size(result));
96+
subdicts.push_back(result);
97+
}
98+
}
99+
return Status::OK();
100+
}
101+
102+
Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder,
103+
std::vector<PyObject*>& subdicts, std::vector<PyObject*>& tensors_out) {
104+
int dtype = PyArray_TYPE(array);
105+
switch (dtype) {
106+
case NPY_BOOL:
107+
case NPY_UINT8:
108+
case NPY_INT8:
109+
case NPY_UINT16:
110+
case NPY_INT16:
111+
case NPY_UINT32:
112+
case NPY_INT32:
113+
case NPY_UINT64:
114+
case NPY_INT64:
115+
case NPY_FLOAT:
116+
case NPY_DOUBLE: {
117+
RETURN_NOT_OK(builder.AppendTensor(tensors_out.size()));
118+
tensors_out.push_back(reinterpret_cast<PyObject*>(array));
119+
} break;
120+
default:
121+
if (!pyarrow_serialize_callback) {
122+
std::stringstream stream;
123+
stream << "numpy data type not recognized: " << dtype;
124+
return Status::NotImplemented(stream.str());
125+
} else {
126+
PyObject* arglist = Py_BuildValue("(O)", array);
127+
// The reference count of the result of the call to PyObject_CallObject
128+
// must be decremented. This is done in SerializeDict in python.cc.
129+
PyObject* result = PyObject_CallObject(pyarrow_serialize_callback, arglist);
130+
Py_XDECREF(arglist);
131+
if (!result) { return Status::NotImplemented("python error"); }
132+
builder.AppendDict(PyDict_Size(result));
133+
subdicts.push_back(result);
134+
}
135+
}
136+
return Status::OK();
137+
}
138+
139+
Status SerializeSequences(std::vector<PyObject*> sequences, int32_t recursion_depth,
140+
std::shared_ptr<Array>* out, std::vector<PyObject*>& tensors_out) {
141+
DCHECK(out);
142+
if (recursion_depth >= kMaxRecursionDepth) {
143+
return Status::NotImplemented(
144+
"This object exceeds the maximum recursion depth. It may contain itself "
145+
"recursively.");
146+
}
147+
SequenceBuilder builder(nullptr);
148+
std::vector<PyObject *> sublists, subtuples, subdicts;
149+
for (const auto& sequence : sequences) {
150+
PyObject* item;
151+
PyObject* iterator = PyObject_GetIter(sequence);
152+
while ((item = PyIter_Next(iterator))) {
153+
Status s = append(item, builder, sublists, subtuples, subdicts, tensors_out);
154+
Py_DECREF(item);
155+
// if an error occurs, we need to decrement the reference counts before returning
156+
if (!s.ok()) {
157+
Py_DECREF(iterator);
158+
return s;
159+
}
160+
}
161+
Py_DECREF(iterator);
162+
}
163+
std::shared_ptr<Array> list;
164+
if (sublists.size() > 0) {
165+
RETURN_NOT_OK(SerializeSequences(sublists, recursion_depth + 1, &list, tensors_out));
166+
}
167+
std::shared_ptr<Array> tuple;
168+
if (subtuples.size() > 0) {
169+
RETURN_NOT_OK(
170+
SerializeSequences(subtuples, recursion_depth + 1, &tuple, tensors_out));
171+
}
172+
std::shared_ptr<Array> dict;
173+
if (subdicts.size() > 0) {
174+
RETURN_NOT_OK(SerializeDict(subdicts, recursion_depth + 1, &dict, tensors_out));
175+
}
176+
return builder.Finish(list, tuple, dict, out);
177+
}
178+
179+
Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth,
180+
std::shared_ptr<Array>* out, std::vector<PyObject*>& tensors_out) {
181+
DictBuilder result;
182+
if (recursion_depth >= kMaxRecursionDepth) {
183+
return Status::NotImplemented(
184+
"This object exceeds the maximum recursion depth. It may contain itself "
185+
"recursively.");
186+
}
187+
std::vector<PyObject *> key_tuples, key_dicts, val_lists, val_tuples, val_dicts, dummy;
188+
for (const auto& dict : dicts) {
189+
PyObject *key, *value;
190+
Py_ssize_t pos = 0;
191+
while (PyDict_Next(dict, &pos, &key, &value)) {
192+
RETURN_NOT_OK(
193+
append(key, result.keys(), dummy, key_tuples, key_dicts, tensors_out));
194+
DCHECK(dummy.size() == 0);
195+
RETURN_NOT_OK(
196+
append(value, result.vals(), val_lists, val_tuples, val_dicts, tensors_out));
197+
}
198+
}
199+
std::shared_ptr<Array> key_tuples_arr;
200+
if (key_tuples.size() > 0) {
201+
RETURN_NOT_OK(SerializeSequences(
202+
key_tuples, recursion_depth + 1, &key_tuples_arr, tensors_out));
203+
}
204+
std::shared_ptr<Array> key_dicts_arr;
205+
if (key_dicts.size() > 0) {
206+
RETURN_NOT_OK(
207+
SerializeDict(key_dicts, recursion_depth + 1, &key_dicts_arr, tensors_out));
208+
}
209+
std::shared_ptr<Array> val_list_arr;
210+
if (val_lists.size() > 0) {
211+
RETURN_NOT_OK(
212+
SerializeSequences(val_lists, recursion_depth + 1, &val_list_arr, tensors_out));
213+
}
214+
std::shared_ptr<Array> val_tuples_arr;
215+
if (val_tuples.size() > 0) {
216+
RETURN_NOT_OK(SerializeSequences(
217+
val_tuples, recursion_depth + 1, &val_tuples_arr, tensors_out));
218+
}
219+
std::shared_ptr<Array> val_dict_arr;
220+
if (val_dicts.size() > 0) {
221+
RETURN_NOT_OK(
222+
SerializeDict(val_dicts, recursion_depth + 1, &val_dict_arr, tensors_out));
223+
}
224+
result.Finish(
225+
key_tuples_arr, key_dicts_arr, val_list_arr, val_tuples_arr, val_dict_arr, out);
226+
227+
// This block is used to decrement the reference counts of the results
228+
// returned by the serialization callback, which is called in SerializeArray
229+
// in numpy.cc as well as in DeserializeDict and in append in this file.
230+
static PyObject* py_type = PyUnicode_FromString("_pytype_");
231+
for (const auto& dict : dicts) {
232+
if (PyDict_Contains(dict, py_type)) {
233+
// If the dictionary contains the key "_pytype_", then the user has to
234+
// have registered a callback.
235+
ARROW_CHECK(pyarrow_serialize_callback);
236+
Py_XDECREF(dict);
237+
}
238+
}
239+
240+
return Status::OK();
241+
}
242+
243+
std::shared_ptr<RecordBatch> MakeBatch(std::shared_ptr<Array> data) {
244+
auto field = std::make_shared<Field>("list", data->type());
245+
std::shared_ptr<Schema> schema(new Schema({field}));
246+
return std::shared_ptr<RecordBatch>(new RecordBatch(schema, data->length(), {data}));
247+
}
248+
249+
} // namespace arrow

0 commit comments

Comments
 (0)