Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JNI: Pass names of children struct columns to native Arrow IPC writer [skip ci] #7598

Merged
merged 16 commits into from
Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ArrowIPCWriterOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,64 @@ public Builder withCallback(DoneOnGpu callback) {
return this;
}

/**
* Add the name(s) for nullable column(s).
*
* Please note the column names of the nested struct columns should be flattened in sequence.
* For examples,
* <pre>
* A table with an int column and a struct column:
* ["int_col", "struct_col":{"field_1", "field_2"}]
* output:
* ["int_col", "struct_col", "field_1", "field_2"]
*
* A table with an int column and a list of non-nested type column:
* ["int_col", "list_col":[]]
* output:
* ["int_col", "list_col"]
*
* A table with an int column and a list of struct column:
* ["int_col", "list_struct_col":[{"field_1", "field_2"}]]
* output:
* ["int_col", "list_struct_col", "field_1", "field_2"]
* </pre>
*
* @param columnNames The column names corresponding to the written table(s).
*/
@Override
public Builder withColumnNames(String... columnNames) {
return super.withColumnNames(columnNames);
}

/**
* Add the name(s) for non-nullable column(s).
*
* Please note the column names of the nested struct columns should be flattened in sequence.
* For examples,
* <pre>
* A table with an int column and a struct column:
* ["int_col", "struct_col":{"field_1", "field_2"}]
* output:
* ["int_col", "struct_col", "field_1", "field_2"]
*
* A table with an int column and a list of non-nested type column:
* ["int_col", "list_col":[]]
* output:
* ["int_col", "list_col"]
*
* A table with an int column and a list of struct column:
* ["int_col", "list_struct_col":[{"field_1", "field_2"}]]
* output:
* ["int_col", "list_struct_col", "field_1", "field_2"]
* </pre>
*
* @param columnNames The column names corresponding to the written table(s).
*/
@Override
public Builder withNotNullableColumnNames(String... columnNames) {
return super.withNotNullableColumnNames(columnNames);
}

public ArrowIPCWriterOptions build() {
return new ArrowIPCWriterOptions(this);
}
Expand Down
63 changes: 57 additions & 6 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,15 @@ class native_arrow_ipc_writer_handle final {
const std::shared_ptr<arrow::io::OutputStream> &sink)
: initialized(false), column_names(col_names), file_name(""), sink(sink) {}

private:
bool initialized;
std::vector<std::string> column_names;
std::vector<cudf::column_metadata> columns_meta;
std::string file_name;
std::shared_ptr<arrow::io::OutputStream> sink;
std::shared_ptr<arrow::ipc::RecordBatchWriter> writer;

public:
void write(std::shared_ptr<arrow::Table> &arrow_tab, int64_t max_chunk) {
if (!initialized) {
if (!sink) {
Expand Down Expand Up @@ -245,6 +248,59 @@ class native_arrow_ipc_writer_handle final {
}
initialized = false;
}

std::vector<cudf::column_metadata> get_column_metadata(const cudf::table_view& tview) {
if (!column_names.empty() && columns_meta.empty()) {
// Rebuild the structure of column meta according to table schema.
// All the tables written by this writer should share the same schema,
// so build column metadata only once.
columns_meta.reserve(tview.num_columns());
size_t idx = 0;
for (auto itr = tview.begin(); itr < tview.end(); ++itr) {
// It should consume the column names only when a column is
// - type of struct, or
// - not a child.
columns_meta.push_back(build_one_column_meta(*itr, idx));
}
if (idx < column_names.size()) {
throw cudf::jni::jni_exception("Too many column names are provided.");
}
}
jlowe marked this conversation as resolved.
Show resolved Hide resolved
return columns_meta;
}

private:
cudf::column_metadata build_one_column_meta(const cudf::column_view& cview, size_t& idx,
const bool consume_name = true) {
auto col_meta = cudf::column_metadata{};
if (consume_name) {
col_meta.name = get_column_name(idx++);
}
// Process children
if (cview.type().id() == cudf::type_id::LIST) {
// list type:
// - requires a stub metadata for offset column(index: 0).
// - does not require a name for the child column(index 1).
col_meta.children_meta = {{}, build_one_column_meta(cview.child(1), idx, false)};
} else if (cview.type().id() == cudf::type_id::STRUCT) {
// struct type always consumes the column names.
col_meta.children_meta.reserve(cview.num_children());
for (auto itr = cview.child_begin(); itr < cview.child_end(); ++itr) {
col_meta.children_meta.push_back(build_one_column_meta(*itr, idx));
}
} else if (cview.type().id() == cudf::type_id::DICTIONARY32) {
// not supported yet in JNI, nested type?
jlowe marked this conversation as resolved.
Show resolved Hide resolved
throw cudf::jni::jni_exception("Unsupported type 'DICTIONARY32'");
}
return col_meta;
}

std::string& get_column_name(const size_t idx) {
if (idx < 0 || idx >= column_names.size()) {
throw cudf::jni::jni_exception("Missing names for columns or nested struct columns");
}
return column_names[idx];
}
};

class jni_arrow_output_stream final : public arrow::io::OutputStream {
Expand Down Expand Up @@ -1245,12 +1301,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_convertCudfToArrowTable(JNIEnv
cudf::jni::auto_set_device(env);
std::unique_ptr<std::shared_ptr<arrow::Table>> result(
new std::shared_ptr<arrow::Table>(nullptr));
auto column_metadata = std::vector<cudf::column_metadata>{};
column_metadata.reserve(state->column_names.size());
std::transform(std::begin(state->column_names), std::end(state->column_names),
std::back_inserter(column_metadata),
[](auto const &column_name) { return cudf::column_metadata{column_name}; });
*result = cudf::to_arrow(*tview, column_metadata);
*result = cudf::to_arrow(*tview, state->get_column_metadata(*tview));
if (!result->get()) {
return 0;
}
Expand Down
46 changes: 36 additions & 10 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4056,15 +4056,38 @@ void testTableBasedFilter() {
}

private Table getExpectedFileTable() {
return new TestBuilder()
.column(true, false, false, true, false)
.column(5, 1, 0, 2, 7)
.column(new Byte[]{2, 3, 4, 5, 9})
.column(3l, 9l, 4l, 2l, 20l)
.column("this", "is", "a", "test", "string")
.column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f)
.column(5.0d, 9.5d, 0.9d, 7.23d, 2.8d)
.build();
return getExpectedFileTable(false);
}

private Table getExpectedFileTable(boolean withNestedColumns) {
TestBuilder tb = new TestBuilder()
.column(true, false, false, true, false)
.column(5, 1, 0, 2, 7)
.column(new Byte[]{2, 3, 4, 5, 9})
.column(3l, 9l, 4l, 2l, 20l)
.column("this", "is", "a", "test", "string")
.column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f)
.column(5.0d, 9.5d, 0.9d, 7.23d, 2.8d);
if (withNestedColumns) {
StructType nestedType = new StructType(true,
new BasicType(false, DType.INT32), new BasicType(false, DType.STRING));
tb.column(nestedType,
struct(1, "k1"), struct(2, "k2"), struct(3, "k3"),
struct(4, "k4"), new HostColumnVector.StructData((List) null))
.column(new ListType(false, new BasicType(false, DType.INT32)),
Arrays.asList(1, 2),
Arrays.asList(3, 4),
Arrays.asList(5),
Arrays.asList(6, 7),
Arrays.asList(8, 9, 10))
.column(new ListType(false, nestedType),
Arrays.asList(struct(1, "k1"), struct(2, "k2"), struct(3, "k3")),
Arrays.asList(struct(4, "k4"), struct(5, "k5")),
Arrays.asList(struct(6, "k6")),
Arrays.asList(new HostColumnVector.StructData((List) null)),
Arrays.asList());
}
return tb.build();
}

private Table getExpectedFileTableWithDecimals() {
Expand Down Expand Up @@ -4272,10 +4295,13 @@ void testArrowIPCWriteToFileWithNamesAndMetadata() throws IOException {

@Test
void testArrowIPCWriteToBufferChunked() {
try (Table table0 = getExpectedFileTable();
try (Table table0 = getExpectedFileTable(true);
MyBufferConsumer consumer = new MyBufferConsumer()) {
ArrowIPCWriterOptions options = ArrowIPCWriterOptions.builder()
.withColumnNames("first", "second", "third", "fourth", "fifth", "sixth", "seventh")
.withColumnNames("eighth", "eighth_id", "eighth_name")
.withColumnNames("ninth")
.withColumnNames("tenth", "child_id", "child_name")
.build();
try (TableWriter writer = Table.writeArrowIPCChunked(options, consumer)) {
writer.write(table0);
Expand Down