Skip to content

Commit

Permalink
refactor ComputeStrides
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Mar 29, 2024
1 parent 48f9191 commit 8e6e990
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 36 deletions.
8 changes: 4 additions & 4 deletions cpp/src/arrow/extension/tensor_extension_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,16 +514,16 @@ TEST_F(TestFixedShapeTensorType, ComputeStrides) {

auto ext_type_5 = internal::checked_pointer_cast<FixedShapeTensorType>(
fixed_shape_tensor(int64(), {3, 4, 7}, {1, 0, 2}));
ASSERT_EQ(ext_type_5->strides(), (std::vector<int64_t>{56, 224, 8}));
ASSERT_EQ(ext_type_5->strides(), (std::vector<int64_t>{56, 168, 8}));
ASSERT_EQ(ext_type_5->Serialize(), R"({"shape":[3,4,7],"permutation":[1,0,2]})");

auto ext_type_6 = internal::checked_pointer_cast<FixedShapeTensorType>(
fixed_shape_tensor(int64(), {3, 4, 7}, {1, 2, 0}, {}));
ASSERT_EQ(ext_type_6->strides(), (std::vector<int64_t>{56, 8, 224}));
ASSERT_EQ(ext_type_6->strides(), (std::vector<int64_t>{32, 8, 96}));
ASSERT_EQ(ext_type_6->Serialize(), R"({"shape":[3,4,7],"permutation":[1,2,0]})");
auto ext_type_7 = internal::checked_pointer_cast<FixedShapeTensorType>(
fixed_shape_tensor(int32(), {3, 4, 7}, {2, 0, 1}, {}));
ASSERT_EQ(ext_type_7->strides(), (std::vector<int64_t>{4, 112, 16}));
ASSERT_EQ(ext_type_7->strides(), (std::vector<int64_t>{4, 84, 12}));
ASSERT_EQ(ext_type_7->Serialize(), R"({"shape":[3,4,7],"permutation":[2,0,1]})");
}

Expand Down Expand Up @@ -594,7 +594,7 @@ TEST_F(TestFixedShapeTensorType, GetTensor) {
// Get tensor from extension array with non-trivial permutation
ASSERT_OK_AND_ASSIGN(auto expected_permuted_tensor,
Tensor::Make(value_type_, Buffer::Wrap(element_values[i]),
{4, 3}, {8, 24}, {"y", "x"}));
{4, 3}, {8, 32}, {"y", "x"}));
ASSERT_OK_AND_ASSIGN(scalar, permuted_array->GetScalar(i));
ASSERT_OK_AND_ASSIGN(auto actual_permuted_tensor,
exact_permuted_ext_type->MakeTensor(
Expand Down
39 changes: 7 additions & 32 deletions cpp/src/arrow/extension/tensor_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,43 +47,18 @@ Status IsPermutationValid(const std::vector<int64_t>& permutation) {
Result<std::vector<int64_t>> ComputeStrides(const std::shared_ptr<DataType>& value_type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation) {
const auto fixed_width_type =
internal::checked_pointer_cast<FixedWidthType>(value_type);

const auto& fw_type = checked_cast<const FixedWidthType&>(*value_type);
std::vector<int64_t> strides;
if (permutation.empty()) {
ARROW_DCHECK_OK(
internal::ComputeRowMajorStrides(*fixed_width_type.get(), shape, &strides));
return strides;
}
const int byte_width = value_type->byte_width();

int64_t remaining = 0;
if (!shape.empty() && shape.front() > 0) {
remaining = byte_width;
for (auto i : permutation) {
if (i > 0) {
if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
return Status::Invalid(
"Strides computed from shape would not fit in 64-bit integer");
}
}
}
}

if (remaining == 0) {
strides.assign(shape.size(), byte_width);
if (permutation.empty()) {
ARROW_DCHECK_OK(internal::ComputeRowMajorStrides(fw_type, shape, &strides));
return strides;
}

strides.push_back(remaining);
for (auto i : permutation) {
if (i > 0) {
remaining /= shape[i];
strides.push_back(remaining);
}
}
DCHECK_EQ(strides.back(), byte_width);
auto permuted_shape = std::move(shape);
auto reverse_permutation = internal::ArgSort(permutation, std::less<>());
Permute(reverse_permutation, &permuted_shape);
ARROW_DCHECK_OK(internal::ComputeRowMajorStrides(fw_type, permuted_shape, &strides));
Permute(permutation, &strides);

return strides;
Expand Down

0 comments on commit 8e6e990

Please sign in to comment.