Skip to content

Commit

Permalink
some fixes, more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Dec 11, 2023
1 parent 23ba2a0 commit 4dffacd
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 19 deletions.
9 changes: 8 additions & 1 deletion cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,14 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
ARROW_RETURN_IF(!is_fixed_width(*ext_arr->value_type()),
Status::Invalid(ext_arr->value_type()->ToString(),
" is not valid data type for a tensor"));
auto permutation = ext_type->permutation();
std::vector<int64_t> permutation = ext_type->permutation();
if (permutation.empty()) {
for (int64_t i = 0; i < static_cast<int64_t>(ext_type->ndim()); i++) {
permutation.emplace_back(i);
}
ARROW_LOG(INFO) << "generated permutation: "
<< ::arrow::internal::PrintVector{permutation, ","};
}

std::vector<std::string> dim_names;
if (!ext_type->dim_names().empty()) {
Expand Down
41 changes: 41 additions & 0 deletions cpp/src/arrow/extension/fixed_shape_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/tensor.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/sort.h"

namespace arrow {

Expand Down Expand Up @@ -337,6 +338,24 @@ void CheckTensorRoundtrip(const std::shared_ptr<Tensor>& tensor) {
ASSERT_TRUE(tensor->Equals(*tensor_from_array));
}

void CheckToTensor(const std::shared_ptr<DataType>& ext_type,
const std::shared_ptr<Tensor>& expected_tensor) {
const std::vector<int64_t> values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35};
const std::shared_ptr<DataType> cell_type = fixed_size_list(int64(), 12);
std::vector<std::shared_ptr<Buffer>> buffers = {nullptr, Buffer::Wrap(values)};

auto arr_data = std::make_shared<ArrayData>(int64(), values.size(), buffers);
auto arr = std::make_shared<Int64Array>(arr_data);
ASSERT_OK_AND_ASSIGN(auto fsla_arr, FixedSizeListArray::FromArrays(arr, cell_type));
auto ext_arr = ExtensionType::WrapArray(ext_type, fsla_arr);
const auto tensor_array = std::static_pointer_cast<FixedShapeTensorArray>(ext_arr);

ASSERT_OK_AND_ASSIGN(const auto actual_tensor, tensor_array->ToTensor());
ASSERT_TRUE(actual_tensor->Equals(*expected_tensor));
}

TEST_F(TestExtensionType, RoundtripTensor) {
auto values = Buffer::Wrap(values_);

Expand All @@ -356,6 +375,28 @@ TEST_F(TestExtensionType, RoundtripTensor) {
strides[i], tensor_dim_names[i]));
CheckTensorRoundtrip(tensor);
}

for (size_t i = 0; i < shapes.size(); i++) {
auto cell_shape = std::vector<int64_t>(shapes[i].begin() + 1, shapes[i].end());
auto cell_dim_names = std::vector<std::string>(tensor_dim_names[i].begin() + 1,
tensor_dim_names[i].end());
auto cell_strides = std::vector<int64_t>(strides[i].begin() + 1, strides[i].end());
auto cell_permutation = internal::ArgSort(cell_strides, std::greater<>());
// std::vector<int64_t> cell_shape;
// for (auto j : cell_permutation) {
// cell_shape.push_back(shapes[i][cell_permutation[j] + 1]);
// }

const auto ext_type =
fixed_shape_tensor(value_type_, cell_shape, cell_permutation, cell_dim_names);
auto tmp_dim_names = tensor_dim_names[i];
tmp_dim_names[0] = "";
ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, values, shapes[i],
strides[i], tmp_dim_names));
CheckToTensor(ext_type, tensor);
// TODO: handle non-trivial cases
break;
}
}

TEST_F(TestExtensionType, SliceTensor) {
Expand Down
48 changes: 30 additions & 18 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,25 +1319,26 @@ def test_tensor_type():


def test_tensor_class_methods():
from numpy.lib.stride_tricks import as_strided

values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
np_arr = np.array(values, dtype=np.int8)
tensor_type = pa.fixed_shape_tensor(pa.float32(), [2, 3])
storage = pa.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]],
pa.list_(pa.float32(), 6))
storage = pa.array(np_arr.reshape(2, 6).tolist(), pa.list_(pa.float32(), 6))
arr = pa.ExtensionArray.from_storage(tensor_type, storage)

# TODO: add more get_tensor tests
assert arr.get_tensor(0) == pa.Tensor.from_numpy(
np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32))

# expected = np.array(
# [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=np.float32)
# result = arr.to_numpy_ndarray()
# np.testing.assert_array_equal(result, expected)

# np.testing.assert_array_equal(result, expected)
expected = np_arr.reshape(2, 2, 3).astype(np.float32).tolist()
np.testing.assert_array_equal(arr.to_tensor(), expected)
np.testing.assert_array_equal(arr.to_numpy_ndarray(), expected)

# TODO: offset not correctly handled
expected = np_arr[6:].reshape(1, 2, 3).astype(np.float32)
# expected = np.array([[[7, 8, 9], [10, 11, 12]]], dtype=np.float32)
# result = arr[:1].to_numpy_ndarray()
result = arr[:1].to_numpy_ndarray()
# np.testing.assert_array_equal(result, expected)

arr = np.array(
Expand All @@ -1348,18 +1349,29 @@ def test_tensor_class_methods():
assert tensor_array_from_numpy.type.value_type == pa.float32()
assert tensor_array_from_numpy.type.shape == [2, 3]

# TODO
arr = np.array(
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
dtype=np.float32, order="F")
with pytest.raises(ValueError, match="C-style contiguous segment"):
arr = np_arr.reshape((-1, 3), order="F")
with pytest.raises(ValueError, match="First stride needs to be largest"):
pa.FixedShapeTensorArray.from_numpy_ndarray(arr)

# expected = np.array(values, dtype=np.int8).reshape(1, 2, 3, 2)
# tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 2, 3], permutation=[0, 2, 1])
bw = np.int8().itemsize
arr = as_strided(np_arr, shape=(3, 4), strides=(bw * 4, bw), writeable=False)
tensor_array_from_numpy = pa.FixedShapeTensorArray.from_numpy_ndarray(arr)
assert tensor_array_from_numpy.type.shape == [4]
assert tensor_array_from_numpy.type.permutation == [1]
# TODO: strides not correctly handled
# assert tensor_array_from_numpy.to_tensor() == pa.Tensor.from_numpy(arr)
# assert tensor_array_from_numpy.strides == [bw * 4, bw]

arr = as_strided(np_arr, shape=(1, 2, 3, 2), strides=(
bw * 12, bw, bw * 2, bw * 6), writeable=False)
expected = np.array(values, dtype=np.int8).reshape(1, 2, 3, 2)
# TODO: strides not correctly handled
# np.testing.assert_array_equal(arr, expected)

tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 2, 3], permutation=[0, 2, 1])
storage = pa.array([values], pa.list_(pa.int8(), 12))
# result = pa.ExtensionArray.from_storage(tensor_type, storage)
# TODO
result = pa.ExtensionArray.from_storage(tensor_type, storage)
# TODO: strides not correctly handled
# assert np.testing.assert_array_equal(result.to_numpy_ndarray(), expected)

expected = np.array(values, dtype=np.int8).reshape(1, 2, 2, 3)
Expand Down

0 comments on commit 4dffacd

Please sign in to comment.