Skip to content

Commit

Permalink
implement dense unions
Browse files Browse the repository at this point in the history
  • Loading branch information
pcmoritz committed Nov 19, 2016
1 parent ed6ec3b commit d881d71
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 12 deletions.
9 changes: 9 additions & 0 deletions cpp/src/arrow/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ class ARROW_EXPORT ArrayBuilder {
DISALLOW_COPY_AND_ASSIGN(ArrayBuilder);
};

class ARROW_EXPORT NullArrayBuilder : public ArrayBuilder {
public:
explicit NullArrayBuilder(MemoryPool* pool, const TypePtr& type) : ArrayBuilder(pool, type) {}
virtual ~NullArrayBuilder() {};
Status Finish(std::shared_ptr<Array>* out) override {
return Status::OK();
}
};

} // namespace arrow

#endif // ARROW_BUILDER_H_
24 changes: 24 additions & 0 deletions cpp/src/arrow/ipc/adapter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "arrow/types/primitive.h"
#include "arrow/types/string.h"
#include "arrow/types/struct.h"
#include "arrow/types/union.h"
#include "arrow/util/bit-util.h"
#include "arrow/util/buffer.h"
#include "arrow/util/logging.h"
Expand Down Expand Up @@ -115,6 +116,13 @@ Status VisitArray(const Array* arr, std::vector<flatbuf::FieldNode>* field_nodes
RETURN_NOT_OK(
VisitArray(field.get(), field_nodes, buffers, max_recursion_depth - 1));
}
} else if (arr->type_enum() == Type::UNION) {
const auto union_arr = static_cast<const UnionArray*>(arr);
buffers->push_back(union_arr->types());
buffers->push_back(union_arr->offset_buf());
for (auto& child_arr : union_arr->children()) {
RETURN_NOT_OK(VisitArray(child_arr.get(), field_nodes, buffers, max_recursion_depth - 1));
}
} else {
return Status::NotImplemented("Unrecognized type");
}
Expand Down Expand Up @@ -363,6 +371,22 @@ class RecordBatchReader::RecordBatchReaderImpl {
out->reset(new StructArray(
type, field_meta.length, fields, field_meta.null_count, null_bitmap));
return Status::OK();
} else if (type->type == Type::UNION) {
std::shared_ptr<Buffer> types;
RETURN_NOT_OK(GetBuffer(buffer_index_++, &types));
std::shared_ptr<Buffer> offset_buf;
RETURN_NOT_OK(GetBuffer(buffer_index_++, &offset_buf));
auto union_type = std::dynamic_pointer_cast<UnionType>(type);
const int num_children = union_type->num_children();
std::vector<ArrayPtr> results;
for (int child_idx = 0; child_idx < num_children; ++child_idx) {
std::shared_ptr<Array> result;
RETURN_NOT_OK(NextArray(union_type->child(child_idx), max_recursion_depth - 1, &result));
results.push_back(result);
}
out->reset(new UnionArray(
type, field_meta.length, results, types, offset_buf, field_meta.null_count, null_bitmap));
return Status::OK();
}

return Status::NotImplemented("Non-primitive types not complete yet");
Expand Down
17 changes: 17 additions & 0 deletions cpp/src/arrow/ipc/ipc-metadata-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "arrow/schema.h"
#include "arrow/test-util.h"
#include "arrow/type.h"
#include "arrow/types/union.h"
#include "arrow/util/status.h"

namespace arrow {
Expand Down Expand Up @@ -97,6 +98,22 @@ TEST_F(TestSchemaMessage, NestedFields) {
CheckRoundtrip(&schema);
}

TEST_F(TestSchemaMessage, UnionType) {
auto f0 = std::make_shared<Field>("f0", TypePtr(new Int32Type()));
auto f1 = std::make_shared<Field>("f1", TypePtr(new Int64Type()));
std::vector<uint8_t> type_ids = {}; // TODO(pcm): Implement typeIds
auto ud = TypePtr(new UnionType(std::vector<std::shared_ptr<Field>>({f0, f1}),
type_ids, UnionMode::DENSE));
auto fd = std::make_shared<Field>("f", ud);
Schema schema_dense({fd});
CheckRoundtrip(&schema_dense);
auto us = TypePtr(new UnionType(std::vector<std::shared_ptr<Field>>({f0, f1}),
type_ids, UnionMode::SPARSE));
auto fs = std::make_shared<Field>("f", us);
Schema schema_sparse({fs});
CheckRoundtrip(&schema_sparse);
}

class TestFileFooter : public ::testing::Test {
public:
void SetUp() {}
Expand Down
32 changes: 30 additions & 2 deletions cpp/src/arrow/ipc/metadata-internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/ipc/Message_generated.h"
#include "arrow/schema.h"
#include "arrow/type.h"
#include "arrow/types/union.h"
#include "arrow/util/buffer.h"
#include "arrow/util/status.h"

Expand Down Expand Up @@ -119,8 +120,20 @@ static Status TypeFromFlatbuffer(flatbuf::Type type, const void* type_data,
case flatbuf::Type_Struct_:
*out = std::make_shared<StructType>(children);
return Status::OK();
case flatbuf::Type_Union:
return Status::NotImplemented("Type is not implemented");
case flatbuf::Type_Union: {
std::vector<uint8_t> type_ids = {}; // TODO(pcm): Implement typeIds
auto union_data = static_cast<const flatbuf::Union*>(type_data);
UnionMode mode;
if (union_data->mode() == flatbuf::UnionMode_Sparse) {
mode = UnionMode::SPARSE;
} else if (union_data->mode() == flatbuf::UnionMode_Dense) {
mode = UnionMode::DENSE;
} else {
return Status::Invalid("Unrecognized UnionMode");
}
*out = std::make_shared<UnionType>(children, type_ids, mode);
}
return Status::OK();
default:
return Status::Invalid("Unrecognized type");
}
Expand Down Expand Up @@ -158,6 +171,18 @@ static Status StructToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type
return Status::OK();
}

static Status UnionToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type,
std::vector<FieldOffset>* out_children, Offset* offset) {
auto union_type = std::dynamic_pointer_cast<UnionType>(type);
FieldOffset field;
for (int i = 0; i < union_type->num_children(); ++i) {
RETURN_NOT_OK(FieldToFlatbuffer(fbb, union_type->child(i), &field));
out_children->push_back(field);
}
*offset = flatbuf::CreateUnion(fbb).Union();
return Status::OK();
}

#define INT_TO_FB_CASE(BIT_WIDTH, IS_SIGNED) \
*out_type = flatbuf::Type_Int; \
*offset = IntToFlatbuffer(fbb, BIT_WIDTH, IS_SIGNED); \
Expand Down Expand Up @@ -208,6 +233,9 @@ static Status TypeToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type,
case Type::STRUCT:
*out_type = flatbuf::Type_Struct_;
return StructToFlatbuffer(fbb, type, children, offset);
case Type::UNION:
*out_type = flatbuf::Type_Union;
return UnionToFlatbuffer(fbb, type, children, offset);
default:
*out_type = flatbuf::Type_NONE; // Make clang-tidy happy
std::stringstream ss;
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ std::string UnionType::ToString() const {
return s.str();
}

bool UnionType::Equals(const DataType* other) const {
if (!DataType::Equals(other)) {
return false;
}
const UnionType *union_type = dynamic_cast<const UnionType*>(other);
return union_type && type_id == union_type->type_id
&& std::equal(type_ids.begin(), type_ids.end(),
union_type->type_ids.begin());
}

int NullType::bit_width() const {
return 0;
}
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,11 @@ struct ARROW_EXPORT UnionType : public DataType {
static std::string name() { return "union"; }
Status Accept(TypeVisitor* visitor) const override;

bool Equals(const DataType* other) const override;
bool Equals(const std::shared_ptr<DataType>& other) const {
return Equals(other.get());
}

UnionMode mode;
std::vector<uint8_t> type_ids;
};
Expand Down
47 changes: 46 additions & 1 deletion cpp/src/arrow/types/union.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,50 @@
#include <vector>

#include "arrow/type.h"
#include "arrow/util/status.h"

namespace arrow {} // namespace arrow
namespace arrow {

bool UnionArray::Equals(const std::shared_ptr<Array>& arr) const {
if (this == arr.get()) { return true; }
if (!arr) { return false; }
if (this->type_enum() != arr->type_enum()) { return false; }
if (null_count_ != arr->null_count()) { return false; }
return RangeEquals(0, length_, 0, arr);
}

bool UnionArray::RangeEquals(int32_t start_idx, int32_t end_idx, int32_t other_start_idx,
const std::shared_ptr<Array>& arr) const {
if (this == arr.get()) { return true; }
if (Type::UNION != arr->type_enum()) { return false; }
const auto other = static_cast<UnionArray*>(arr.get());

// TODO(pcm): Handle sparse case here

int32_t i = start_idx;
int32_t o_i = other_start_idx;
for (size_t c = 0; c < other->children().size(); ++c) {
for (int32_t e = 0; e < other->children()[c]->length(); ++e) {
if (!children()[c]->RangeEquals(e, e + 1, e, other->children()[c])) { // FIXME(pcm): fix this
return false;
}
i += 1;
o_i += 1;
if (i >= end_idx) {
return true;
}
}
}
return false; // to make the compiler happy
}

Status UnionArray::Validate() const {
// TODO(pcm): what to do here?
return Status::OK();
}

Status UnionArray::Accept(ArrayVisitor* visitor) const {
return visitor->Visit(*this);
}

} // namespace arrow
39 changes: 30 additions & 9 deletions cpp/src/arrow/types/union.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,46 @@

#include "arrow/array.h"
#include "arrow/type.h"
#include "arrow/types/collection.h"
#include "arrow/types/primitive.h"

namespace arrow {

class Buffer;

class UnionArray : public Array {
class ARROW_EXPORT UnionArray : public Array {
public:
UnionArray(const TypePtr& type, int32_t length, std::vector<ArrayPtr>& children,
std::shared_ptr<Buffer> types, std::shared_ptr<Buffer> offset_buf,
int32_t null_count = 0, std::shared_ptr<Buffer> null_bitmap = nullptr)
: Array(type, length, null_count, null_bitmap), types_(types) {
type_ = type;
children_ = children;
offset_buf_ = offset_buf;
}

const std::shared_ptr<Buffer>& types() const { return types_; }

const std::vector<ArrayPtr>& children() const { return children_; }

const std::shared_ptr<Buffer>& offset_buf() const { return offset_buf_; }

Status Validate() const override;

Status Accept(ArrayVisitor* visitor) const override;

bool Equals(const std::shared_ptr<Array>& arr) const override;
bool RangeEquals(int32_t start_idx, int32_t end_idx, int32_t other_start_idx,
const std::shared_ptr<Array>& arr) const override;

ArrayPtr child(int32_t index) const { return children_[index]; }
protected:
// The data are types encoded as int16
Buffer* types_;
std::shared_ptr<Buffer> types_;
std::vector<std::shared_ptr<Array>> children_;
std::shared_ptr<Buffer> offset_buf_;
};

class DenseUnionArray : public UnionArray {
protected:
Buffer* offset_buf_;
};

class SparseUnionArray : public UnionArray {};

} // namespace arrow

#endif // ARROW_TYPES_UNION_H

0 comments on commit d881d71

Please sign in to comment.