Skip to content

Commit

Permalink
Python wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Aug 25, 2023
1 parent 94d948b commit 95372a0
Show file tree
Hide file tree
Showing 8 changed files with 383 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/python/api/arrays.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ may expose data type-specific methods or properties.
UnionArray
ExtensionArray
FixedShapeTensorArray
VariableShapeTensorArray

.. _api.scalar:

Expand Down
5 changes: 4 additions & 1 deletion python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def print_entry(label, value):
dictionary,
run_end_encoded,
fixed_shape_tensor,
variable_shape_tensor,
field,
type_for_alias,
DataType, DictionaryType, StructType,
Expand All @@ -180,7 +181,8 @@ def print_entry(label, value):
FixedSizeBinaryType, Decimal128Type, Decimal256Type,
BaseExtensionType, ExtensionType,
RunEndEncodedType, FixedShapeTensorType,
PyExtensionType, UnknownExtensionType,
VariableShapeTensorType, PyExtensionType,
UnknownExtensionType,
register_extension_type, unregister_extension_type,
DictionaryMemo,
KeyValueMetadata,
Expand Down Expand Up @@ -212,6 +214,7 @@ def print_entry(label, value):
MonthDayNanoIntervalArray,
Decimal128Array, Decimal256Array, StructArray, ExtensionArray,
RunEndEncodedArray, FixedShapeTensorArray,
VariableShapeTensorArray,
scalar, NA, _NULL as NULL, Scalar,
NullScalar, BooleanScalar,
Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar,
Expand Down
48 changes: 48 additions & 0 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -3264,6 +3264,54 @@ class FixedShapeTensorArray(ExtensionArray):
)


class VariableShapeTensorArray(ExtensionArray):
"""
Concrete class for variable shape tensor extension arrays.
Examples
--------
Define the extension type for tensor array
>>> import pyarrow as pa
>>> tensor_type = pa.variable_shape_tensor(pa.int32(), 2)
Create an extension array
>>> shapes = pa.array([[2, 3], [1, 2]], pa.list_(pa.uint32(), 2))
>>> values = pa.array([[1, 2, 3, 4, 5, 6], [7, 8]], pa.list_(pa.int32()))
>>> arr = pa.StructArray.from_arrays([shapes, values], names=["shape", "data"])
>>> pa.ExtensionArray.from_storage(tensor_type, arr)
<pyarrow.lib.VariableShapeTensorArray object at ...>
-- is_valid: all not null
-- child 0 type: fixed_size_list<item: uint32>[2]
[
[
2,
3
],
[
1,
2
]
]
-- child 1 type: list<item: int32>
[
[
1,
2,
3,
4,
5,
6
],
[
7,
8
]
]
"""


cdef dict _array_classes = {
_Type_NA: NullArray,
_Type_BOOL: BooleanArray,
Expand Down
21 changes: 21 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2634,6 +2634,27 @@ cdef extern from "arrow/extension_type.h" namespace "arrow":
shared_ptr[CArray] storage()


cdef extern from "arrow/extension/variable_shape_tensor.h" namespace "arrow::extension":
cdef cppclass CVariableShapeTensorType \
" arrow::extension::VariableShapeTensorType"(CExtensionType):

@staticmethod
CResult[shared_ptr[CDataType]] Make(const shared_ptr[CDataType]& value_type,
const uint32_t ndim,
const vector[int64_t]& permutation,
const vector[c_string]& dim_names)

CResult[shared_ptr[CDataType]] Deserialize(const shared_ptr[CDataType] storage_type,
const c_string& serialized_data) const

c_string Serialize() const

const shared_ptr[CDataType] value_type()
const uint32_t ndim()
const vector[int64_t] permutation()
const vector[c_string] dim_names()


cdef extern from "arrow/extension/fixed_shape_tensor.h" namespace "arrow::extension":
cdef cppclass CFixedShapeTensorType \
" arrow::extension::FixedShapeTensorType"(CExtensionType):
Expand Down
5 changes: 5 additions & 0 deletions python/pyarrow/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ cdef class ExtensionType(BaseExtensionType):
const CPyExtensionType* cpy_ext_type


cdef class VariableShapeTensorType(BaseExtensionType):
cdef:
const CVariableShapeTensorType* tensor_ext_type


cdef class FixedShapeTensorType(BaseExtensionType):
cdef:
const CFixedShapeTensorType* tensor_ext_type
Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/public-api.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ cdef api object pyarrow_wrap_data_type(
cpy_ext_type = dynamic_cast[_CPyExtensionTypePtr](ext_type)
if cpy_ext_type != nullptr:
return cpy_ext_type.GetInstance()
elif ext_type.extension_name() == b"arrow.variable_shape_tensor":
out = VariableShapeTensorType.__new__(VariableShapeTensorType)
elif ext_type.extension_name() == b"arrow.fixed_shape_tensor":
out = FixedShapeTensorType.__new__(FixedShapeTensorType)
else:
Expand Down
117 changes: 117 additions & 0 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,39 @@ def test_tensor_type():
assert tensor_type.dim_names == ['C', 'H', 'W']
assert tensor_type.permutation is None

tensor_type = pa.variable_shape_tensor(pa.int8(), 2)
expected_storage_type = pa.struct([
pa.field("shape", pa.list_(pa.uint32(), 2)),
pa.field("data", pa.list_(pa.int8()))
])
assert tensor_type.extension_name == "arrow.variable_shape_tensor"
assert tensor_type.storage_type == expected_storage_type
assert tensor_type.ndim == 2
assert tensor_type.dim_names is None
assert tensor_type.permutation is None

tensor_type = pa.variable_shape_tensor(pa.int64(), 3, dim_names=['C', 'H', 'W'])
expected_storage_type = pa.struct([
pa.field("shape", pa.list_(pa.uint32(), 3)),
pa.field("data", pa.list_(pa.int64()))
])
assert tensor_type.extension_name == "arrow.variable_shape_tensor"
assert tensor_type.storage_type == expected_storage_type
assert tensor_type.ndim == 3
assert tensor_type.dim_names == ['C', 'H', 'W']
assert tensor_type.permutation is None

tensor_type = pa.variable_shape_tensor(pa.bool_(), 2, permutation=[1, 0])
expected_storage_type = pa.struct([
pa.field("shape", pa.list_(pa.uint32(), 2)),
pa.field("data", pa.list_(pa.bool_()))
])
assert tensor_type.extension_name == "arrow.variable_shape_tensor"
assert tensor_type.storage_type == expected_storage_type
assert tensor_type.ndim == 2
assert tensor_type.dim_names is None
assert tensor_type.permutation == [1, 0]


def test_tensor_class_methods():
tensor_type = pa.fixed_shape_tensor(pa.float32(), [2, 3])
Expand Down Expand Up @@ -1249,6 +1282,28 @@ def test_tensor_class_methods():
arr.to_numpy_ndarray()


def test_variable_size_tensor_class_method():
shape_type = pa.list_(pa.uint32(), 2)
value_type = pa.int8()
tensor_type = pa.variable_shape_tensor(value_type, 2)
fields = [pa.field("shape", shape_type), pa.field("data", pa.list_(value_type))]
shapes = pa.array([[2, 3], [1, 2]], shape_type)
values = pa.array([[1, 2, 3, 4, 5, 6], [7, 8]], pa.list_(value_type))
storage = pa.array(
[([2, 3], [1, 2, 3, 4, 5, 6]), ([1, 2], [7, 8])], type=pa.struct(fields)
)

struct_arr = pa.StructArray.from_arrays([shapes, values], fields=fields)
arr = pa.ExtensionArray.from_storage(tensor_type, struct_arr)

assert arr.to_pylist() == [
{"data": [1, 2, 3, 4, 5, 6], "shape": [2, 3]},
{"data": [7, 8], "shape": [1, 2]},
]
assert pa.ExtensionArray.from_storage(tensor_type, storage).equals(arr)
assert arr.type == tensor_type


@pytest.mark.parametrize("tensor_type", (
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3]),
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3], permutation=[0, 2, 1]),
Expand Down Expand Up @@ -1279,6 +1334,46 @@ def test_tensor_type_ipc(tensor_type):
assert result.type.shape == [2, 2, 3]


@pytest.mark.parametrize("tensor_type", (
pa.variable_shape_tensor(pa.int8(), 2),
pa.variable_shape_tensor(pa.int8(), 2, permutation=[1, 0]),
pa.variable_shape_tensor(pa.int8(), 2, dim_names=['H', 'W'])
))
def test_variable_size_tensor_type_ipc(tensor_type):
shape_type = tensor_type.storage_type.field(0).type
values_type = tensor_type.storage_type.field(1).type
shapes = pa.array([[2, 3], [1, 2]], shape_type)
values = pa.array([[1, 2, 3, 4, 5, 6], [7, 8]], values_type)

struct_arr = pa.StructArray.from_arrays([shapes, values], names=["shape", "data"])
arr = pa.ExtensionArray.from_storage(tensor_type, struct_arr)
batch = pa.RecordBatch.from_arrays([arr], ["ext"])

# check the built array has exactly the expected clss
tensor_class = tensor_type.__arrow_ext_class__()
assert isinstance(arr, tensor_class)

buf = ipc_write_batch(batch)
del batch
batch = ipc_read_batch(buf)

result = batch.column(0)
# check the deserialized array class is the expected one
assert isinstance(result, tensor_class)
assert result.type.extension_name == "arrow.variable_shape_tensor"
assert arr.storage.to_pylist() == [
{"data": [1, 2, 3, 4, 5, 6], "shape": [2, 3]},
{"data": [7, 8], "shape": [1, 2]},
]

# we get back an actual TensorType
assert isinstance(result.type, pa.VariableShapeTensorType)
assert result.type.value_type == pa.int8()
assert result.type.ndim == 2
assert result.type.permutation == tensor_type.permutation
assert result.type.dim_names == tensor_type.dim_names


def test_tensor_type_equality():
tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 2, 3])
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
Expand All @@ -1288,6 +1383,14 @@ def test_tensor_type_equality():
assert tensor_type == tensor_type2
assert not tensor_type == tensor_type3

tensor_type = pa.variable_shape_tensor(pa.int8(), 2)
assert tensor_type.extension_name == "arrow.variable_shape_tensor"

tensor_type2 = pa.variable_shape_tensor(pa.int8(), 2)
tensor_type3 = pa.variable_shape_tensor(pa.uint8(), 2)
assert tensor_type == tensor_type2
assert not tensor_type == tensor_type3


@pytest.mark.pandas
def test_extension_to_pandas_storage_type(registered_period_type):
Expand Down Expand Up @@ -1352,3 +1455,17 @@ def test_tensor_type_is_picklable():
result = pickle.loads(pickle.dumps(expected_arr))

assert result == expected_arr

expected_type = pa.variable_shape_tensor(pa.int32(), 2)
result = pickle.loads(pickle.dumps(expected_type))

assert result == expected_type

shapes = pa.array([[2, 3], [1, 2]], pa.list_(pa.uint32(), 2))
values = pa.array([[1, 2, 3, 4, 5, 6], [7, 8]], pa.list_(pa.int32()))
arr = pa.StructArray.from_arrays([shapes, values], names=["shape", "data"])
expected_arr = pa.ExtensionArray.from_storage(expected_type, arr)

result = pickle.loads(pickle.dumps(expected_arr))

assert result == expected_arr
Loading

0 comments on commit 95372a0

Please sign in to comment.