Skip to content

Commit

Permalink
Datasets Java API: Support readning parquet into dictionary-encoded r…
Browse files Browse the repository at this point in the history
…ecord batches (apache#73)
  • Loading branch information
zhztheplayer authored Jun 18, 2020
1 parent 4d166c4 commit c1f2ec7
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 48 deletions.
23 changes: 19 additions & 4 deletions cpp/src/arrow/dataset/file_parquet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,29 @@ static parquet::ReaderProperties MakeReaderProperties(
return properties;
}

void SetDictionaryColumns(const parquet::ParquetFileReader &reader,
parquet::ArrowReaderProperties &properties,
const std::unordered_set<std::string> &dict_columns) {
if (dict_columns.empty()) {
// default: dict-encode all columns
int num_columns = reader.metadata()->num_columns();
for (int i = 0; i < num_columns; i++) {
properties.set_read_dictionary(i, true);
}
return;
}
for (const std::string& name : dict_columns) {
auto column_index = reader.metadata()->schema()->ColumnIndex(name);
properties.set_read_dictionary(column_index, true);
}
}

static parquet::ArrowReaderProperties MakeArrowReaderProperties(
const ParquetFileFormat& format, int64_t batch_size,
const parquet::ParquetFileReader& reader) {
parquet::ArrowReaderProperties properties(/* use_threads = */ false);
for (const std::string& name : format.reader_options.dict_columns) {
auto column_index = reader.metadata()->schema()->ColumnIndex(name);
properties.set_read_dictionary(column_index, true);
}
std::unordered_set<std::string> dict_columns = format.reader_options.dict_columns;
SetDictionaryColumns(reader, properties, dict_columns);
properties.set_batch_size(batch_size);
return properties;
}
Expand Down
121 changes: 100 additions & 21 deletions cpp/src/jni/dataset/jni_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
#include <arrow/util/iterator.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/message.h>
#include "arrow/compute/kernel.h"
#include "arrow/compute/kernels/cast.h"
#include "arrow/compute/kernels/compare.h"
#include "jni/dataset/DTypes.pb.h"
#include "jni/dataset/concurrent_map.h"
#include <arrow/compute/kernel.h>
#include <arrow/compute/kernels/cast.h>
#include <arrow/compute/kernels/compare.h>
#include <jni/dataset/DTypes.pb.h>
#include <jni/dataset/concurrent_map.h>

#include "org_apache_arrow_dataset_file_JniWrapper.h"
#include "org_apache_arrow_dataset_jni_JniWrapper.h"
Expand All @@ -40,10 +40,12 @@ static jclass runtime_exception_class;
static jclass record_batch_handle_class;
static jclass record_batch_handle_field_class;
static jclass record_batch_handle_buffer_class;
static jclass dictionary_batch_handle_class;

static jmethodID record_batch_handle_constructor;
static jmethodID record_batch_handle_field_constructor;
static jmethodID record_batch_handle_buffer_constructor;
static jmethodID dictionary_batch_handle_constructor;

static jint JNI_VERSION = JNI_VERSION_1_6;

Expand Down Expand Up @@ -119,6 +121,15 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
GetMethodID(env, record_batch_handle_class, "<init>",
"(J[Lorg/apache/arrow/dataset/jni/NativeRecordBatchHandle$Field;"
"[Lorg/apache/arrow/dataset/jni/NativeRecordBatchHandle$Buffer;)V");

dictionary_batch_handle_class = CreateGlobalClassReference(
env, "Lorg/apache/arrow/dataset/jni/NativeDictionaryBatchHandle;");

dictionary_batch_handle_constructor =
GetMethodID(env, dictionary_batch_handle_class, "<init>",
"(JJ[Lorg/apache/arrow/dataset/jni/NativeRecordBatchHandle$Field;"
"[Lorg/apache/arrow/dataset/jni/NativeRecordBatchHandle$Buffer;)V");

record_batch_handle_field_constructor =
GetMethodID(env, record_batch_handle_field_class, "<init>", "(JJ)V");
record_batch_handle_buffer_constructor =
Expand All @@ -138,6 +149,7 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) {
env->DeleteGlobalRef(record_batch_handle_class);
env->DeleteGlobalRef(record_batch_handle_field_class);
env->DeleteGlobalRef(record_batch_handle_buffer_class);
env->DeleteGlobalRef(dictionary_batch_handle_class);

dataset_factory_holder_.Clear();
dataset_holder_.Clear();
Expand Down Expand Up @@ -219,6 +231,14 @@ std::vector<T> collect(JNIEnv* env, arrow::Iterator<T> itr) {
return vector;
}

jobjectArray makeObjectArray(JNIEnv* env, jclass clazz, std::vector<jobject> args) {
jobjectArray oa = env->NewObjectArray(args.size(), clazz, 0);
for (size_t i = 0; i < args.size(); i++) {
env->SetObjectArrayElement(oa, i, args.at(i));
}
return oa;
}

// FIXME: COPIED FROM intel/master on which this branch is not rebased yet
// FIXME:
// https://github.com/Intel-bigdata/arrow/blob/02502a4eb59834c2471dd629e77dbeed19559f68/cpp/src/jni/jni_common.h#L239-L254
Expand Down Expand Up @@ -512,19 +532,10 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_scan(
std::move(record_batch_iterator))); // move and propagate
}

/*
* Class: org_apache_arrow_dataset_jni_JniWrapper
* Method: nextRecordBatch
* Signature: (J)Lorg/apache/arrow/dataset/jni/NativeRecordBatchHandle;
*/
JNIEXPORT jobject JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecordBatch(
JNIEnv* env, jobject, jlong iterator_id) {
std::shared_ptr<arrow::RecordBatchIterator> itr = iterator_holder_.Lookup(iterator_id);

JNI_ASSIGN_OR_THROW(std::shared_ptr<arrow::RecordBatch> record_batch, itr->Next())
if (record_batch == nullptr) {
return nullptr; // stream ended
}
template<typename HandleCreator>
jobject createJavaHandle(JNIEnv *env,
std::shared_ptr<arrow::RecordBatch> &record_batch,
HandleCreator handle_creator) {
std::shared_ptr<arrow::Schema> schema = record_batch->schema();
jobjectArray field_array =
env->NewObjectArray(schema->num_fields(), record_batch_handle_field_class, nullptr);
Expand Down Expand Up @@ -561,12 +572,80 @@ JNIEXPORT jobject JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecor
buffer_holder_.Insert(buffer), data, size, capacity);
env->SetObjectArrayElement(buffer_array, j, buffer_handle);
}

jobject ret = env->NewObject(record_batch_handle_class, record_batch_handle_constructor,
record_batch->num_rows(), field_array, buffer_array);
int64_t num_rows = record_batch->num_rows();
jobject ret = handle_creator(num_rows, field_array, buffer_array);
return ret;
}

jobject createJavaRecordBatchHandle(JNIEnv *env,
std::shared_ptr<arrow::RecordBatch> &record_batch) {
auto handle_creator = [env] (int64_t num_rows, jobjectArray field_array,
jobjectArray buffer_array) {
return env->NewObject(record_batch_handle_class, record_batch_handle_constructor,
num_rows, field_array, buffer_array);
};
return createJavaHandle(env, record_batch, handle_creator);
}

jobject createJavaDictionaryBatchHandle(JNIEnv *env, jlong id,
std::shared_ptr<arrow::RecordBatch> &record_batch) {
auto handle_creator = [env, id] (int64_t num_rows, jobjectArray field_array,
jobjectArray buffer_array) {
return env->NewObject(dictionary_batch_handle_class,
dictionary_batch_handle_constructor, id, num_rows, field_array, buffer_array);
};
return createJavaHandle(env, record_batch, handle_creator);
}

/*
* Class: org_apache_arrow_dataset_jni_JniWrapper
* Method: nextRecordBatch
* Signature: (J)[Lorg/apache/arrow/dataset/jni/NativeRecordBatchHandle;
*/
jobjectArray JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecordBatch(
JNIEnv* env, jobject, jlong iterator_id) {
std::shared_ptr<arrow::RecordBatchIterator> itr = iterator_holder_.Lookup(iterator_id);

JNI_ASSIGN_OR_THROW(std::shared_ptr<arrow::RecordBatch> record_batch, itr->Next())
if (record_batch == nullptr) {
return nullptr; // stream ended
}
std::vector<jobject> handles;
jobject handle = createJavaRecordBatchHandle(env, record_batch);
handles.push_back(handle);

// dictionary batches
int num_columns = record_batch->num_columns();
long dict_id = 0;
for (int i = 0; i < num_columns; i++) {
// defer to Java dictionary batch rule: a single array per batch
std::shared_ptr<arrow::Field> field = record_batch->schema()->field(i);
std::shared_ptr<arrow::Array> data = record_batch->column(i);
std::shared_ptr<arrow::DataType> type = field->type();
if (type->id() == arrow::Type::DICTIONARY) {
std::shared_ptr<arrow::DataType> value_type =
arrow::internal::checked_cast<const arrow::DictionaryType &>(*type)
.value_type();
std::shared_ptr<arrow::DictionaryArray> dict_data =
arrow::internal::checked_pointer_cast<arrow::DictionaryArray>(data);
std::shared_ptr<arrow::Field>
value_field = std::make_shared<arrow::Field>(field->name(), value_type);
std::vector<std::shared_ptr<arrow::Field>> dict_batch_fields;
dict_batch_fields.push_back(value_field);
std::shared_ptr<arrow::Schema>
dict_batch_schema = std::make_shared<arrow::Schema>(dict_batch_fields);
std::vector<std::shared_ptr<arrow::Array>> dict_datum;
dict_datum.push_back(dict_data->dictionary());
std::shared_ptr<arrow::RecordBatch>
dict_batch = arrow::RecordBatch::Make(dict_batch_schema, dict_data->length(),
dict_datum);
jobject dict_handle = createJavaDictionaryBatchHandle(env, dict_id++, dict_batch);
handles.push_back(dict_handle);
}
}
return makeObjectArray(env, record_batch_handle_class, handles);
}

/*
* Class: org_apache_arrow_dataset_jni_JniWrapper
* Method: closeIterator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ private JniWrapper() {

public native long scan(long scanTaskId);

public native NativeRecordBatchHandle nextRecordBatch(long recordBatchIteratorId);
public native NativeRecordBatchHandle[] nextRecordBatch(long recordBatchIteratorId);

public native void closeIterator(long id);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.dataset.jni;

/**
* Hold pointers to a Arrow C++ DictionaryBatch.
*/
public class NativeDictionaryBatchHandle extends NativeRecordBatchHandle {

/**
* Dictionary ID.
*/
private final long id;

/**
* Constructor.
*/
public NativeDictionaryBatchHandle(long id, long numRows, Field[] fields, Buffer[] buffers) {
super(numRows, fields, buffers);
this.id = id;
}

public long getId() {
return id;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import org.apache.arrow.memory.BufferLedger;
import org.apache.arrow.memory.NativeUnderlingMemory;
import org.apache.arrow.memory.Ownerships;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.Schema;
Expand Down Expand Up @@ -58,7 +58,7 @@ public Itr scan() {
return new Itr() {

private final Reader in = new Reader(JniWrapper.get().scan(scanTaskId));
private VectorSchemaRoot peek = null;
private ArrowBundledVectors peek = null;

@Override
public void close() throws Exception {
Expand All @@ -74,15 +74,15 @@ public boolean hasNext() {
if (!in.loadNextBatch()) {
return false;
}
peek = in.getVectorSchemaRoot();
peek = new ArrowBundledVectors(in.getVectorSchemaRoot(), in.getDictionaryVectors());
return true;
} catch (IOException e) {
throw new RuntimeException(e);
}
}

@Override
public VectorSchemaRoot next() {
public ArrowBundledVectors next() {
if (!hasNext()) {
throw new NoSuchElementException();
}
Expand Down Expand Up @@ -113,10 +113,26 @@ private class Reader extends ArrowReader {
public boolean loadNextBatch() throws IOException {
// fixme it seems that the initialization is not thread-safe. Does caller already make it safe?
ensureInitialized();
NativeRecordBatchHandle handle = JniWrapper.get().nextRecordBatch(recordBatchIteratorId);
if (handle == null) {
NativeRecordBatchHandle[] handles = JniWrapper.get().nextRecordBatch(recordBatchIteratorId);
if (handles == null) {
return false;
}
for (NativeRecordBatchHandle handle : handles) {
if (handle instanceof NativeDictionaryBatchHandle) {
NativeDictionaryBatchHandle dbh = (NativeDictionaryBatchHandle) handle;
ArrowRecordBatch dictionary = toArrowRecordBatch(dbh);
ArrowDictionaryBatch db = new ArrowDictionaryBatch(dbh.getId(), dictionary, false);
loadDictionary(db);
continue;
}
// todo add and use NativeDataRecordBatch
ArrowRecordBatch batch = toArrowRecordBatch(handle);
loadRecordBatch(batch);
}
return true;
}

private ArrowRecordBatch toArrowRecordBatch(NativeRecordBatchHandle handle) {
final ArrayList<ArrowBuf> buffers = new ArrayList<>();
for (NativeRecordBatchHandle.Buffer buffer : handle.getBuffers()) {
final BaseAllocator allocator = context.getAllocator();
Expand All @@ -127,14 +143,12 @@ public boolean loadNextBatch() throws IOException {
buffers.add(buf);
}
try {
loadRecordBatch(
new ArrowRecordBatch((int) handle.getNumRows(), handle.getFields().stream()
.map(field -> new ArrowFieldNode((int) field.length, (int) field.nullCount))
.collect(Collectors.toList()), buffers));
return new ArrowRecordBatch((int) handle.getNumRows(), handle.getFields().stream()
.map(field -> new ArrowFieldNode((int) field.length, (int) field.nullCount))
.collect(Collectors.toList()), buffers);
} finally {
buffers.forEach(b -> b.getReferenceManager().release());
}
return true;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.arrow.dataset.scanner;

import java.util.Iterator;
import java.util.Map;

import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.Dictionary;

/**
* Read record batches from a range of a single data fragment. A
Expand All @@ -36,7 +38,20 @@ public interface ScanTask {
/**
* The iterator implementation for {@link VectorSchemaRoot}s.
*/
interface Itr extends Iterator<VectorSchemaRoot>, AutoCloseable {
interface Itr extends Iterator<ArrowBundledVectors>, AutoCloseable {
// FIXME VectorSchemaRoot is not actually something ITERABLE. Using a reader convention instead.
}

/**
* Emitted vectors including both values and dictionaries.
*/
class ArrowBundledVectors {
public final VectorSchemaRoot valueVectors;
public final Map<Long, Dictionary> dictionaryVectors;

public ArrowBundledVectors(VectorSchemaRoot valueVectors, Map<Long, Dictionary> dictionaryVectors) {
this.valueVectors = valueVectors;
this.dictionaryVectors = dictionaryVectors;
}
}
}
Loading

0 comments on commit c1f2ec7

Please sign in to comment.