Skip to content

Commit

Permalink
Post rebase changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Jun 7, 2024
1 parent 6ccbecc commit e9edba4
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 98 deletions.
62 changes: 7 additions & 55 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,53 +37,7 @@

namespace rj = arrow::rapidjson;

namespace arrow {

namespace extension {

namespace {

Status ComputeStrides(const std::shared_ptr<DataType>& value_type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation,
std::vector<int64_t>* strides) {
auto fixed_width_type = internal::checked_pointer_cast<FixedWidthType>(value_type);
if (permutation.empty()) {
return internal::ComputeRowMajorStrides(*fixed_width_type.get(), shape, strides);
}
const int byte_width = value_type->byte_width();

int64_t remaining = 0;
if (!shape.empty() && shape.front() > 0) {
remaining = byte_width;
for (auto i : permutation) {
if (i > 0) {
if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
return Status::Invalid(
"Strides computed from shape would not fit in 64-bit integer");
}
}
}
}

if (remaining == 0) {
strides->assign(shape.size(), byte_width);
return Status::OK();
}

strides->push_back(remaining);
for (auto i : permutation) {
if (i > 0) {
remaining /= shape[i];
strides->push_back(remaining);
}
}
internal::Permute(permutation, strides);

return Status::OK();
}

} // namespace
namespace arrow::extension {

bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
if (extension_name() != other.extension_name()) {
Expand Down Expand Up @@ -238,7 +192,8 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
}

std::vector<int64_t> strides;
RETURN_NOT_OK(ComputeStrides(value_type, shape, permutation, &strides));
RETURN_NOT_OK(
internal::ComputeStrides(ext_type.value_type(), shape, permutation, &strides));
const auto start_position = array->offset() * byte_width;
const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
std::multiplies<>());
Expand Down Expand Up @@ -377,9 +332,8 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
internal::Permute<int64_t>(permutation, &shape);

std::vector<int64_t> tensor_strides;
const auto* fw_value_type = internal::checked_cast<FixedWidthType*>(value_type.get());
ARROW_RETURN_NOT_OK(
ComputeStrides(*fw_value_type, shape, permutation, &tensor_strides));
internal::ComputeStrides(value_type, shape, permutation, &tensor_strides));

const auto& raw_buffer = this->storage()->data()->child_data[0]->buffers[1];
ARROW_ASSIGN_OR_RAISE(
Expand Down Expand Up @@ -413,10 +367,9 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(

const std::vector<int64_t>& FixedShapeTensorType::strides() {
if (strides_.empty()) {
auto value_type = internal::checked_cast<FixedWidthType*>(this->value_type_.get());
std::vector<int64_t> tensor_strides;
ARROW_CHECK_OK(
ComputeStrides(*value_type, this->shape(), this->permutation(), &tensor_strides));
ARROW_CHECK_OK(internal::ComputeStrides(this->value_type_, this->shape(),
this->permutation(), &tensor_strides));
strides_ = tensor_strides;
}
return strides_;
Expand All @@ -431,5 +384,4 @@ std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& va
return maybe_type.MoveValueUnsafe();
}

} // namespace extension
} // namespace arrow
} // namespace arrow::extension
8 changes: 2 additions & 6 deletions cpp/src/arrow/extension/fixed_shape_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@
// specific language governing permissions and limitations
// under the License.

#pragma once

#include "arrow/extension_type.h"

namespace arrow {
namespace extension {
namespace arrow::extension {

class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
public:
Expand Down Expand Up @@ -126,5 +123,4 @@ ARROW_EXPORT std::shared_ptr<DataType> fixed_shape_tensor(
const std::vector<int64_t>& permutation = {},
const std::vector<std::string>& dim_names = {});

} // namespace extension
} // namespace arrow
} // namespace arrow::extension
4 changes: 2 additions & 2 deletions cpp/src/arrow/extension/tensor_extension_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ TEST_F(TestFixedShapeTensorType, ComputeStrides) {
ASSERT_EQ(ext_type_6->Serialize(), R"({"shape":[3,4,7],"permutation":[1,2,0]})");
auto ext_type_7 = internal::checked_pointer_cast<FixedShapeTensorType>(
fixed_shape_tensor(int32(), {3, 4, 7}, {2, 0, 1}, {}));
ASSERT_EQ(ext_type_7->strides(), (std::vector<int64_t>{4, 112, 28}));
ASSERT_EQ(ext_type_7->strides(), (std::vector<int64_t>{4, 112, 16}));
ASSERT_EQ(ext_type_7->Serialize(), R"({"shape":[3,4,7],"permutation":[2,0,1]})");
}

Expand Down Expand Up @@ -577,7 +577,7 @@ TEST_F(TestFixedShapeTensorType, GetTensor) {
// Get tensor from extension array with non-trivial permutation
ASSERT_OK_AND_ASSIGN(auto expected_permuted_tensor,
Tensor::Make(value_type_, Buffer::Wrap(element_values[i]),
{4, 3}, {8, 32}, {"y", "x"}));
{4, 3}, {8, 24}, {"y", "x"}));
ASSERT_OK_AND_ASSIGN(scalar, permuted_array->GetScalar(i));
ASSERT_OK_AND_ASSIGN(auto actual_permuted_tensor,
exact_permuted_ext_type->MakeTensor(
Expand Down
48 changes: 39 additions & 9 deletions cpp/src/arrow/extension/tensor_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,44 @@ Status IsPermutationValid(const std::vector<int64_t>& permutation) {
return Status::OK();
}

Result<std::vector<int64_t>> ComputeStrides(const std::shared_ptr<DataType>& value_type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation) {
const auto& fw_type = checked_cast<const FixedWidthType&>(*value_type);
std::vector<int64_t> strides;
ARROW_DCHECK_OK(internal::ComputeRowMajorStrides(fw_type, shape, &strides));
// If the permutation is empty, the strides are already in the correct order.
internal::Permute<int64_t>(permutation, &strides);
return strides;
Status ComputeStrides(const std::shared_ptr<DataType>& value_type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation,
std::vector<int64_t>* strides) {
auto fixed_width_type = internal::checked_pointer_cast<FixedWidthType>(value_type);
if (permutation.empty()) {
return internal::ComputeRowMajorStrides(*fixed_width_type.get(), shape, strides);
}
const int byte_width = value_type->byte_width();

int64_t remaining = 0;
if (!shape.empty() && shape.front() > 0) {
remaining = byte_width;
for (auto i : permutation) {
if (i > 0) {
if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
return Status::Invalid(
"Strides computed from shape would not fit in 64-bit integer");
}
}
}
}

if (remaining == 0) {
strides->assign(shape.size(), byte_width);
return Status::OK();
}

strides->push_back(remaining);
for (auto i : permutation) {
if (i > 0) {
remaining /= shape[i];
strides->push_back(remaining);
}
}
internal::Permute(permutation, strides);

return Status::OK();
}

} // namespace arrow::internal
7 changes: 4 additions & 3 deletions cpp/src/arrow/extension/tensor_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ ARROW_EXPORT
Status IsPermutationValid(const std::vector<int64_t>& permutation);

ARROW_EXPORT
Result<std::vector<int64_t>> ComputeStrides(const std::shared_ptr<DataType>& value_type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation);
Status ComputeStrides(const std::shared_ptr<DataType>& value_type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation,
std::vector<int64_t>* strides);

} // namespace arrow::internal
45 changes: 22 additions & 23 deletions cpp/src/arrow/extension/variable_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@

namespace rj = arrow::rapidjson;

namespace arrow {
namespace extension {
namespace arrow::extension {

bool VariableShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
if (extension_name() != other.extension_name()) {
Expand Down Expand Up @@ -206,52 +205,53 @@ std::shared_ptr<Array> VariableShapeTensorType::MakeArray(

Result<std::shared_ptr<Tensor>> VariableShapeTensorType::MakeTensor(
const std::shared_ptr<ExtensionScalar>& scalar) {
const auto tensor_scalar = internal::checked_pointer_cast<StructScalar>(scalar->value);
const auto ext_type =
internal::checked_pointer_cast<VariableShapeTensorType>(scalar->type);
const auto& tensor_scalar = internal::checked_cast<const StructScalar&>(*scalar->value);
const auto& ext_type =
internal::checked_cast<const VariableShapeTensorType&>(*scalar->type);

ARROW_ASSIGN_OR_RAISE(const auto data_scalar, tensor_scalar->field(0));
ARROW_ASSIGN_OR_RAISE(const auto shape_scalar, tensor_scalar->field(1));
ARROW_CHECK(tensor_scalar->is_valid);
ARROW_ASSIGN_OR_RAISE(const auto data_scalar, tensor_scalar.field(0));
ARROW_ASSIGN_OR_RAISE(const auto shape_scalar, tensor_scalar.field(1));
ARROW_CHECK(tensor_scalar.is_valid);
const auto data_array =
internal::checked_pointer_cast<BaseListScalar>(data_scalar)->value;
const auto shape_array = internal::checked_pointer_cast<Int32Array>(
internal::checked_pointer_cast<FixedSizeListScalar>(shape_scalar)->value);

const auto value_type =
internal::checked_pointer_cast<FixedWidthType>(ext_type->value_type());
const auto& value_type =
internal::checked_cast<const FixedWidthType&>(*ext_type.value_type());

if (data_array->null_count() > 0) {
return Status::Invalid("Cannot convert data with nulls to Tensor.");
}

auto permutation = ext_type->permutation();
auto permutation = ext_type.permutation();
if (permutation.empty()) {
permutation.resize(ext_type->ndim());
permutation.resize(ext_type.ndim());
std::iota(permutation.begin(), permutation.end(), 0);
}

ARROW_CHECK_EQ(shape_array->length(), ext_type->ndim());
ARROW_CHECK_EQ(shape_array->length(), ext_type.ndim());
std::vector<int64_t> shape;
shape.reserve(ext_type->ndim());
for (int64_t j = 0; j < static_cast<int64_t>(ext_type->ndim()); ++j) {
shape.reserve(ext_type.ndim());
for (int64_t j = 0; j < static_cast<int64_t>(ext_type.ndim()); ++j) {
const auto size_value = shape_array->Value(j);
if (size_value < 0) {
return Status::Invalid("shape must have non-negative values");
}
shape.push_back(std::move(size_value));
}

std::vector<std::string> dim_names = ext_type->dim_names();
std::vector<std::string> dim_names = ext_type.dim_names();
if (!dim_names.empty()) {
internal::Permute<std::string>(permutation, &dim_names);
}

ARROW_ASSIGN_OR_RAISE(std::vector<int64_t> strides,
internal::ComputeStrides(value_type, shape, permutation));
std::vector<int64_t> strides;
ARROW_RETURN_NOT_OK(
internal::ComputeStrides(ext_type.value_type(), shape, permutation, &strides));
internal::Permute<int64_t>(permutation, &shape);

const auto byte_width = value_type->byte_width();
const auto byte_width = value_type.byte_width();
const auto start_position = data_array->offset() * byte_width;
const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
std::multiplies<>());
Expand All @@ -260,8 +260,8 @@ Result<std::shared_ptr<Tensor>> VariableShapeTensorType::MakeTensor(
const auto buffer,
SliceBufferSafe(data_array->data()->buffers[1], start_position, size * byte_width));

return Tensor::Make(value_type, std::move(buffer), std::move(shape), std::move(strides),
ext_type->dim_names());
return Tensor::Make(ext_type.value_type(), std::move(buffer), std::move(shape),
std::move(strides), ext_type.dim_names());
}

Result<std::shared_ptr<DataType>> VariableShapeTensorType::Make(
Expand Down Expand Up @@ -311,5 +311,4 @@ std::shared_ptr<DataType> variable_shape_tensor(
return maybe_type.MoveValueUnsafe();
}

} // namespace extension
} // namespace arrow
} // namespace arrow::extension

0 comments on commit e9edba4

Please sign in to comment.