Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Aug 15, 2023
1 parent 0e4fa8a commit b3e2c84
Show file tree
Hide file tree
Showing 7 changed files with 510 additions and 25 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ if(ARROW_JSON)
list(APPEND
ARROW_SRCS
extension/fixed_shape_tensor.cc
extension/variable_shape_tensor.cc
json/options.cc
json/chunked_builder.cc
json/chunker.cc
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

add_arrow_test(test
SOURCES
fixed_shape_tensor_test.cc
tensor_extension_array_test.cc
PREFIX
"arrow-fixed-shape-tensor")
"arrow-canonical-extensions")

arrow_install_all_headers("arrow/extension")
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

#include "arrow/extension/fixed_shape_tensor.h"
#include "arrow/extension/variable_shape_tensor.h"

#include "arrow/testing/matchers.h"

Expand All @@ -35,6 +36,10 @@ using FixedShapeTensorType = extension::FixedShapeTensorType;
using extension::fixed_shape_tensor;
using extension::FixedShapeTensorArray;

using VariableShapeTensorType = extension::VariableShapeTensorType;
using extension::variable_shape_tensor;
using extension::VariableShapeTensorArray;

class TestExtensionType : public ::testing::Test {
public:
void SetUp() override {
Expand Down Expand Up @@ -154,43 +159,47 @@ TEST_F(TestExtensionType, CreateFromArray) {
ASSERT_EQ(ext_arr->null_count(), 0);
}

template <typename T>
void CheckSerializationRoundtrip(const std::shared_ptr<DataType>& ext_type) {
auto fst_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type);
auto serialized = fst_type->Serialize();
auto type = internal::checked_pointer_cast<T>(ext_type);
auto serialized = type->Serialize();
ASSERT_OK_AND_ASSIGN(auto deserialized,
fst_type->Deserialize(fst_type->storage_type(), serialized));
ASSERT_TRUE(fst_type->Equals(*deserialized));
type->Deserialize(type->storage_type(), serialized));
ASSERT_TRUE(type->Equals(*deserialized));
}

void CheckDeserializationRaises(const std::shared_ptr<DataType>& storage_type,
void CheckDeserializationRaises(const std::shared_ptr<DataType>& extension_type,
const std::shared_ptr<DataType>& storage_type,
const std::string& serialized,
const std::string& expected_message) {
auto fst_type = internal::checked_pointer_cast<FixedShapeTensorType>(
fixed_shape_tensor(int64(), {3, 4}));
auto ext_type = internal::checked_pointer_cast<ExtensionType>(extension_type);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr(expected_message),
fst_type->Deserialize(storage_type, serialized));
ext_type->Deserialize(storage_type, serialized));
}

TEST_F(TestExtensionType, MetadataSerializationRoundtrip) {
CheckSerializationRoundtrip(ext_type_);
CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {}, {}, {}));
CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {0}, {}, {}));
CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {1}, {0}, {"x"}));
CheckSerializationRoundtrip(
using T = FixedShapeTensorType;
CheckSerializationRoundtrip<T>(ext_type_);
CheckSerializationRoundtrip<T>(fixed_shape_tensor(value_type_, {}, {}, {}));
CheckSerializationRoundtrip<T>(fixed_shape_tensor(value_type_, {0}, {}, {}));
CheckSerializationRoundtrip<T>(fixed_shape_tensor(value_type_, {1}, {0}, {"x"}));
CheckSerializationRoundtrip<T>(
fixed_shape_tensor(value_type_, {256, 256, 3}, {0, 1, 2}, {"H", "W", "C"}));
CheckSerializationRoundtrip(
CheckSerializationRoundtrip<T>(
fixed_shape_tensor(value_type_, {256, 256, 3}, {2, 0, 1}, {"C", "H", "W"}));

auto storage_type = fixed_size_list(int64(), 12);
CheckDeserializationRaises(boolean(), R"({"shape":[3,4]})",
CheckDeserializationRaises(ext_type_, boolean(), R"({"shape":[3,4]})",
"Expected FixedSizeList storage type, got bool");
CheckDeserializationRaises(storage_type, R"({"dim_names":["x","y"]})",
CheckDeserializationRaises(ext_type_, storage_type, R"({"dim_names":["x","y"]})",
"Invalid serialized JSON data");
CheckDeserializationRaises(storage_type, R"({"shape":(3,4)})",
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":(3,4)})",
"Invalid serialized JSON data");
CheckDeserializationRaises(storage_type, R"({"shape":[3,4],"permutation":[1,0,2]})",
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3,4],"permutation":[1,0,2]})",
"Invalid permutation");
CheckDeserializationRaises(storage_type, R"({"shape":[3],"dim_names":["x","y"]})",
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3],"dim_names":["x","y"]})",
"Invalid dim_names");
}

Expand Down Expand Up @@ -434,4 +443,159 @@ TEST_F(TestExtensionType, ComputeStrides) {
ASSERT_EQ(ext_type_7->Serialize(), R"({"shape":[3,4,7],"permutation":[2,0,1]})");
}

class TestVariableShapeTensorType : public ::testing::Test {
public:
void SetUp() override {
ndim_ = 3;
value_type_ = int64();
permutation_ = {0, 1, 2};
dim_names_ = {"x", "y", "z"};
ext_type_ = internal::checked_pointer_cast<ExtensionType>(
variable_shape_tensor(value_type_, ndim_, permutation_, dim_names_));
shapes_ =
ArrayFromJSON(fixed_size_list(uint32(), ndim_), "[[2,3,1],[1,2,2],[3,1,3]]");
data_ = ArrayFromJSON(list(value_type_),
"[[0,1,2,3,4,5],[6,7,8,9],[10,11,12,13,14,15,16,17,18]]");
serialized_ = R"({"ndim":3,"permutation":[0,1,2],"dim_names":["x","y","z"]})";
storage_arr_ = ArrayFromJSON(
ext_type_->storage_type(),
R"([[[2,3,1],[0,1,2,3,4,5]],[[1,2,2],[6,7,8,9]],[[3,1,3],[10,11,12,13,14,15,16,17,18]]])");
ext_arr_ = internal::checked_pointer_cast<ExtensionArray>(
ExtensionType::WrapArray(ext_type_, storage_arr_));
}

protected:
uint32_t ndim_;
std::shared_ptr<DataType> value_type_;
std::vector<int64_t> permutation_;
std::vector<std::string> dim_names_;
std::shared_ptr<ExtensionType> ext_type_;
std::shared_ptr<Array> shapes_;
std::shared_ptr<Array> data_;
std::string serialized_;
std::shared_ptr<Array> storage_arr_;
std::shared_ptr<ExtensionArray> ext_arr_;
};

TEST_F(TestVariableShapeTensorType, CheckDummyRegistration) {
// We need a registered dummy type at runtime to allow for IPC deserialization
auto registered_type = GetExtensionType("arrow.variable_shape_tensor");
ASSERT_TRUE(registered_type->type_id == Type::EXTENSION);
}

TEST_F(TestVariableShapeTensorType, CreateExtensionType) {
auto exact_ext_type =
internal::checked_pointer_cast<VariableShapeTensorType>(ext_type_);

// Test ExtensionType methods
ASSERT_EQ(ext_type_->extension_name(), "arrow.variable_shape_tensor");
ASSERT_TRUE(ext_type_->Equals(*exact_ext_type));
auto expected_type = struct_({
::arrow::field("shape", fixed_size_list(uint32(), ndim_)),
::arrow::field("data", list(value_type_)),
});

ASSERT_TRUE(ext_type_->storage_type()->Equals(*expected_type));
ASSERT_EQ(ext_type_->Serialize(), serialized_);
ASSERT_OK_AND_ASSIGN(auto ds,
ext_type_->Deserialize(ext_type_->storage_type(), serialized_));
auto deserialized = internal::checked_pointer_cast<ExtensionType>(ds);
ASSERT_TRUE(deserialized->Equals(*exact_ext_type));
ASSERT_TRUE(deserialized->Equals(*ext_type_));

// Test FixedShapeTensorType methods
ASSERT_EQ(exact_ext_type->id(), Type::EXTENSION);
ASSERT_EQ(exact_ext_type->ndim(), ndim_);
ASSERT_EQ(exact_ext_type->value_type(), value_type_);
ASSERT_EQ(exact_ext_type->permutation(), permutation_);
ASSERT_EQ(exact_ext_type->dim_names(), dim_names_);

EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid,
testing::HasSubstr("Invalid: permutation size must match ndim. Expected: 3 Got: 1"),
VariableShapeTensorType::Make(value_type_, ndim_, {0}));
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, testing::HasSubstr("Invalid: dim_names size must match ndim."),
VariableShapeTensorType::Make(value_type_, ndim_, {}, {"x"}));
}

TEST_F(TestVariableShapeTensorType, EqualsCases) {
auto ext_type_permutation_1 = variable_shape_tensor(int64(), 2, {0, 1}, {"x", "y"});
auto ext_type_permutation_2 = variable_shape_tensor(int64(), 2, {1, 0}, {"x", "y"});
auto ext_type_no_permutation = variable_shape_tensor(int64(), 2, {}, {"x", "y"});

ASSERT_TRUE(ext_type_permutation_1->Equals(ext_type_permutation_1));

ASSERT_FALSE(
variable_shape_tensor(int32(), 2, {}, {"x", "y"})->Equals(ext_type_no_permutation));
ASSERT_FALSE(variable_shape_tensor(int64(), 2, {}, {})
->Equals(variable_shape_tensor(int64(), 3, {}, {})));
ASSERT_FALSE(
variable_shape_tensor(int64(), 2, {}, {"H", "W"})->Equals(ext_type_no_permutation));

ASSERT_TRUE(ext_type_no_permutation->Equals(ext_type_permutation_1));
ASSERT_TRUE(ext_type_permutation_1->Equals(ext_type_no_permutation));
ASSERT_FALSE(ext_type_no_permutation->Equals(ext_type_permutation_2));
ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_no_permutation));
ASSERT_FALSE(ext_type_permutation_1->Equals(ext_type_permutation_2));
ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_permutation_1));
}

TEST_F(TestVariableShapeTensorType, CreateFromArray) {
std::vector<std::string> field_names = {"shapes", "data"};
ASSERT_OK_AND_ASSIGN(auto storage_arr,
StructArray::Make({shapes_, data_}, field_names));
auto arr = ExtensionType::WrapArray(ext_type_, storage_arr);
ASSERT_TRUE(ext_arr_->Equals(*arr));
}

TEST_F(TestVariableShapeTensorType, MetadataSerializationRoundtrip) {
using T = VariableShapeTensorType;

CheckSerializationRoundtrip<T>(ext_type_);
CheckSerializationRoundtrip<T>(variable_shape_tensor(value_type_, {}, {}, {}));
CheckSerializationRoundtrip<T>(variable_shape_tensor(value_type_, {0}, {}, {}));
CheckSerializationRoundtrip<T>(variable_shape_tensor(value_type_, {1}, {0}, {"x"}));
CheckSerializationRoundtrip<T>(
variable_shape_tensor(value_type_, 3, {0, 1, 2}, {"H", "W", "C"}));
CheckSerializationRoundtrip<T>(
variable_shape_tensor(value_type_, 3, {2, 0, 1}, {"C", "H", "W"}));

auto storage_type = ext_type_->storage_type();
CheckDeserializationRaises(ext_type_, boolean(), R"({"shape":[3,4]})",
"Expected Struct storage type, got bool");
CheckDeserializationRaises(ext_type_, storage_type, R"({"dim_names":["x","y"]})",
"Missing ndim");
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":(3,4)})",
"Invalid serialized JSON data");
CheckDeserializationRaises(ext_type_, storage_type,
R"({"ndim":2,"permutation":[1,0,2]})",
"Invalid permutation");
CheckDeserializationRaises(ext_type_, storage_type,
R"({"ndim":1,"dim_names":["x","y"]})", "Invalid dim_names");
}

TEST_F(TestVariableShapeTensorType, RoudtripBatch) {
auto exact_ext_type =
internal::checked_pointer_cast<VariableShapeTensorType>(ext_type_);

// Pass extension array, expect getting back extension array
std::shared_ptr<RecordBatch> read_batch;
auto ext_field = field(/*name=*/"f0", /*type=*/ext_type_);
auto batch = RecordBatch::Make(schema({ext_field}), ext_arr_->length(), {ext_arr_});
RoundtripBatch(batch, &read_batch);
CompareBatch(*batch, *read_batch, /*compare_metadata=*/true);

// Pass extension metadata and storage array, expect getting back extension array
std::shared_ptr<RecordBatch> read_batch2;
auto ext_metadata =
key_value_metadata({{"ARROW:extension:name", exact_ext_type->extension_name()},
{"ARROW:extension:metadata", serialized_}});
ext_field = field(/*name=*/"f0", /*type=*/ext_type_->storage_type(), /*nullable=*/true,
/*metadata=*/ext_metadata);
auto batch2 = RecordBatch::Make(schema({ext_field}), ext_arr_->length(), {ext_arr_});
RoundtripBatch(batch2, &read_batch2);
CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true);
}

} // namespace arrow
Loading

0 comments on commit b3e2c84

Please sign in to comment.