Skip to content

Commit

Permalink
feat(c/driver/postgresql): Basic Support for Writing LIST types
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd committed Jul 2, 2024
1 parent ce896df commit 5cb0702
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 7 deletions.
48 changes: 48 additions & 0 deletions c/driver/postgresql/copy/postgres_copy_writer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,54 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteBinary) {
}
}

TEST(PostgresCopyUtilsTest, PostgresCopyWriteArray) {
adbc_validation::Handle<struct ArrowSchema> schema;
adbc_validation::Handle<struct ArrowArray> array;
struct ArrowError na_error;

ASSERT_EQ(ArrowSchemaInitFromType(&schema.value, NANOARROW_TYPE_STRUCT), NANOARROW_OK);
ASSERT_EQ(ArrowSchemaAllocateChildren(&schema.value, 1), NANOARROW_OK);

ASSERT_EQ(ArrowSchemaInitFromType(schema->children[0], NANOARROW_TYPE_LIST),
NANOARROW_OK);
ASSERT_EQ(ArrowSchemaSetName(schema->children[0], "col"), NANOARROW_OK);
// ASSERT_EQ(ArrowSchemaAllocateChildren(schema->children[0], 1), NANOARROW_OK);
ASSERT_EQ(ArrowSchemaSetType(schema->children[0]->children[0], NANOARROW_TYPE_INT32),
NANOARROW_OK);

ASSERT_EQ(ArrowArrayInitFromSchema(&array.value, &schema.value, nullptr), NANOARROW_OK);
ASSERT_EQ(ArrowArrayStartAppending(&array.value), NANOARROW_OK);

ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], -123), NANOARROW_OK);
ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], -1), NANOARROW_OK);
ASSERT_EQ(ArrowArrayFinishElement(array->children[0]), NANOARROW_OK);
ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK);

ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 0), NANOARROW_OK);
ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 1), NANOARROW_OK);
ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 123), NANOARROW_OK);
ASSERT_EQ(ArrowArrayFinishElement(array->children[0]), NANOARROW_OK);
ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK);

ASSERT_EQ(ArrowArrayAppendNull(array->children[0], 1), NANOARROW_OK);
ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK);

ASSERT_EQ(ArrowArrayFinishBuildingDefault(&array.value, &na_error), NANOARROW_OK);

PostgresCopyStreamWriteTester tester;
ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK);
ASSERT_EQ(tester.WriteAll(nullptr), ENODATA);

const struct ArrowBuffer buf = tester.WriteBuffer();
// The last 2 bytes of a message can be transmitted via PQputCopyData
// so no need to test those bytes from the Writer
constexpr size_t buf_size = sizeof(kTestPgCopyIntegerArray) - 2;
ASSERT_EQ(buf.size_bytes, buf_size);
for (size_t i = 0; i < buf_size; i++) {
ASSERT_EQ(buf.data[i], kTestPgCopyIntegerArray[i]) << "failure at index " << i;
}
}

TEST(PostgresCopyUtilsTest, PostgresCopyWriteMultiBatch) {
// Regression test for https://github.com/apache/arrow-adbc/issues/1310
adbc_validation::Handle<struct ArrowSchema> schema;
Expand Down
125 changes: 118 additions & 7 deletions c/driver/postgresql/copy/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ class PostgresCopyFieldWriter {
class PostgresCopyFieldTupleWriter : public PostgresCopyFieldWriter {
public:
void AppendChild(std::unique_ptr<PostgresCopyFieldWriter> child) {
int64_t child_i = static_cast<int64_t>(children_.size());
children_.push_back(std::move(child));
children_[child_i]->Init(array_view_->children[child_i]);
}

ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
Expand Down Expand Up @@ -437,6 +435,56 @@ class PostgresCopyBinaryDictFieldWriter : public PostgresCopyFieldWriter {
}
};

class PostgresCopyListFieldWriter : public PostgresCopyFieldWriter {
public:
explicit PostgresCopyListFieldWriter(uint32_t child_oid) : child_oid_{child_oid} {}

void InitChild(std::unique_ptr<PostgresCopyFieldWriter> child) {
child_ = std::move(child);
}

ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
if (index >= array_view_->length) {
return ENODATA;
}

constexpr int32_t ndim = 1;
constexpr int32_t has_null_flags = 0;

const int32_t start = ArrowArrayViewListChildOffset(array_view_, index);
const int32_t end = ArrowArrayViewListChildOffset(array_view_, index + 1);
const int32_t dim = end - start;
constexpr int32_t lb = 1;

// TODO: this works for primitive types where we can calculate the buffer size
// in advance for varying types we likely need to create a separate buffer first
const int32_t child_record_size =
array_view_->children[0]->layout.element_size_bits[1] / 8;
const int32_t field_size_bytes =
sizeof(ndim) + sizeof(has_null_flags) + sizeof(child_oid_) + sizeof(dim) +
sizeof(lb)
// for each primitive record we send int32_t nbytes + the value itself
+ sizeof(int32_t) * dim + child_record_size * dim;

NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, ndim, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, has_null_flags, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<uint32_t>(buffer, child_oid_, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, dim, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, lb, error));

for (auto i = start; i < end; ++i) {
NANOARROW_RETURN_NOT_OK(child_->Write(buffer, i, error));
}

return ADBC_STATUS_OK;
}

private:
std::unique_ptr<PostgresCopyFieldWriter> child_;
const uint32_t child_oid_;
};

template <enum ArrowTimeUnit TU>
class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
public:
Expand Down Expand Up @@ -495,98 +543,116 @@ class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
};

static inline ArrowErrorCode MakeCopyFieldWriter(
struct ArrowSchema* schema, std::unique_ptr<PostgresCopyFieldWriter>* out,
ArrowError* error) {
struct ArrowSchema* schema, struct ArrowArrayView* array_view,
std::unique_ptr<PostgresCopyFieldWriter>* out, ArrowError* error) {
struct ArrowSchemaView schema_view;
NANOARROW_RETURN_NOT_OK(ArrowSchemaViewInit(&schema_view, schema, error));

switch (schema_view.type) {
case NANOARROW_TYPE_BOOL:
*out = std::make_unique<PostgresCopyBooleanFieldWriter>();
out->get()->Init(array_view);
return NANOARROW_OK;
case NANOARROW_TYPE_INT8:
case NANOARROW_TYPE_INT16:
*out = std::make_unique<PostgresCopyNetworkEndianFieldWriter<int16_t>>();
out->get()->Init(array_view);
return NANOARROW_OK;
case NANOARROW_TYPE_INT32:
*out = std::make_unique<PostgresCopyNetworkEndianFieldWriter<int32_t>>();
out->get()->Init(array_view);
return NANOARROW_OK;
case NANOARROW_TYPE_INT64:
*out = std::make_unique<PostgresCopyNetworkEndianFieldWriter<int64_t>>();
out->get()->Init(array_view);
return NANOARROW_OK;
case NANOARROW_TYPE_DATE32: {
constexpr int32_t kPostgresDateEpoch = 10957;
*out = std::make_unique<
PostgresCopyNetworkEndianFieldWriter<int32_t, kPostgresDateEpoch>>();
out->get()->Init(array_view);
return NANOARROW_OK;
}
case NANOARROW_TYPE_FLOAT:
*out = std::make_unique<PostgresCopyFloatFieldWriter>();
out->get()->Init(array_view);
return NANOARROW_OK;
case NANOARROW_TYPE_DOUBLE:
*out = std::make_unique<PostgresCopyDoubleFieldWriter>();
out->get()->Init(array_view);
return NANOARROW_OK;
case NANOARROW_TYPE_DECIMAL128: {
const auto precision = schema_view.decimal_precision;
const auto scale = schema_view.decimal_scale;
*out = std::make_unique<PostgresCopyNumericFieldWriter<NANOARROW_TYPE_DECIMAL128>>(
precision, scale);
out->get()->Init(array_view);
return NANOARROW_OK;
}
case NANOARROW_TYPE_DECIMAL256: {
const auto precision = schema_view.decimal_precision;
const auto scale = schema_view.decimal_scale;
*out = std::make_unique<PostgresCopyNumericFieldWriter<NANOARROW_TYPE_DECIMAL256>>(
precision, scale);
out->get()->Init(array_view);
return NANOARROW_OK;
}
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
*out = std::make_unique<PostgresCopyBinaryFieldWriter>();
out->get()->Init(array_view);
return NANOARROW_OK;
case NANOARROW_TYPE_TIMESTAMP: {
switch (schema_view.time_unit) {
case NANOARROW_TIME_UNIT_NANO:
*out = std::make_unique<
PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_NANO>>();
out->get()->Init(array_view);
break;
case NANOARROW_TIME_UNIT_MILLI:
*out = std::make_unique<
PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MILLI>>();
out->get()->Init(array_view);
break;
case NANOARROW_TIME_UNIT_MICRO:
*out = std::make_unique<
PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MICRO>>();
out->get()->Init(array_view);
break;
case NANOARROW_TIME_UNIT_SECOND:
*out = std::make_unique<
PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_SECOND>>();
out->get()->Init(array_view);
break;
}
return NANOARROW_OK;
}
case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
*out = std::make_unique<PostgresCopyIntervalFieldWriter>();
out->get()->Init(array_view);
return NANOARROW_OK;
case NANOARROW_TYPE_DURATION: {
switch (schema_view.time_unit) {
case NANOARROW_TIME_UNIT_SECOND:
*out = std::make_unique<
PostgresCopyDurationFieldWriter<NANOARROW_TIME_UNIT_SECOND>>();
out->get()->Init(array_view);
break;
case NANOARROW_TIME_UNIT_MILLI:
*out = std::make_unique<
PostgresCopyDurationFieldWriter<NANOARROW_TIME_UNIT_MILLI>>();
out->get()->Init(array_view);
break;
case NANOARROW_TIME_UNIT_MICRO:
*out = std::make_unique<
PostgresCopyDurationFieldWriter<NANOARROW_TIME_UNIT_MICRO>>();

out->get()->Init(array_view);
break;
case NANOARROW_TIME_UNIT_NANO:
*out = std::make_unique<
PostgresCopyDurationFieldWriter<NANOARROW_TIME_UNIT_NANO>>();
out->get()->Init(array_view);
break;
}
return NANOARROW_OK;
Expand All @@ -601,10 +667,55 @@ static inline ArrowErrorCode MakeCopyFieldWriter(
case NANOARROW_TYPE_LARGE_BINARY:
case NANOARROW_TYPE_LARGE_STRING:
*out = std::make_unique<PostgresCopyBinaryDictFieldWriter>();
out->get()->Init(array_view);
return NANOARROW_OK;
default:
break;
}
break;
}
case NANOARROW_TYPE_LIST: {
// For now our implementation only supports primitive children types
// See PostgresCopyListFieldWriter::Write for limtiations
struct ArrowSchemaView child_schema_view;
NANOARROW_RETURN_NOT_OK(
ArrowSchemaViewInit(&child_schema_view, schema->children[0], error));
switch (child_schema_view.type) {
case NANOARROW_TYPE_INT8:
case NANOARROW_TYPE_INT16:
case NANOARROW_TYPE_INT32:
case NANOARROW_TYPE_INT64:
case NANOARROW_TYPE_UINT8:
case NANOARROW_TYPE_UINT16:
case NANOARROW_TYPE_UINT32:
case NANOARROW_TYPE_UINT64:
case NANOARROW_TYPE_FLOAT:
case NANOARROW_TYPE_DOUBLE: {
// TODO: likely need to make type_resolver available here
// PostgresTypeResolver resolver;
// PostgresType child_type;
// NANOARROW_RETURN_NOT_OK(PostgresType::FromSchema(resolver,
// schema->children[0], &child_type, error));
constexpr uint32_t child_oid = 23; // TODO: don't hard-code

auto list_writer = std::make_unique<PostgresCopyListFieldWriter>(child_oid);
list_writer->Init(array_view);

std::unique_ptr<PostgresCopyFieldWriter> child_writer;
NANOARROW_RETURN_NOT_OK(MakeCopyFieldWriter(
schema->children[0], array_view->children[0], &child_writer, error));

list_writer->InitChild(std::move(child_writer));

*out = std::move(list_writer);
return NANOARROW_OK;
}
default:
ArrowErrorSet(
error, "COPY Writer not implemented for list types with child type of %d",
child_schema_view.type);
return EINVAL;
}
}
default:
break;
Expand Down Expand Up @@ -658,8 +769,8 @@ class PostgresCopyStreamWriter {

for (int64_t i = 0; i < schema_->n_children; i++) {
std::unique_ptr<PostgresCopyFieldWriter> child_writer;
NANOARROW_RETURN_NOT_OK(
MakeCopyFieldWriter(schema_->children[i], &child_writer, error));
NANOARROW_RETURN_NOT_OK(MakeCopyFieldWriter(
schema_->children[i], array_view_->children[i], &child_writer, error));
root_writer_.AppendChild(std::move(child_writer));
}

Expand Down

0 comments on commit 5cb0702

Please sign in to comment.