Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [C++] DenseUnionArray #206

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cpp/src/arrow/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ class ARROW_EXPORT ArrayBuilder {
DISALLOW_COPY_AND_ASSIGN(ArrayBuilder);
};

class ARROW_EXPORT NullArrayBuilder : public ArrayBuilder {
public:
explicit NullArrayBuilder(MemoryPool* pool, const TypePtr& type) : ArrayBuilder(pool, type) {}
virtual ~NullArrayBuilder() {};
Status Finish(std::shared_ptr<Array>* out) override {
return Status::OK();
}
};

} // namespace arrow

#endif // ARROW_BUILDER_H_
24 changes: 24 additions & 0 deletions cpp/src/arrow/ipc/adapter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "arrow/types/primitive.h"
#include "arrow/types/string.h"
#include "arrow/types/struct.h"
#include "arrow/types/union.h"
#include "arrow/util/bit-util.h"
#include "arrow/util/buffer.h"
#include "arrow/util/logging.h"
Expand Down Expand Up @@ -115,6 +116,13 @@ Status VisitArray(const Array* arr, std::vector<flatbuf::FieldNode>* field_nodes
RETURN_NOT_OK(
VisitArray(field.get(), field_nodes, buffers, max_recursion_depth - 1));
}
} else if (arr->type_enum() == Type::UNION) {
const auto union_arr = static_cast<const UnionArray*>(arr);
buffers->push_back(union_arr->types());
buffers->push_back(union_arr->offset_buf());
for (auto& child_arr : union_arr->children()) {
RETURN_NOT_OK(VisitArray(child_arr.get(), field_nodes, buffers, max_recursion_depth - 1));
}
} else {
return Status::NotImplemented("Unrecognized type");
}
Expand Down Expand Up @@ -363,6 +371,22 @@ class RecordBatchReader::RecordBatchReaderImpl {
out->reset(new StructArray(
type, field_meta.length, fields, field_meta.null_count, null_bitmap));
return Status::OK();
} else if (type->type == Type::UNION) {
std::shared_ptr<Buffer> types;
RETURN_NOT_OK(GetBuffer(buffer_index_++, &types));
std::shared_ptr<Buffer> offset_buf;
RETURN_NOT_OK(GetBuffer(buffer_index_++, &offset_buf));
auto union_type = std::dynamic_pointer_cast<UnionType>(type);
const int num_children = union_type->num_children();
std::vector<ArrayPtr> results;
for (int child_idx = 0; child_idx < num_children; ++child_idx) {
std::shared_ptr<Array> result;
RETURN_NOT_OK(NextArray(union_type->child(child_idx).get(), max_recursion_depth - 1, &result));
results.push_back(result);
}
out->reset(new UnionArray(
type, field_meta.length, results, types, offset_buf, field_meta.null_count, null_bitmap));
return Status::OK();
}

return Status::NotImplemented("Non-primitive types not complete yet");
Expand Down
17 changes: 17 additions & 0 deletions cpp/src/arrow/ipc/ipc-metadata-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "arrow/schema.h"
#include "arrow/test-util.h"
#include "arrow/type.h"
#include "arrow/types/union.h"
#include "arrow/util/status.h"

namespace arrow {
Expand Down Expand Up @@ -97,6 +98,22 @@ TEST_F(TestSchemaMessage, NestedFields) {
CheckRoundtrip(&schema);
}

TEST_F(TestSchemaMessage, UnionType) {
auto f0 = std::make_shared<Field>("f0", TypePtr(new Int32Type()));
auto f1 = std::make_shared<Field>("f1", TypePtr(new Int64Type()));
std::vector<uint8_t> type_ids = {}; // TODO(pcm): Implement typeIds
auto ud = TypePtr(new UnionType(std::vector<std::shared_ptr<Field>>({f0, f1}),
type_ids, UnionMode::DENSE));
auto fd = std::make_shared<Field>("f", ud);
Schema schema_dense({fd});
CheckRoundtrip(&schema_dense);
auto us = TypePtr(new UnionType(std::vector<std::shared_ptr<Field>>({f0, f1}),
type_ids, UnionMode::SPARSE));
auto fs = std::make_shared<Field>("f", us);
Schema schema_sparse({fs});
CheckRoundtrip(&schema_sparse);
}

class TestFileFooter : public ::testing::Test {
public:
void SetUp() {}
Expand Down
32 changes: 30 additions & 2 deletions cpp/src/arrow/ipc/metadata-internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/ipc/Message_generated.h"
#include "arrow/schema.h"
#include "arrow/type.h"
#include "arrow/types/union.h"
#include "arrow/util/buffer.h"
#include "arrow/util/status.h"

Expand Down Expand Up @@ -119,8 +120,20 @@ static Status TypeFromFlatbuffer(flatbuf::Type type, const void* type_data,
case flatbuf::Type_Struct_:
*out = std::make_shared<StructType>(children);
return Status::OK();
case flatbuf::Type_Union:
return Status::NotImplemented("Type is not implemented");
case flatbuf::Type_Union: {
std::vector<uint8_t> type_ids = {}; // TODO(pcm): Implement typeIds
auto union_data = static_cast<const flatbuf::Union*>(type_data);
UnionMode mode;
if (union_data->mode() == flatbuf::UnionMode_Sparse) {
mode = UnionMode::SPARSE;
} else if (union_data->mode() == flatbuf::UnionMode_Dense) {
mode = UnionMode::DENSE;
} else {
return Status::Invalid("Unrecognized UnionMode");
}
*out = std::make_shared<UnionType>(children, type_ids, mode);
}
return Status::OK();
default:
return Status::Invalid("Unrecognized type");
}
Expand Down Expand Up @@ -158,6 +171,18 @@ static Status StructToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type
return Status::OK();
}

static Status UnionToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type,
std::vector<FieldOffset>* out_children, Offset* offset) {
auto union_type = std::dynamic_pointer_cast<UnionType>(type);
FieldOffset field;
for (int i = 0; i < union_type->num_children(); ++i) {
RETURN_NOT_OK(FieldToFlatbuffer(fbb, union_type->child(i), &field));
out_children->push_back(field);
}
*offset = flatbuf::CreateUnion(fbb).Union();
return Status::OK();
}

#define INT_TO_FB_CASE(BIT_WIDTH, IS_SIGNED) \
*out_type = flatbuf::Type_Int; \
*offset = IntToFlatbuffer(fbb, BIT_WIDTH, IS_SIGNED); \
Expand Down Expand Up @@ -208,6 +233,9 @@ static Status TypeToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type,
case Type::STRUCT:
*out_type = flatbuf::Type_Struct_;
return StructToFlatbuffer(fbb, type, children, offset);
case Type::UNION:
*out_type = flatbuf::Type_Union;
return UnionToFlatbuffer(fbb, type, children, offset);
default:
*out_type = flatbuf::Type_NONE; // Make clang-tidy happy
std::stringstream ss;
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ std::string UnionType::ToString() const {
return s.str();
}

bool UnionType::Equals(const DataType* other) const {
if (!DataType::Equals(other)) {
return false;
}
const UnionType *union_type = dynamic_cast<const UnionType*>(other);
return union_type && type_id == union_type->type_id
&& std::equal(type_ids.begin(), type_ids.end(),
union_type->type_ids.begin());
}

int NullType::bit_width() const {
return 0;
}
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,11 @@ struct ARROW_EXPORT UnionType : public DataType {
static std::string name() { return "union"; }
Status Accept(TypeVisitor* visitor) const override;

bool Equals(const DataType* other) const override;
bool Equals(const std::shared_ptr<DataType>& other) const {
return Equals(other.get());
}

UnionMode mode;
std::vector<uint8_t> type_ids;
};
Expand Down
47 changes: 46 additions & 1 deletion cpp/src/arrow/types/union.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,50 @@
#include <vector>

#include "arrow/type.h"
#include "arrow/util/status.h"

namespace arrow {} // namespace arrow
namespace arrow {

bool UnionArray::Equals(const std::shared_ptr<Array>& arr) const {
if (this == arr.get()) { return true; }
if (!arr) { return false; }
if (this->type_enum() != arr->type_enum()) { return false; }
if (null_count_ != arr->null_count()) { return false; }
return RangeEquals(0, length_, 0, arr);
}

bool UnionArray::RangeEquals(int32_t start_idx, int32_t end_idx, int32_t other_start_idx,
const std::shared_ptr<Array>& arr) const {
if (this == arr.get()) { return true; }
if (Type::UNION != arr->type_enum()) { return false; }
const auto other = static_cast<UnionArray*>(arr.get());

// TODO(pcm): Handle sparse case here

int32_t i = start_idx;
int32_t o_i = other_start_idx;
for (size_t c = 0; c < other->children().size(); ++c) {
for (int32_t e = 0; e < other->children()[c]->length(); ++e) {
if (!children()[c]->RangeEquals(e, e + 1, e, other->children()[c])) { // FIXME(pcm): fix this
return false;
}
i += 1;
o_i += 1;
if (i >= end_idx) {
return true;
}
}
}
return false; // to make the compiler happy
}

Status UnionArray::Validate() const {
// TODO(pcm): what to do here?
return Status::OK();
}

Status UnionArray::Accept(ArrayVisitor* visitor) const {
return visitor->Visit(*this);
}

} // namespace arrow
38 changes: 29 additions & 9 deletions cpp/src/arrow/types/union.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,45 @@

#include "arrow/array.h"
#include "arrow/type.h"
#include "arrow/types/primitive.h"

namespace arrow {

class Buffer;

class UnionArray : public Array {
class ARROW_EXPORT UnionArray : public Array {
public:
UnionArray(const TypePtr& type, int32_t length, std::vector<ArrayPtr>& children,
std::shared_ptr<Buffer> types, std::shared_ptr<Buffer> offset_buf,
int32_t null_count = 0, std::shared_ptr<Buffer> null_bitmap = nullptr)
: Array(type, length, null_count, null_bitmap), types_(types) {
type_ = type;
children_ = children;
offset_buf_ = offset_buf;
}

const std::shared_ptr<Buffer>& types() const { return types_; }

const std::vector<ArrayPtr>& children() const { return children_; }

const std::shared_ptr<Buffer>& offset_buf() const { return offset_buf_; }

Status Validate() const override;

Status Accept(ArrayVisitor* visitor) const override;

bool Equals(const std::shared_ptr<Array>& arr) const override;
bool RangeEquals(int32_t start_idx, int32_t end_idx, int32_t other_start_idx,
const std::shared_ptr<Array>& arr) const override;

ArrayPtr child(int32_t index) const { return children_[index]; }
protected:
// The data are types encoded as int16
Buffer* types_;
std::shared_ptr<Buffer> types_;
std::vector<std::shared_ptr<Array>> children_;
std::shared_ptr<Buffer> offset_buf_;
};

class DenseUnionArray : public UnionArray {
protected:
Buffer* offset_buf_;
};

class SparseUnionArray : public UnionArray {};

} // namespace arrow

#endif // ARROW_TYPES_UNION_H