Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Antoine Pitrou <pitrou@free.fr>
  • Loading branch information
rok and pitrou authored Mar 27, 2024
1 parent d71accf commit c9c558a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 18 deletions.
4 changes: 2 additions & 2 deletions cpp/src/arrow/extension/tensor_extension_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ TEST_F(TestFixedShapeTensorType, CreateFromArray) {

template <typename T>
void CheckSerializationRoundtrip(const std::shared_ptr<DataType>& ext_type) {
auto type = internal::checked_pointer_cast<T>(ext_type);
auto type = internal::checked_pointer_cast<ExtensionType>(ext_type);
auto serialized = type->Serialize();
ASSERT_OK_AND_ASSIGN(auto deserialized,
type->Deserialize(type->storage_type(), serialized));
Expand Down Expand Up @@ -695,7 +695,7 @@ class TestVariableShapeTensorType : public ::testing::Test {
TEST_F(TestVariableShapeTensorType, CheckDummyRegistration) {
// We need a registered dummy type at runtime to allow for IPC deserialization
auto registered_type = GetExtensionType("arrow.variable_shape_tensor");
ASSERT_TRUE(registered_type->type_id == Type::EXTENSION);
ASSERT_EQ(registered_type->type_id, Type::EXTENSION);
}

TEST_F(TestVariableShapeTensorType, CreateExtensionType) {
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/extension/tensor_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <cstdint>
#include <vector>

#include "arrow/array/array_nested.h"
#include "arrow/tensor.h"
#include "arrow/util/checked_cast.h"
Expand Down Expand Up @@ -84,6 +85,7 @@ inline Status ComputeStrides(const std::shared_ptr<DataType>& value_type,
strides->push_back(remaining);
}
}
DCHECK_EQ(strides.back(), byte_width);
Permute(permutation, strides);

return Status::OK();
Expand Down
28 changes: 15 additions & 13 deletions cpp/src/arrow/extension/variable_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ bool VariableShapeTensorType::ExtensionEquals(const ExtensionType& other) const
if (extension_name() != other.extension_name()) {
return false;
}
const auto& other_ext = static_cast<const VariableShapeTensorType&>(other);
const auto& other_ext = checked_cast<const VariableShapeTensorType&>(other);
if (this->ndim() != other_ext.ndim()) {
return false;
}
Expand Down Expand Up @@ -150,16 +150,15 @@ Result<std::shared_ptr<DataType>> VariableShapeTensorType::Deserialize(
return Status::Invalid("Expected List storage type, got ",
storage_type->field(1)->type()->ToString());
}
if (std::static_pointer_cast<FixedSizeListType>(storage_type->field(0)->type())
if (checked_cast<const FixedSizeListType&>(*storage_type->field(0)->type())
->value_type() != int32()) {
return Status::Invalid("Expected FixedSizeList value type int32, got ",
storage_type->field(0)->type()->ToString());
}

const auto value_type = storage_type->field(1)->type()->field(0)->type();
const size_t ndim =
std::static_pointer_cast<FixedSizeListType>(storage_type->field(0)->type())
->list_size();
const int32_t ndim =
checked_cast<const FixedSizeListType&>(*storage_type->field(0)->type()).list_size();

rj::Document document;
if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError()) {
Expand All @@ -168,6 +167,7 @@ Result<std::shared_ptr<DataType>> VariableShapeTensorType::Deserialize(

std::vector<int64_t> permutation;
if (document.HasMember("permutation")) {
permutation.reserve(ndim);
for (auto& x : document["permutation"].GetArray()) {
permutation.emplace_back(x.GetInt64());
}
Expand All @@ -177,7 +177,8 @@ Result<std::shared_ptr<DataType>> VariableShapeTensorType::Deserialize(
}
std::vector<std::string> dim_names;
if (document.HasMember("dim_names")) {
for (auto& x : document["dim_names"].GetArray()) {
dim_names.reserve(ndim);
for (const auto& x : document["dim_names"].GetArray()) {
dim_names.emplace_back(x.GetString());
}
if (dim_names.size() != ndim) {
Expand All @@ -187,7 +188,8 @@ Result<std::shared_ptr<DataType>> VariableShapeTensorType::Deserialize(

std::vector<std::optional<int64_t>> uniform_shape;
if (document.HasMember("uniform_shape")) {
for (auto& x : document["uniform_shape"].GetArray()) {
uniform_shape.reserve(ndim);
for (const auto& x : document["uniform_shape"].GetArray()) {
if (x.IsNull()) {
uniform_shape.emplace_back(std::nullopt);
} else {
Expand All @@ -200,8 +202,8 @@ Result<std::shared_ptr<DataType>> VariableShapeTensorType::Deserialize(
}
}

return variable_shape_tensor(value_type, static_cast<int32_t>(ndim), permutation,
dim_names, uniform_shape);
return variable_shape_tensor(value_type, static_cast<int32_t>(ndim), std::move(permutation),
std::move(dim_names), std::move(uniform_shape));
}

std::shared_ptr<Array> VariableShapeTensorType::MakeArray(
Expand Down Expand Up @@ -233,17 +235,17 @@ Result<std::shared_ptr<Tensor>> VariableShapeTensorType::MakeTensor(
return Status::Invalid("Cannot convert non-fixed-width values to Tensor.");
}
if (data_array->null_count() > 0) {
return Status::Invalid("Cannot convert data with nulls values to Tensor.");
return Status::Invalid("Cannot convert data with nulls to Tensor.");
}

auto permutation = ext_type->permutation();
if (permutation.empty()) {
for (int64_t j = 0; j < static_cast<int64_t>(ext_type->ndim()); ++j) {
permutation.emplace_back(j);
}
permutation.resize(ndim);
std::iota(permutation.begin(), permutation.end(), 0);
}

std::vector<int64_t> shape;
shape.reserve(ndim);
for (int64_t j = 0; j < static_cast<int64_t>(ext_type->ndim()); ++j) {
ARROW_ASSIGN_OR_RAISE(const auto size, shape_array->GetScalar(j));
auto size_value = internal::checked_pointer_cast<Int32Scalar>(size)->value;
Expand Down
12 changes: 9 additions & 3 deletions cpp/src/arrow/extension/variable_shape_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

#pragma once

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <vector>

#include "arrow/extension_type.h"

namespace arrow {
Expand Down Expand Up @@ -86,9 +92,9 @@ class ARROW_EXPORT VariableShapeTensorType : public ExtensionType {
/// \brief Create a VariableShapeTensorType instance
static Result<std::shared_ptr<DataType>> Make(
const std::shared_ptr<DataType>& value_type, const int32_t ndim,
const std::vector<int64_t>& permutation = {},
const std::vector<std::string>& dim_names = {},
const std::vector<std::optional<int64_t>>& uniform_shape = {});
const std::vector<int64_t> permutation = {},
const std::vector<std::string> dim_names = {},
const std::vector<std::optional<int64_t>> uniform_shape = {});

private:
std::shared_ptr<DataType> storage_type_;
Expand Down

0 comments on commit c9c558a

Please sign in to comment.