Skip to content

Commit

Permalink
implement dense unions
Browse files Browse the repository at this point in the history
  • Loading branch information
pcmoritz committed Nov 18, 2016
1 parent 48f9780 commit 16c6fdc
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 6 deletions.
9 changes: 9 additions & 0 deletions cpp/src/arrow/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ class ARROW_EXPORT ArrayBuilder {
DISALLOW_COPY_AND_ASSIGN(ArrayBuilder);
};

class ARROW_EXPORT NullArrayBuilder : public ArrayBuilder {
public:
explicit NullArrayBuilder(MemoryPool* pool, const TypePtr& type) : ArrayBuilder(pool, type) {}
virtual ~NullArrayBuilder() {};
Status Finish(std::shared_ptr<Array>* out) override {
return Status::OK();
}
};

} // namespace arrow

#endif // ARROW_BUILDER_H_
25 changes: 25 additions & 0 deletions cpp/src/arrow/ipc/adapter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "arrow/types/primitive.h"
#include "arrow/types/string.h"
#include "arrow/types/struct.h"
#include "arrow/types/union.h"
#include "arrow/util/bit-util.h"
#include "arrow/util/buffer.h"
#include "arrow/util/logging.h"
Expand Down Expand Up @@ -115,6 +116,13 @@ Status VisitArray(const Array* arr, std::vector<flatbuf::FieldNode>* field_nodes
RETURN_NOT_OK(
VisitArray(field.get(), field_nodes, buffers, max_recursion_depth - 1));
}
} else if (arr->type_enum() == Type::DENSE_UNION) {
const auto union_arr = static_cast<const DenseUnionArray*>(arr);
buffers->push_back(union_arr->types());
buffers->push_back(union_arr->offset_buf());
for (auto& child_arr : union_arr->children()) {
RETURN_NOT_OK(VisitArray(child_arr.get(), field_nodes, buffers, max_recursion_depth - 1));
}
} else {
return Status::NotImplemented("Unrecognized type");
}
Expand Down Expand Up @@ -363,6 +371,23 @@ class RecordBatchReader::RecordBatchReaderImpl {
out->reset(new StructArray(
type, field_meta.length, fields, field_meta.null_count, null_bitmap));
return Status::OK();
} else if (type->type == Type::DENSE_UNION) {
std::shared_ptr<Buffer> types;
RETURN_NOT_OK(GetBuffer(buffer_index_++, &types));
std::shared_ptr<Buffer> offset_buf;
RETURN_NOT_OK(GetBuffer(buffer_index_++, &offset_buf));
auto dense_union_type = std::dynamic_pointer_cast<DenseUnionType>(type);
const int num_children = dense_union_type->num_children();
std::vector<ArrayPtr> results;
for (int child_idx = 0; child_idx < num_children; ++child_idx) {
std::shared_ptr<Array> result;
auto child_field = std::make_shared<Field>(std::string(""), dense_union_type->child(child_idx), false);
RETURN_NOT_OK(NextArray(child_field.get(), max_recursion_depth - 1, &result));
results.push_back(result);
}
out->reset(new DenseUnionArray(
type, field_meta.length, results, types, offset_buf, field_meta.null_count, null_bitmap));
return Status::OK();
}

return Status::NotImplemented("Non-primitive types not complete yet");
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/ipc/ipc-metadata-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "arrow/schema.h"
#include "arrow/test-util.h"
#include "arrow/type.h"
#include "arrow/types/union.h"
#include "arrow/util/status.h"

namespace arrow {
Expand Down Expand Up @@ -97,6 +98,15 @@ TEST_F(TestSchemaMessage, NestedFields) {
CheckRoundtrip(&schema);
}

TEST_F(TestSchemaMessage, DenseUnion) {
auto t0 = TypePtr(new Int32Type());
auto t1 = TypePtr(new Int64Type());
auto u = TypePtr(new DenseUnionType(std::vector<TypePtr>({t0, t1})));
auto f = std::make_shared<Field>("f", u);
Schema schema({f});
CheckRoundtrip(&schema);
}

class TestFileFooter : public ::testing::Test {
public:
void SetUp() {}
Expand Down
27 changes: 25 additions & 2 deletions cpp/src/arrow/ipc/metadata-internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/ipc/Message_generated.h"
#include "arrow/schema.h"
#include "arrow/type.h"
#include "arrow/types/union.h"
#include "arrow/util/buffer.h"
#include "arrow/util/status.h"

Expand Down Expand Up @@ -119,8 +120,14 @@ static Status TypeFromFlatbuffer(flatbuf::Type type, const void* type_data,
case flatbuf::Type_Struct_:
*out = std::make_shared<StructType>(children);
return Status::OK();
case flatbuf::Type_Union:
return Status::NotImplemented("Type is not implemented");
case flatbuf::Type_Union: {
std::vector<TypePtr> child_types;
for (auto type : children) {
child_types.push_back(type->type);
}
*out = std::make_shared<DenseUnionType>(child_types); // TODO(pcm): SparseUnionType
}
return Status::OK();
default:
return Status::Invalid("Unrecognized type");
}
Expand Down Expand Up @@ -158,6 +165,19 @@ static Status StructToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type
return Status::OK();
}

static Status UnionToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type,
std::vector<FieldOffset>* out_children, Offset* offset) {
auto dense_union_type = std::dynamic_pointer_cast<DenseUnionType>(type);
FieldOffset field;
for (int i = 0; i < dense_union_type->num_children(); ++i) {
auto child_field = std::make_shared<Field>(std::string(""), dense_union_type->child(i), false);
RETURN_NOT_OK(FieldToFlatbuffer(fbb, child_field, &field));
out_children->push_back(field);
}
*offset = flatbuf::CreateUnion(fbb).Union();
return Status::OK();
}

#define INT_TO_FB_CASE(BIT_WIDTH, IS_SIGNED) \
*out_type = flatbuf::Type_Int; \
*offset = IntToFlatbuffer(fbb, BIT_WIDTH, IS_SIGNED); \
Expand Down Expand Up @@ -208,6 +228,9 @@ static Status TypeToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type,
case Type::STRUCT:
*out_type = flatbuf::Type_Struct_;
return StructToFlatbuffer(fbb, type, children, offset);
case Type::DENSE_UNION:
*out_type = flatbuf::Type_Union;
return UnionToFlatbuffer(fbb, type, children, offset);
default:
*out_type = flatbuf::Type_NONE; // Make clang-tidy happy
std::stringstream ss;
Expand Down
36 changes: 36 additions & 0 deletions cpp/src/arrow/types/union.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,45 @@
#include <vector>

#include "arrow/type.h"
#include "arrow/util/status.h"

namespace arrow {

bool DenseUnionArray::Equals(const std::shared_ptr<Array>& arr) const {
if (this == arr.get()) { return true; }
if (!arr) { return false; }
if (this->type_enum() != arr->type_enum()) { return false; }
if (null_count_ != arr->null_count()) { return false; }
return RangeEquals(0, length_, 0, arr);
}

bool DenseUnionArray::RangeEquals(int32_t start_idx, int32_t end_idx, int32_t other_start_idx,
const std::shared_ptr<Array>& arr) const {
if (this == arr.get()) { return true; }
if (Type::DENSE_UNION != arr->type_enum()) { return false; }
const auto other = static_cast<DenseUnionArray*>(arr.get());

int32_t i = start_idx;
int32_t o_i = other_start_idx;
for (size_t c = 0; c < other->children().size(); ++c) {
for (int32_t e = 0; e < other->children()[c]->length(); ++e) {
if (!children()[c]->RangeEquals(e, e + 1, e, other->children()[c])) { // FIXME(pcm): fix this
return false;
}
i += 1;
o_i += 1;
if (i >= end_idx) {
return true;
}
}
}
return false; // to make the compiler happy
}

Status DenseUnionArray::Validate() const {
return Status::OK();
}

static inline std::string format_union(const std::vector<TypePtr>& child_types) {
std::stringstream s;
s << "union<";
Expand Down
39 changes: 35 additions & 4 deletions cpp/src/arrow/types/union.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
#include "arrow/array.h"
#include "arrow/type.h"
#include "arrow/types/collection.h"
#include "arrow/types/primitive.h"

namespace arrow {

class Buffer;

struct DenseUnionType : public CollectionType<Type::DENSE_UNION> {
struct ARROW_EXPORT DenseUnionType : public CollectionType<Type::DENSE_UNION> {
typedef CollectionType<Type::DENSE_UNION> Base;

explicit DenseUnionType(const std::vector<TypePtr>& child_types) : Base() {
Expand All @@ -51,15 +52,45 @@ struct SparseUnionType : public CollectionType<Type::SPARSE_UNION> {
};

class UnionArray : public Array {
public:
UnionArray(const TypePtr& type, int32_t length, std::vector<ArrayPtr>& children,
std::shared_ptr<Buffer> types,
int32_t null_count = 0, std::shared_ptr<Buffer> null_bitmap = nullptr)
: Array(type, length, null_count, null_bitmap), types_(types) {
type_ = type;
children_ = children;
}

const std::shared_ptr<Buffer>& types() const { return types_; }

const std::vector<ArrayPtr>& children() const { return children_; }

ArrayPtr child(int32_t index) const { return children_[index]; }
protected:
// The data are types encoded as int16
Buffer* types_;
std::shared_ptr<Buffer> types_;
std::vector<std::shared_ptr<Array>> children_;
};

class DenseUnionArray : public UnionArray {
class ARROW_EXPORT DenseUnionArray : public UnionArray {
public:
DenseUnionArray(const TypePtr& type, int32_t length, std::vector<ArrayPtr>& children,
std::shared_ptr<Buffer> types, std::shared_ptr<Buffer> offset_buf,
int32_t null_count = 0, std::shared_ptr<Buffer> null_bitmap = nullptr)
: UnionArray(type, length, children, types, null_count, null_bitmap) {
offset_buf_ = offset_buf;
}

const std::shared_ptr<Buffer>& offset_buf() const { return offset_buf_; }

Status Validate() const override;

bool Equals(const std::shared_ptr<Array>& arr) const override;
bool RangeEquals(int32_t start_idx, int32_t end_idx, int32_t other_start_idx,
const std::shared_ptr<Array>& arr) const override;

protected:
Buffer* offset_buf_;
std::shared_ptr<Buffer> offset_buf_;
};

class SparseUnionArray : public UnionArray {};
Expand Down

0 comments on commit 16c6fdc

Please sign in to comment.