Skip to content

Commit

Permalink
Review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Mar 27, 2024
1 parent d6fe235 commit de7a9be
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 156 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,7 @@ endif()
if(ARROW_JSON)
arrow_add_object_library(ARROW_JSON
extension/fixed_shape_tensor.cc
extension/tensor_internal.cc
extension/variable_shape_tensor.cc
json/options.cc
json/chunked_builder.cc
Expand Down
17 changes: 8 additions & 9 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
internal::Permute<std::string>(permutation, &dim_names);
}

std::vector<int64_t> strides;
RETURN_NOT_OK(internal::ComputeStrides(value_type, shape, permutation, &strides));
ARROW_ASSIGN_OR_RAISE(std::vector<int64_t> strides,
internal::ComputeStrides(value_type, shape, permutation));
const auto start_position = array->offset() * byte_width;
const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
std::multiplies<>());
Expand Down Expand Up @@ -330,9 +330,8 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
shape.insert(shape.begin(), 1, this->length());
internal::Permute<int64_t>(permutation, &shape);

std::vector<int64_t> tensor_strides;
ARROW_RETURN_NOT_OK(
internal::ComputeStrides(value_type, shape, permutation, &tensor_strides));
ARROW_ASSIGN_OR_RAISE(std::vector<int64_t> tensor_strides,
internal::ComputeStrides(value_type, shape, permutation));

const auto raw_buffer = this->storage()->data()->child_data[0]->buffers[1];
ARROW_ASSIGN_OR_RAISE(
Expand Down Expand Up @@ -366,10 +365,10 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(

const std::vector<int64_t>& FixedShapeTensorType::strides() {
if (strides_.empty()) {
std::vector<int64_t> tensor_strides;
ARROW_CHECK_OK(internal::ComputeStrides(this->value_type_, this->shape(),
this->permutation(), &tensor_strides));
strides_ = tensor_strides;
Result<std::vector<int64_t>> maybe_tensor_strides =
internal::ComputeStrides(this->value_type_, this->shape(), this->permutation());
ARROW_DCHECK_OK(maybe_tensor_strides.status());
strides_ = maybe_tensor_strides.MoveValueUnsafe();
}
return strides_;
}
Expand Down
47 changes: 21 additions & 26 deletions cpp/src/arrow/extension/tensor_extension_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ TEST_F(TestFixedShapeTensorType, CreateFromArray) {
ASSERT_EQ(ext_arr->null_count(), 0);
}

template <typename T>
void CheckSerializationRoundtrip(const std::shared_ptr<DataType>& ext_type) {
auto type = internal::checked_pointer_cast<ExtensionType>(ext_type);
auto serialized = type->Serialize();
Expand All @@ -189,14 +188,13 @@ void CheckDeserializationRaises(const std::shared_ptr<DataType>& extension_type,
}

TEST_F(TestFixedShapeTensorType, MetadataSerializationRoundtrip) {
using T = FixedShapeTensorType;
CheckSerializationRoundtrip<T>(ext_type_);
CheckSerializationRoundtrip<T>(fixed_shape_tensor(value_type_, {}, {}, {}));
CheckSerializationRoundtrip<T>(fixed_shape_tensor(value_type_, {0}, {}, {}));
CheckSerializationRoundtrip<T>(fixed_shape_tensor(value_type_, {1}, {0}, {"x"}));
CheckSerializationRoundtrip<T>(
CheckSerializationRoundtrip(ext_type_);
CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {}, {}, {}));
CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {0}, {}, {}));
CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {1}, {0}, {"x"}));
CheckSerializationRoundtrip(
fixed_shape_tensor(value_type_, {256, 256, 3}, {0, 1, 2}, {"H", "W", "C"}));
CheckSerializationRoundtrip<T>(
CheckSerializationRoundtrip(
fixed_shape_tensor(value_type_, {256, 256, 3}, {2, 0, 1}, {"C", "H", "W"}));

auto storage_type = fixed_size_list(int64(), 12);
Expand Down Expand Up @@ -529,7 +527,7 @@ TEST_F(TestFixedShapeTensorType, ComputeStrides) {
ASSERT_EQ(ext_type_7->Serialize(), R"({"shape":[3,4,7],"permutation":[2,0,1]})");
}

TEST_F(TestFixedShapeTensorType, FixedShapeTensoToString) {
TEST_F(TestFixedShapeTensorType, FixedShapeTensorToString) {
auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);

auto ext_type_1 = internal::checked_pointer_cast<FixedShapeTensorType>(
Expand Down Expand Up @@ -670,7 +668,7 @@ class TestVariableShapeTensorType : public ::testing::Test {
R"({"permutation":[0,1,2],"dim_names":["x","y","z"],"uniform_shape":[null,1,null]})";
storage_arr_ = ArrayFromJSON(
ext_type_->storage_type(),
R"([[[2,3,1],[0,1,2,3,4,5]],[[1,2,2],[6,7,8,9]],[[3,1,3],[10,11,12,13,14,15,16,17,18]]])");
R"([[[0,1,2,3,4,5],[2,3,1]],[[6,7,8,9],[1,2,2]],[[10,11,12,13,14,15,16,17,18],[3,1,3]]])");
ext_arr_ = internal::checked_pointer_cast<ExtensionArray>(
ExtensionType::WrapArray(ext_type_, storage_arr_));
}
Expand Down Expand Up @@ -705,10 +703,9 @@ TEST_F(TestVariableShapeTensorType, CreateExtensionType) {
// Test ExtensionType methods
ASSERT_EQ(ext_type_->extension_name(), "arrow.variable_shape_tensor");
ASSERT_TRUE(ext_type_->Equals(*exact_ext_type));
auto expected_type = struct_({
::arrow::field("shape", fixed_size_list(int32(), ndim_)),
::arrow::field("data", list(value_type_)),
});
auto expected_type =
struct_({::arrow::field("data", list(value_type_)),
::arrow::field("shape", fixed_size_list(int32(), ndim_))});

ASSERT_TRUE(ext_type_->storage_type()->Equals(*expected_type));
ASSERT_EQ(ext_type_->Serialize(), serialized_);
Expand All @@ -735,7 +732,7 @@ TEST_F(TestVariableShapeTensorType, CreateExtensionType) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid,
testing::HasSubstr("Invalid: Permutation indices for 3 dimensional tensors must be "
"unique and within [0, 2] range. Got: [0,0,2]"),
"unique and within [0, 2] range. Got: [2,0,0]"),
VariableShapeTensorType::Make(value_type_, 3, {2, 0, 0}, {"C", "H", "W"}));
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid,
Expand Down Expand Up @@ -767,18 +764,16 @@ TEST_F(TestVariableShapeTensorType, EqualsCases) {
}

TEST_F(TestVariableShapeTensorType, MetadataSerializationRoundtrip) {
using T = VariableShapeTensorType;

CheckSerializationRoundtrip<T>(ext_type_);
CheckSerializationRoundtrip<T>(
CheckSerializationRoundtrip(ext_type_);
CheckSerializationRoundtrip(
variable_shape_tensor(value_type_, 3, {1, 2, 0}, {"x", "y", "z"}));
CheckSerializationRoundtrip<T>(variable_shape_tensor(value_type_, 0, {}, {}));
CheckSerializationRoundtrip<T>(variable_shape_tensor(value_type_, 1, {0}, {"x"}));
CheckSerializationRoundtrip<T>(
CheckSerializationRoundtrip(variable_shape_tensor(value_type_, 0, {}, {}));
CheckSerializationRoundtrip(variable_shape_tensor(value_type_, 1, {0}, {"x"}));
CheckSerializationRoundtrip(
variable_shape_tensor(value_type_, 3, {0, 1, 2}, {"H", "W", "C"}));
CheckSerializationRoundtrip<T>(
CheckSerializationRoundtrip(
variable_shape_tensor(value_type_, 3, {2, 0, 1}, {"C", "H", "W"}));
CheckSerializationRoundtrip<T>(
CheckSerializationRoundtrip(
variable_shape_tensor(value_type_, 3, {2, 0, 1}, {"C", "H", "W"}, {0, 1, 2}));

auto storage_type = ext_type_->storage_type();
Expand All @@ -787,9 +782,9 @@ TEST_F(TestVariableShapeTensorType, MetadataSerializationRoundtrip) {
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":(3,4)})",
"Invalid serialized JSON data");
CheckDeserializationRaises(ext_type_, storage_type, R"({"permutation":[1,0]})",
"Invalid permutation");
"Invalid: permutation");
CheckDeserializationRaises(ext_type_, storage_type, R"({"dim_names":["x","y"]})",
"Invalid dim_names");
"Invalid: dim_names");
}

TEST_F(TestVariableShapeTensorType, RoudtripBatch) {
Expand Down
90 changes: 90 additions & 0 deletions cpp/src/arrow/extension/tensor_internal.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/array/array_nested.h"
#include "arrow/tensor.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/sort.h"

#include "arrow/status.h"
#include "arrow/util/logging.h"
#include "arrow/util/print.h"

namespace arrow::internal {

Status IsPermutationValid(const std::vector<int64_t>& permutation) {
const auto size = static_cast<int64_t>(permutation.size());
std::vector<uint8_t> dim_seen(size, 0);

for (const auto p : permutation) {
if (p < 0 || p >= size || dim_seen[p] != 0) {
return Status::Invalid(
"Permutation indices for ", size,
" dimensional tensors must be unique and within [0, ", size - 1,
"] range. Got: ", ::arrow::internal::PrintVector{permutation, ","});
}
dim_seen[p] = 1;
}
return Status::OK();
}

Result<std::vector<int64_t>> ComputeStrides(const std::shared_ptr<DataType>& value_type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation) {
const auto fixed_width_type =
internal::checked_pointer_cast<FixedWidthType>(value_type);

std::vector<int64_t> strides;
if (permutation.empty()) {
ARROW_DCHECK_OK(
internal::ComputeRowMajorStrides(*fixed_width_type.get(), shape, &strides));
return strides;
}
const int byte_width = value_type->byte_width();

int64_t remaining = 0;
if (!shape.empty() && shape.front() > 0) {
remaining = byte_width;
for (auto i : permutation) {
if (i > 0) {
if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
return Status::Invalid(
"Strides computed from shape would not fit in 64-bit integer");
}
}
}
}

if (remaining == 0) {
strides.assign(shape.size(), byte_width);
return strides;
}

strides.push_back(remaining);
for (auto i : permutation) {
if (i > 0) {
remaining /= shape[i];
strides.push_back(remaining);
}
}
DCHECK_EQ(strides.back(), byte_width);
Permute(permutation, &strides);

return strides;
}
} // namespace arrow::internal
68 changes: 4 additions & 64 deletions cpp/src/arrow/extension/tensor_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,75 +21,15 @@
#include <vector>

#include "arrow/array/array_nested.h"
#include "arrow/tensor.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/sort.h"

#include "arrow/status.h"
#include "arrow/util/logging.h"
#include "arrow/util/print.h"

namespace arrow::internal {

ARROW_EXPORT
inline Status IsPermutationValid(const std::vector<int64_t>& permutation) {
const auto size = static_cast<int64_t>(permutation.size());
std::vector<uint8_t> dim_seen(size, 0);

for (const auto p : permutation) {
if (p < 0 || p >= size || dim_seen[p] != 0) {
return Status::Invalid(
"Permutation indices for ", size,
" dimensional tensors must be unique and within [0, ", size - 1,
"] range. Got: ", ::arrow::internal::PrintVector{permutation, ","});
}
dim_seen[p] = 1;
}
return Status::OK();
}
Status IsPermutationValid(const std::vector<int64_t>& permutation);

ARROW_EXPORT
inline Status ComputeStrides(const std::shared_ptr<DataType>& value_type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation,
std::vector<int64_t>* strides) {
const auto fixed_width_type =
internal::checked_pointer_cast<FixedWidthType>(value_type);
if (permutation.empty()) {
return internal::ComputeRowMajorStrides(*fixed_width_type.get(), shape, strides);
}
const int byte_width = value_type->byte_width();

int64_t remaining = 0;
if (!shape.empty() && shape.front() > 0) {
remaining = byte_width;
for (auto i : permutation) {
if (i > 0) {
if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
return Status::Invalid(
"Strides computed from shape would not fit in 64-bit integer");
}
}
}
}

if (remaining == 0) {
strides->assign(shape.size(), byte_width);
return Status::OK();
}

strides->push_back(remaining);
for (auto i : permutation) {
if (i > 0) {
remaining /= shape[i];
strides->push_back(remaining);
}
}
DCHECK_EQ(strides->back(), byte_width);
Permute(permutation, strides);

return Status::OK();
}
Result<std::vector<int64_t>> ComputeStrides(const std::shared_ptr<DataType>& value_type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation);

} // namespace arrow::internal
Loading

0 comments on commit de7a9be

Please sign in to comment.