Skip to content

Commit

Permalink
change slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Dec 13, 2023
1 parent 5085b09 commit 9c60b85
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
26 changes: 16 additions & 10 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,16 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
// To convert an array of n dimensional tensors to a n+1 dimensional tensor we
// interpret the array's length as the first dimension the new tensor.

const auto ext_arr =
const auto storage_array =
internal::checked_pointer_cast<FixedSizeListArray>(this->storage());
ARROW_ASSIGN_OR_RAISE(const auto flattened_storage_array, storage_array->Flatten());

const auto ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(this->type());
ARROW_RETURN_IF(!is_fixed_width(*ext_arr->value_type()),
Status::Invalid(ext_arr->value_type()->ToString(),
" is not valid data type for a tensor"));
const auto value_type = ext_type->value_type();
ARROW_RETURN_IF(
!is_fixed_width(*value_type),
Status::Invalid(value_type->ToString(), " is not valid data type for a tensor"));

std::vector<int64_t> permutation = ext_type->permutation();
if (permutation.empty()) {
Expand All @@ -361,18 +364,21 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
}

std::vector<int64_t> shape = ext_type->shape();
auto cell_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
std::multiplies<>());
shape.insert(shape.begin(), 1, this->length());
internal::Permute<int64_t>(permutation, &shape);

std::vector<int64_t> tensor_strides;
const auto value_type =
internal::checked_pointer_cast<FixedWidthType>(ext_arr->value_type());
const auto fw_value_type = internal::checked_pointer_cast<FixedWidthType>(value_type);
ARROW_RETURN_NOT_OK(
ComputeStrides(*value_type.get(), shape, permutation, &tensor_strides));
ARROW_ASSIGN_OR_RAISE(const auto flattened_array, ext_arr->Flatten());
ComputeStrides(*fw_value_type.get(), shape, permutation, &tensor_strides));

return Tensor::Make(ext_arr->value_type(), flattened_array->data()->buffers[1], shape,
tensor_strides, dim_names);
ARROW_ASSIGN_OR_RAISE(
const auto buffer,
SliceBufferSafe(flattened_storage_array->data()->buffers[1],
this->offset() * cell_size * value_type->byte_width()));
return Tensor::Make(value_type, buffer, shape, tensor_strides, dim_names);
}

Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
Expand Down
6 changes: 2 additions & 4 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,8 +1334,7 @@ def test_tensor_class_methods(value_type):

expected = np.array([[[7, 8, 9], [10, 11, 12]]], dtype=value_type)
result = arr[1:].to_numpy_ndarray()
# TODO: offset of sliced pa.array is not correctly handled
# np.testing.assert_array_equal(result, expected)
np.testing.assert_array_equal(result, expected)

values = [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]
flat_arr = np.array(values[0], dtype=value_type)
Expand Down Expand Up @@ -1413,8 +1412,7 @@ def test_tensor_array_from_numpy(value_type):
arr = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]], dtype=value_type)
expected = arr[1:]
result = pa.FixedShapeTensorArray.from_numpy_ndarray(arr)[1:].to_numpy_ndarray()
# TODO: offset of sliced pa.array is not correctly handled
# np.testing.assert_array_equal(result, expected)
np.testing.assert_array_equal(result, expected)


@pytest.mark.parametrize("tensor_type", (
Expand Down

0 comments on commit 9c60b85

Please sign in to comment.