Skip to content

Commit

Permalink
ARROW-15241: [C++] MakeArrayOfNull fails on extension types with a ne…
Browse files Browse the repository at this point in the history
…sted storage type

Closes #12066 from westonpace/bugfix/ARROW-15241--make-array-of-null-nested-storage

Authored-by: Weston Pace <weston.pace@gmail.com>
Signed-off-by: Weston Pace <weston.pace@gmail.com>
  • Loading branch information
westonpace committed Jan 5, 2022
1 parent 49093a1 commit b596f29
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 6 deletions.
2 changes: 2 additions & 0 deletions cpp/src/arrow/array/array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,15 @@ TEST_F(TestArray, TestMakeArrayOfNull) {
dense_union(union_fields1, union_type_codes),
dense_union(union_fields2, union_type_codes),
smallint(), // extension type
list_extension_type(), // nested extension type
// clang-format on
};

for (int64_t length : {0, 1, 16, 133}) {
for (auto type : types) {
ARROW_SCOPED_TRACE("type = ", type->ToString());
ASSERT_OK_AND_ASSIGN(auto array, MakeArrayOfNull(type, length));
ASSERT_EQ(array->type(), type);
ASSERT_OK(array->ValidateFull());
ASSERT_EQ(array->length(), length);
if (is_union(type->id())) {
Expand Down
14 changes: 8 additions & 6 deletions cpp/src/arrow/array/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,19 +456,19 @@ class NullArrayFactory {
template <typename T>
enable_if_var_size_list<T, Status> Visit(const T& type) {
out_->buffers.resize(2, buffer_);
ARROW_ASSIGN_OR_RAISE(out_->child_data[0], CreateChild(0, /*length=*/0));
ARROW_ASSIGN_OR_RAISE(out_->child_data[0], CreateChild(type, 0, /*length=*/0));
return Status::OK();
}

Status Visit(const FixedSizeListType& type) {
ARROW_ASSIGN_OR_RAISE(out_->child_data[0],
CreateChild(0, length_ * type.list_size()));
CreateChild(type, 0, length_ * type.list_size()));
return Status::OK();
}

Status Visit(const StructType& type) {
for (int i = 0; i < type_->num_fields(); ++i) {
ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(i, length_));
ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(type, i, length_));
}
return Status::OK();
}
Expand Down Expand Up @@ -498,7 +498,7 @@ class NullArrayFactory {
child_length = 1;
}
for (int i = 0; i < type_->num_fields(); ++i) {
ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(i, child_length));
ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(type, i, child_length));
}
return Status::OK();
}
Expand All @@ -511,6 +511,7 @@ class NullArrayFactory {
}

Status Visit(const ExtensionType& type) {
out_->child_data.resize(type.storage_type()->num_fields());
RETURN_NOT_OK(VisitTypeInline(*type.storage_type(), this));
return Status::OK();
}
Expand All @@ -519,8 +520,9 @@ class NullArrayFactory {
return Status::NotImplemented("construction of all-null ", type);
}

Result<std::shared_ptr<ArrayData>> CreateChild(int i, int64_t length) {
NullArrayFactory child_factory(pool_, type_->field(i)->type(), length);
Result<std::shared_ptr<ArrayData>> CreateChild(const DataType& type, int i,
int64_t length) {
NullArrayFactory child_factory(pool_, type.field(i)->type(), length);
child_factory.buffer_ = buffer_;
return child_factory.Create();
}
Expand Down
25 changes: 25 additions & 0 deletions cpp/src/arrow/testing/extension_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class ARROW_TESTING_EXPORT SmallintArray : public ExtensionArray {
using ExtensionArray::ExtensionArray;
};

class ARROW_TESTING_EXPORT ListExtensionArray : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
};

class ARROW_TESTING_EXPORT SmallintType : public ExtensionType {
public:
SmallintType() : ExtensionType(int16()) {}
Expand All @@ -71,6 +76,23 @@ class ARROW_TESTING_EXPORT SmallintType : public ExtensionType {
std::string Serialize() const override { return "smallint"; }
};

class ARROW_TESTING_EXPORT ListExtensionType : public ExtensionType {
public:
ListExtensionType() : ExtensionType(list(int32())) {}

std::string extension_name() const override { return "list-ext"; }

bool ExtensionEquals(const ExtensionType& other) const override;

std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;

Result<std::shared_ptr<DataType>> Deserialize(
std::shared_ptr<DataType> storage_type,
const std::string& serialized) const override;

std::string Serialize() const override { return "list-ext"; }
};

class ARROW_TESTING_EXPORT DictExtensionType : public ExtensionType {
public:
DictExtensionType() : ExtensionType(dictionary(int8(), utf8())) {}
Expand Down Expand Up @@ -118,6 +140,9 @@ std::shared_ptr<DataType> uuid();
ARROW_TESTING_EXPORT
std::shared_ptr<DataType> smallint();

ARROW_TESTING_EXPORT
std::shared_ptr<DataType> list_extension_type();

ARROW_TESTING_EXPORT
std::shared_ptr<DataType> dict_extension_type();

Expand Down
27 changes: 27 additions & 0 deletions cpp/src/arrow/testing/gtest_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,29 @@ Result<std::shared_ptr<DataType>> SmallintType::Deserialize(
return std::make_shared<SmallintType>();
}

bool ListExtensionType::ExtensionEquals(const ExtensionType& other) const {
return (other.extension_name() == this->extension_name());
}

std::shared_ptr<Array> ListExtensionType::MakeArray(
std::shared_ptr<ArrayData> data) const {
DCHECK_EQ(data->type->id(), Type::EXTENSION);
DCHECK_EQ("list-ext", static_cast<const ExtensionType&>(*data->type).extension_name());
return std::make_shared<ListExtensionArray>(data);
}

Result<std::shared_ptr<DataType>> ListExtensionType::Deserialize(
std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
if (serialized != "list-ext") {
return Status::Invalid("Type identifier did not match: '", serialized, "'");
}
if (!storage_type->Equals(*list(int32()))) {
return Status::Invalid("Invalid storage type for ListExtensionType: ",
storage_type->ToString());
}
return std::make_shared<ListExtensionType>();
}

bool DictExtensionType::ExtensionEquals(const ExtensionType& other) const {
return (other.extension_name() == this->extension_name());
}
Expand Down Expand Up @@ -847,6 +870,10 @@ std::shared_ptr<DataType> uuid() { return std::make_shared<UuidType>(); }

std::shared_ptr<DataType> smallint() { return std::make_shared<SmallintType>(); }

std::shared_ptr<DataType> list_extension_type() {
return std::make_shared<ListExtensionType>();
}

std::shared_ptr<DataType> dict_extension_type() {
return std::make_shared<DictExtensionType>();
}
Expand Down

0 comments on commit b596f29

Please sign in to comment.