From 8e6e99093f493bbe517cdec102d2dca01d2fcac5 Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Fri, 29 Mar 2024 23:56:03 +0100 Subject: [PATCH] refactor ComputeStrides --- .../extension/tensor_extension_array_test.cc | 8 ++-- cpp/src/arrow/extension/tensor_internal.cc | 39 ++++--------------- 2 files changed, 11 insertions(+), 36 deletions(-) diff --git a/cpp/src/arrow/extension/tensor_extension_array_test.cc b/cpp/src/arrow/extension/tensor_extension_array_test.cc index 83c3760888713..530ae27c85a14 100644 --- a/cpp/src/arrow/extension/tensor_extension_array_test.cc +++ b/cpp/src/arrow/extension/tensor_extension_array_test.cc @@ -514,16 +514,16 @@ TEST_F(TestFixedShapeTensorType, ComputeStrides) { auto ext_type_5 = internal::checked_pointer_cast( fixed_shape_tensor(int64(), {3, 4, 7}, {1, 0, 2})); - ASSERT_EQ(ext_type_5->strides(), (std::vector{56, 224, 8})); + ASSERT_EQ(ext_type_5->strides(), (std::vector{56, 168, 8})); ASSERT_EQ(ext_type_5->Serialize(), R"({"shape":[3,4,7],"permutation":[1,0,2]})"); auto ext_type_6 = internal::checked_pointer_cast( fixed_shape_tensor(int64(), {3, 4, 7}, {1, 2, 0}, {})); - ASSERT_EQ(ext_type_6->strides(), (std::vector{56, 8, 224})); + ASSERT_EQ(ext_type_6->strides(), (std::vector{32, 8, 96})); ASSERT_EQ(ext_type_6->Serialize(), R"({"shape":[3,4,7],"permutation":[1,2,0]})"); auto ext_type_7 = internal::checked_pointer_cast( fixed_shape_tensor(int32(), {3, 4, 7}, {2, 0, 1}, {})); - ASSERT_EQ(ext_type_7->strides(), (std::vector{4, 112, 16})); + ASSERT_EQ(ext_type_7->strides(), (std::vector{4, 84, 12})); ASSERT_EQ(ext_type_7->Serialize(), R"({"shape":[3,4,7],"permutation":[2,0,1]})"); } @@ -594,7 +594,7 @@ TEST_F(TestFixedShapeTensorType, GetTensor) { // Get tensor from extension array with non-trivial permutation ASSERT_OK_AND_ASSIGN(auto expected_permuted_tensor, Tensor::Make(value_type_, Buffer::Wrap(element_values[i]), - {4, 3}, {8, 24}, {"y", "x"})); + {4, 3}, {8, 32}, {"y", "x"})); ASSERT_OK_AND_ASSIGN(scalar, permuted_array->GetScalar(i)); ASSERT_OK_AND_ASSIGN(auto actual_permuted_tensor, exact_permuted_ext_type->MakeTensor( diff --git a/cpp/src/arrow/extension/tensor_internal.cc b/cpp/src/arrow/extension/tensor_internal.cc index d021553702efb..b3d3f626441ae 100644 --- a/cpp/src/arrow/extension/tensor_internal.cc +++ b/cpp/src/arrow/extension/tensor_internal.cc @@ -47,43 +47,18 @@ Status IsPermutationValid(const std::vector& permutation) { Result> ComputeStrides(const std::shared_ptr& value_type, const std::vector& shape, const std::vector& permutation) { - const auto fixed_width_type = - internal::checked_pointer_cast(value_type); - + const auto& fw_type = checked_cast(*value_type); std::vector strides; - if (permutation.empty()) { - ARROW_DCHECK_OK( - internal::ComputeRowMajorStrides(*fixed_width_type.get(), shape, &strides)); - return strides; - } - const int byte_width = value_type->byte_width(); - - int64_t remaining = 0; - if (!shape.empty() && shape.front() > 0) { - remaining = byte_width; - for (auto i : permutation) { - if (i > 0) { - if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) { - return Status::Invalid( - "Strides computed from shape would not fit in 64-bit integer"); - } - } - } - } - if (remaining == 0) { - strides.assign(shape.size(), byte_width); + if (permutation.empty()) { + ARROW_DCHECK_OK(internal::ComputeRowMajorStrides(fw_type, shape, &strides)); return strides; } - strides.push_back(remaining); - for (auto i : permutation) { - if (i > 0) { - remaining /= shape[i]; - strides.push_back(remaining); - } - } - DCHECK_EQ(strides.back(), byte_width); + auto permuted_shape = std::move(shape); + auto reverse_permutation = internal::ArgSort(permutation, std::less<>()); + Permute(reverse_permutation, &permuted_shape); + ARROW_DCHECK_OK(internal::ComputeRowMajorStrides(fw_type, permuted_shape, &strides)); Permute(permutation, &strides); return strides;