diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 6331105af8998..dc1bd62474753 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -364,6 +364,7 @@ TEST_F(TestArray, TestMakeArrayOfNull) { dense_union(union_fields1, union_type_codes), dense_union(union_fields2, union_type_codes), smallint(), // extension type + list_extension_type(), // nested extension type // clang-format on }; @@ -371,6 +372,7 @@ TEST_F(TestArray, TestMakeArrayOfNull) { for (auto type : types) { ARROW_SCOPED_TRACE("type = ", type->ToString()); ASSERT_OK_AND_ASSIGN(auto array, MakeArrayOfNull(type, length)); + ASSERT_EQ(array->type(), type); ASSERT_OK(array->ValidateFull()); ASSERT_EQ(array->length(), length); if (is_union(type->id())) { diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index 62c9fa20c34c2..413182de0df77 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -456,19 +456,19 @@ class NullArrayFactory { template enable_if_var_size_list Visit(const T& type) { out_->buffers.resize(2, buffer_); - ARROW_ASSIGN_OR_RAISE(out_->child_data[0], CreateChild(0, /*length=*/0)); + ARROW_ASSIGN_OR_RAISE(out_->child_data[0], CreateChild(type, 0, /*length=*/0)); return Status::OK(); } Status Visit(const FixedSizeListType& type) { ARROW_ASSIGN_OR_RAISE(out_->child_data[0], - CreateChild(0, length_ * type.list_size())); + CreateChild(type, 0, length_ * type.list_size())); return Status::OK(); } Status Visit(const StructType& type) { for (int i = 0; i < type_->num_fields(); ++i) { - ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(i, length_)); + ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(type, i, length_)); } return Status::OK(); } @@ -498,7 +498,7 @@ class NullArrayFactory { child_length = 1; } for (int i = 0; i < type_->num_fields(); ++i) { - ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(i, child_length)); + ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(type, i, child_length)); } return Status::OK(); } @@ -511,6 +511,7 @@ class NullArrayFactory { } Status Visit(const ExtensionType& type) { + out_->child_data.resize(type.storage_type()->num_fields()); RETURN_NOT_OK(VisitTypeInline(*type.storage_type(), this)); return Status::OK(); } @@ -519,8 +520,9 @@ class NullArrayFactory { return Status::NotImplemented("construction of all-null ", type); } - Result> CreateChild(int i, int64_t length) { - NullArrayFactory child_factory(pool_, type_->field(i)->type(), length); + Result> CreateChild(const DataType& type, int i, + int64_t length) { + NullArrayFactory child_factory(pool_, type.field(i)->type(), length); child_factory.buffer_ = buffer_; return child_factory.Create(); } diff --git a/cpp/src/arrow/testing/extension_type.h b/cpp/src/arrow/testing/extension_type.h index 5afe23400767b..338b4cb4da055 100644 --- a/cpp/src/arrow/testing/extension_type.h +++ b/cpp/src/arrow/testing/extension_type.h @@ -54,6 +54,11 @@ class ARROW_TESTING_EXPORT SmallintArray : public ExtensionArray { using ExtensionArray::ExtensionArray; }; +class ARROW_TESTING_EXPORT ListExtensionArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + class ARROW_TESTING_EXPORT SmallintType : public ExtensionType { public: SmallintType() : ExtensionType(int16()) {} @@ -71,6 +76,23 @@ class ARROW_TESTING_EXPORT SmallintType : public ExtensionType { std::string Serialize() const override { return "smallint"; } }; +class ARROW_TESTING_EXPORT ListExtensionType : public ExtensionType { + public: + ListExtensionType() : ExtensionType(list(int32())) {} + + std::string extension_name() const override { return "list-ext"; } + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized) const override; + + std::string Serialize() const override { return "list-ext"; } +}; + class ARROW_TESTING_EXPORT DictExtensionType : public ExtensionType { public: DictExtensionType() : ExtensionType(dictionary(int8(), utf8())) {} @@ -118,6 +140,9 @@ std::shared_ptr uuid(); ARROW_TESTING_EXPORT std::shared_ptr smallint(); +ARROW_TESTING_EXPORT +std::shared_ptr list_extension_type(); + ARROW_TESTING_EXPORT std::shared_ptr dict_extension_type(); diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 0517be236dce3..56ba94dba8aac 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -797,6 +797,29 @@ Result> SmallintType::Deserialize( return std::make_shared(); } +bool ListExtensionType::ExtensionEquals(const ExtensionType& other) const { + return (other.extension_name() == this->extension_name()); +} + +std::shared_ptr ListExtensionType::MakeArray( + std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("list-ext", static_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +Result> ListExtensionType::Deserialize( + std::shared_ptr storage_type, const std::string& serialized) const { + if (serialized != "list-ext") { + return Status::Invalid("Type identifier did not match: '", serialized, "'"); + } + if (!storage_type->Equals(*list(int32()))) { + return Status::Invalid("Invalid storage type for ListExtensionType: ", + storage_type->ToString()); + } + return std::make_shared(); +} + bool DictExtensionType::ExtensionEquals(const ExtensionType& other) const { return (other.extension_name() == this->extension_name()); } @@ -847,6 +870,10 @@ std::shared_ptr uuid() { return std::make_shared(); } std::shared_ptr smallint() { return std::make_shared(); } +std::shared_ptr list_extension_type() { + return std::make_shared(); +} + std::shared_ptr dict_extension_type() { return std::make_shared(); }