diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc b/cpp/src/arrow/extension/fixed_shape_tensor.cc index ae509d5bb2b3a..8e03525d1e107 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor.cc +++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc @@ -37,53 +37,7 @@ namespace rj = arrow::rapidjson; -namespace arrow { - -namespace extension { - -namespace { - -Status ComputeStrides(const std::shared_ptr& value_type, - const std::vector& shape, - const std::vector& permutation, - std::vector* strides) { - auto fixed_width_type = internal::checked_pointer_cast(value_type); - if (permutation.empty()) { - return internal::ComputeRowMajorStrides(*fixed_width_type.get(), shape, strides); - } - const int byte_width = value_type->byte_width(); - - int64_t remaining = 0; - if (!shape.empty() && shape.front() > 0) { - remaining = byte_width; - for (auto i : permutation) { - if (i > 0) { - if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) { - return Status::Invalid( - "Strides computed from shape would not fit in 64-bit integer"); - } - } - } - } - - if (remaining == 0) { - strides->assign(shape.size(), byte_width); - return Status::OK(); - } - - strides->push_back(remaining); - for (auto i : permutation) { - if (i > 0) { - remaining /= shape[i]; - strides->push_back(remaining); - } - } - internal::Permute(permutation, strides); - - return Status::OK(); -} - -} // namespace +namespace arrow::extension { bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const { if (extension_name() != other.extension_name()) { @@ -238,7 +192,8 @@ Result> FixedShapeTensorType::MakeTensor( } std::vector strides; - RETURN_NOT_OK(ComputeStrides(value_type, shape, permutation, &strides)); + RETURN_NOT_OK( + internal::ComputeStrides(ext_type.value_type(), shape, permutation, &strides)); const auto start_position = array->offset() * byte_width; const auto size = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies<>()); @@ -377,9 +332,8 @@ const Result> FixedShapeTensorArray::ToTensor() const { internal::Permute(permutation, &shape); std::vector tensor_strides; - const auto* fw_value_type = internal::checked_cast(value_type.get()); ARROW_RETURN_NOT_OK( - ComputeStrides(*fw_value_type, shape, permutation, &tensor_strides)); + internal::ComputeStrides(value_type, shape, permutation, &tensor_strides)); const auto& raw_buffer = this->storage()->data()->child_data[0]->buffers[1]; ARROW_ASSIGN_OR_RAISE( @@ -413,10 +367,9 @@ Result> FixedShapeTensorType::Make( const std::vector& FixedShapeTensorType::strides() { if (strides_.empty()) { - auto value_type = internal::checked_cast(this->value_type_.get()); std::vector tensor_strides; - ARROW_CHECK_OK( - ComputeStrides(*value_type, this->shape(), this->permutation(), &tensor_strides)); + ARROW_CHECK_OK(internal::ComputeStrides(this->value_type_, this->shape(), + this->permutation(), &tensor_strides)); strides_ = tensor_strides; } return strides_; @@ -431,5 +384,4 @@ std::shared_ptr fixed_shape_tensor(const std::shared_ptr& va return maybe_type.MoveValueUnsafe(); } -} // namespace extension -} // namespace arrow +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.h b/cpp/src/arrow/extension/fixed_shape_tensor.h index 80a602021c60b..f30f8cf0857ee 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor.h +++ b/cpp/src/arrow/extension/fixed_shape_tensor.h @@ -15,12 +15,9 @@ // specific language governing permissions and limitations // under the License. -#pragma once - #include "arrow/extension_type.h" -namespace arrow { -namespace extension { +namespace arrow::extension { class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray { public: @@ -126,5 +123,4 @@ ARROW_EXPORT std::shared_ptr fixed_shape_tensor( const std::vector& permutation = {}, const std::vector& dim_names = {}); -} // namespace extension -} // namespace arrow +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/tensor_extension_array_test.cc b/cpp/src/arrow/extension/tensor_extension_array_test.cc index 63a270b5226a9..548c7fcd9d2ba 100644 --- a/cpp/src/arrow/extension/tensor_extension_array_test.cc +++ b/cpp/src/arrow/extension/tensor_extension_array_test.cc @@ -506,7 +506,7 @@ TEST_F(TestFixedShapeTensorType, ComputeStrides) { ASSERT_EQ(ext_type_6->Serialize(), R"({"shape":[3,4,7],"permutation":[1,2,0]})"); auto ext_type_7 = internal::checked_pointer_cast( fixed_shape_tensor(int32(), {3, 4, 7}, {2, 0, 1}, {})); - ASSERT_EQ(ext_type_7->strides(), (std::vector{4, 112, 28})); + ASSERT_EQ(ext_type_7->strides(), (std::vector{4, 112, 16})); ASSERT_EQ(ext_type_7->Serialize(), R"({"shape":[3,4,7],"permutation":[2,0,1]})"); } @@ -577,7 +577,7 @@ TEST_F(TestFixedShapeTensorType, GetTensor) { // Get tensor from extension array with non-trivial permutation ASSERT_OK_AND_ASSIGN(auto expected_permuted_tensor, Tensor::Make(value_type_, Buffer::Wrap(element_values[i]), - {4, 3}, {8, 32}, {"y", "x"})); + {4, 3}, {8, 24}, {"y", "x"})); ASSERT_OK_AND_ASSIGN(scalar, permuted_array->GetScalar(i)); ASSERT_OK_AND_ASSIGN(auto actual_permuted_tensor, exact_permuted_ext_type->MakeTensor( diff --git a/cpp/src/arrow/extension/tensor_internal.cc b/cpp/src/arrow/extension/tensor_internal.cc index 02a3fdbbae75c..2f3d8ae5d05d1 100644 --- a/cpp/src/arrow/extension/tensor_internal.cc +++ b/cpp/src/arrow/extension/tensor_internal.cc @@ -44,14 +44,44 @@ Status IsPermutationValid(const std::vector& permutation) { return Status::OK(); } -Result> ComputeStrides(const std::shared_ptr& value_type, - const std::vector& shape, - const std::vector& permutation) { - const auto& fw_type = checked_cast(*value_type); - std::vector strides; - ARROW_DCHECK_OK(internal::ComputeRowMajorStrides(fw_type, shape, &strides)); - // If the permutation is empty, the strides are already in the correct order. - internal::Permute(permutation, &strides); - return strides; +Status ComputeStrides(const std::shared_ptr& value_type, + const std::vector& shape, + const std::vector& permutation, + std::vector* strides) { + auto fixed_width_type = internal::checked_pointer_cast(value_type); + if (permutation.empty()) { + return internal::ComputeRowMajorStrides(*fixed_width_type.get(), shape, strides); + } + const int byte_width = value_type->byte_width(); + + int64_t remaining = 0; + if (!shape.empty() && shape.front() > 0) { + remaining = byte_width; + for (auto i : permutation) { + if (i > 0) { + if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) { + return Status::Invalid( + "Strides computed from shape would not fit in 64-bit integer"); + } + } + } + } + + if (remaining == 0) { + strides->assign(shape.size(), byte_width); + return Status::OK(); + } + + strides->push_back(remaining); + for (auto i : permutation) { + if (i > 0) { + remaining /= shape[i]; + strides->push_back(remaining); + } + } + internal::Permute(permutation, strides); + + return Status::OK(); } + } // namespace arrow::internal diff --git a/cpp/src/arrow/extension/tensor_internal.h b/cpp/src/arrow/extension/tensor_internal.h index 40be4b54344f4..1a0bd0b29c2e5 100644 --- a/cpp/src/arrow/extension/tensor_internal.h +++ b/cpp/src/arrow/extension/tensor_internal.h @@ -28,8 +28,9 @@ ARROW_EXPORT Status IsPermutationValid(const std::vector& permutation); ARROW_EXPORT -Result> ComputeStrides(const std::shared_ptr& value_type, - const std::vector& shape, - const std::vector& permutation); +Status ComputeStrides(const std::shared_ptr& value_type, + const std::vector& shape, + const std::vector& permutation, + std::vector* strides); } // namespace arrow::internal diff --git a/cpp/src/arrow/extension/variable_shape_tensor.cc b/cpp/src/arrow/extension/variable_shape_tensor.cc index 739f2c8edf2f1..96e82689183e4 100644 --- a/cpp/src/arrow/extension/variable_shape_tensor.cc +++ b/cpp/src/arrow/extension/variable_shape_tensor.cc @@ -36,8 +36,7 @@ namespace rj = arrow::rapidjson; -namespace arrow { -namespace extension { +namespace arrow::extension { bool VariableShapeTensorType::ExtensionEquals(const ExtensionType& other) const { if (extension_name() != other.extension_name()) { @@ -206,35 +205,35 @@ std::shared_ptr VariableShapeTensorType::MakeArray( Result> VariableShapeTensorType::MakeTensor( const std::shared_ptr& scalar) { - const auto tensor_scalar = internal::checked_pointer_cast(scalar->value); - const auto ext_type = - internal::checked_pointer_cast(scalar->type); + const auto& tensor_scalar = internal::checked_cast(*scalar->value); + const auto& ext_type = + internal::checked_cast(*scalar->type); - ARROW_ASSIGN_OR_RAISE(const auto data_scalar, tensor_scalar->field(0)); - ARROW_ASSIGN_OR_RAISE(const auto shape_scalar, tensor_scalar->field(1)); - ARROW_CHECK(tensor_scalar->is_valid); + ARROW_ASSIGN_OR_RAISE(const auto data_scalar, tensor_scalar.field(0)); + ARROW_ASSIGN_OR_RAISE(const auto shape_scalar, tensor_scalar.field(1)); + ARROW_CHECK(tensor_scalar.is_valid); const auto data_array = internal::checked_pointer_cast(data_scalar)->value; const auto shape_array = internal::checked_pointer_cast( internal::checked_pointer_cast(shape_scalar)->value); - const auto value_type = - internal::checked_pointer_cast(ext_type->value_type()); + const auto& value_type = + internal::checked_cast(*ext_type.value_type()); if (data_array->null_count() > 0) { return Status::Invalid("Cannot convert data with nulls to Tensor."); } - auto permutation = ext_type->permutation(); + auto permutation = ext_type.permutation(); if (permutation.empty()) { - permutation.resize(ext_type->ndim()); + permutation.resize(ext_type.ndim()); std::iota(permutation.begin(), permutation.end(), 0); } - ARROW_CHECK_EQ(shape_array->length(), ext_type->ndim()); + ARROW_CHECK_EQ(shape_array->length(), ext_type.ndim()); std::vector shape; - shape.reserve(ext_type->ndim()); - for (int64_t j = 0; j < static_cast(ext_type->ndim()); ++j) { + shape.reserve(ext_type.ndim()); + for (int64_t j = 0; j < static_cast(ext_type.ndim()); ++j) { const auto size_value = shape_array->Value(j); if (size_value < 0) { return Status::Invalid("shape must have non-negative values"); @@ -242,16 +241,17 @@ Result> VariableShapeTensorType::MakeTensor( shape.push_back(std::move(size_value)); } - std::vector dim_names = ext_type->dim_names(); + std::vector dim_names = ext_type.dim_names(); if (!dim_names.empty()) { internal::Permute(permutation, &dim_names); } - ARROW_ASSIGN_OR_RAISE(std::vector strides, - internal::ComputeStrides(value_type, shape, permutation)); + std::vector strides; + ARROW_RETURN_NOT_OK( + internal::ComputeStrides(ext_type.value_type(), shape, permutation, &strides)); internal::Permute(permutation, &shape); - const auto byte_width = value_type->byte_width(); + const auto byte_width = value_type.byte_width(); const auto start_position = data_array->offset() * byte_width; const auto size = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies<>()); @@ -260,8 +260,8 @@ Result> VariableShapeTensorType::MakeTensor( const auto buffer, SliceBufferSafe(data_array->data()->buffers[1], start_position, size * byte_width)); - return Tensor::Make(value_type, std::move(buffer), std::move(shape), std::move(strides), - ext_type->dim_names()); + return Tensor::Make(ext_type.value_type(), std::move(buffer), std::move(shape), + std::move(strides), ext_type.dim_names()); } Result> VariableShapeTensorType::Make( @@ -311,5 +311,4 @@ std::shared_ptr variable_shape_tensor( return maybe_type.MoveValueUnsafe(); } -} // namespace extension -} // namespace arrow +} // namespace arrow::extension