Skip to content

Commit e3f0af4

Browse files
davisusanibarwestonpacebenibuslidavidmdanepitkin
authored andcommitted
GH-34252: [Java] Support ScannerBuilder::Project or ScannerBuilder::Filter as a Substrait proto extended expression (#35570)
### Rationale for this change To close apache/arrow#34252 ### What changes are included in this PR? This is a proposal to try to solve: 1. Receive a list of Substrait scalar expressions and use them to Project a Dataset - [x] Draft a Substrait Extended Expression to test (this will be generated by 3rd party project such as Isthmus) - [x] Use C++ draft PR to Serialize/Deserialize Extended Expression proto messages - [x] Create JNI Wrapper for ScannerBuilder::Project - [x] Create JNI API - [x] Testing coverage - [x] Documentation Current problem is: `java.lang.RuntimeException: Inferring column projection from FieldRef FieldRef.FieldPath(0)`. Not able to infer by column position by able to infer by colum name. This problem is solved by apache/arrow#35798 This PR needs/use this PRs/Issues: - apache/arrow#34834 - apache/arrow#34227 - apache/arrow#35579 2. Receive a Boolean-valued Substrait scalar expression and use it to filter a Dataset - [x] Working to identify activities ### Are these changes tested? Initial unit test added. ### Are there any user-facing changes? No * Closes: #34252 Lead-authored-by: david dali susanibar arce <davi.sarces@gmail.com> Co-authored-by: Weston Pace <weston.pace@gmail.com> Co-authored-by: benibus <bpharks@gmx.com> Co-authored-by: David Li <li.davidm96@gmail.com> Co-authored-by: Dane Pitkin <48041712+danepitkin@users.noreply.github.com> Signed-off-by: David Li <li.davidm96@gmail.com>
1 parent 0d0c045 commit e3f0af4

File tree

5 files changed

+398
-16
lines changed

5 files changed

+398
-16
lines changed

dataset/src/main/cpp/jni_wrapper.cc

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include "arrow/filesystem/path_util.h"
3030
#include "arrow/filesystem/s3fs.h"
3131
#include "arrow/engine/substrait/util.h"
32+
#include "arrow/engine/substrait/serde.h"
33+
#include "arrow/engine/substrait/relation.h"
3234
#include "arrow/ipc/api.h"
3335
#include "arrow/util/iterator.h"
3436
#include "jni_util.h"
@@ -200,7 +202,6 @@ arrow::Result<std::shared_ptr<arrow::Schema>> SchemaFromColumnNames(
200202
return arrow::Status::Invalid("Partition column '", ref.ToString(), "' is not in dataset schema");
201203
}
202204
}
203-
204205
return schema(std::move(columns))->WithMetadata(input->metadata());
205206
}
206207
} // namespace
@@ -317,6 +318,14 @@ std::shared_ptr<arrow::Table> GetTableByName(const std::vector<std::string>& nam
317318
return it->second;
318319
}
319320

321+
std::shared_ptr<arrow::Buffer> LoadArrowBufferFromByteBuffer(JNIEnv* env, jobject byte_buffer) {
322+
const auto *buff = reinterpret_cast<jbyte*>(env->GetDirectBufferAddress(byte_buffer));
323+
int length = env->GetDirectBufferCapacity(byte_buffer);
324+
std::shared_ptr<arrow::Buffer> buffer = JniGetOrThrow(arrow::AllocateBuffer(length));
325+
std::memcpy(buffer->mutable_data(), buff, length);
326+
return buffer;
327+
}
328+
320329
/*
321330
* Class: org_apache_arrow_dataset_jni_NativeMemoryPool
322331
* Method: getDefaultMemoryPool
@@ -455,11 +464,12 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDataset
455464
/*
456465
* Class: org_apache_arrow_dataset_jni_JniWrapper
457466
* Method: createScanner
458-
* Signature: (J[Ljava/lang/String;JJ)J
467+
* Signature: (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ)J
459468
*/
460469
JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScanner(
461-
JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns, jlong batch_size,
462-
jlong memory_pool_id) {
470+
JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns,
471+
jobject substrait_projection, jobject substrait_filter,
472+
jlong batch_size, jlong memory_pool_id) {
463473
JNI_METHOD_START
464474
arrow::MemoryPool* pool = reinterpret_cast<arrow::MemoryPool*>(memory_pool_id);
465475
if (pool == nullptr) {
@@ -474,6 +484,40 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScann
474484
std::vector<std::string> column_vector = ToStringVector(env, columns);
475485
JniAssertOkOrThrow(scanner_builder->Project(column_vector));
476486
}
487+
if (substrait_projection != nullptr) {
488+
std::shared_ptr<arrow::Buffer> buffer = LoadArrowBufferFromByteBuffer(env,
489+
substrait_projection);
490+
std::vector<arrow::compute::Expression> project_exprs;
491+
std::vector<std::string> project_names;
492+
arrow::engine::BoundExpressions bounded_expression =
493+
JniGetOrThrow(arrow::engine::DeserializeExpressions(*buffer));
494+
for(arrow::engine::NamedExpression& named_expression :
495+
bounded_expression.named_expressions) {
496+
project_exprs.push_back(std::move(named_expression.expression));
497+
project_names.push_back(std::move(named_expression.name));
498+
}
499+
JniAssertOkOrThrow(scanner_builder->Project(std::move(project_exprs), std::move(project_names)));
500+
}
501+
if (substrait_filter != nullptr) {
502+
std::shared_ptr<arrow::Buffer> buffer = LoadArrowBufferFromByteBuffer(env,
503+
substrait_filter);
504+
std::optional<arrow::compute::Expression> filter_expr = std::nullopt;
505+
arrow::engine::BoundExpressions bounded_expression =
506+
JniGetOrThrow(arrow::engine::DeserializeExpressions(*buffer));
507+
for(arrow::engine::NamedExpression& named_expression :
508+
bounded_expression.named_expressions) {
509+
filter_expr = named_expression.expression;
510+
if (named_expression.expression.type()->id() == arrow::Type::BOOL) {
511+
filter_expr = named_expression.expression;
512+
} else {
513+
JniThrow("There is no filter expression in the expression provided");
514+
}
515+
}
516+
if (filter_expr == std::nullopt) {
517+
JniThrow("The filter expression has not been provided");
518+
}
519+
JniAssertOkOrThrow(scanner_builder->Filter(*filter_expr));
520+
}
477521
JniAssertOkOrThrow(scanner_builder->BatchSize(batch_size));
478522

479523
auto scanner = JniGetOrThrow(scanner_builder->Finish());
@@ -748,10 +792,7 @@ JNIEXPORT void JNICALL
748792
arrow::engine::ConversionOptions conversion_options;
749793
conversion_options.named_table_provider = std::move(table_provider);
750794
// mapping arrow::Buffer
751-
auto *buff = reinterpret_cast<jbyte*>(env->GetDirectBufferAddress(plan));
752-
int length = env->GetDirectBufferCapacity(plan);
753-
std::shared_ptr<arrow::Buffer> buffer = JniGetOrThrow(arrow::AllocateBuffer(length));
754-
std::memcpy(buffer->mutable_data(), buff, length);
795+
std::shared_ptr<arrow::Buffer> buffer = LoadArrowBufferFromByteBuffer(env, plan);
755796
// execute plan
756797
std::shared_ptr<arrow::RecordBatchReader> reader_out =
757798
JniGetOrThrow(arrow::engine::ExecuteSerializedPlan(*buffer, nullptr, nullptr, conversion_options));

dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.arrow.dataset.jni;
1919

20+
import java.nio.ByteBuffer;
21+
2022
/**
2123
* JNI wrapper for Dataset API's native implementation.
2224
*/
@@ -66,15 +68,19 @@ private JniWrapper() {
6668

6769
/**
6870
* Create Scanner from a Dataset and get the native pointer of the Dataset.
71+
*
6972
* @param datasetId the native pointer of the arrow::dataset::Dataset instance.
7073
* @param columns desired column names.
7174
* Columns not in this list will not be emitted when performing scan operation. Null equals
7275
* to "all columns".
76+
* @param substraitProjection substrait extended expression to evaluate for project new columns
77+
* @param substraitFilter substrait extended expression to evaluate for apply filter
7378
* @param batchSize batch size of scanned record batches.
7479
* @param memoryPool identifier of memory pool used in the native scanner.
7580
* @return the native pointer of the arrow::dataset::Scanner instance.
7681
*/
77-
public native long createScanner(long datasetId, String[] columns, long batchSize, long memoryPool);
82+
public native long createScanner(long datasetId, String[] columns, ByteBuffer substraitProjection,
83+
ByteBuffer substraitFilter, long batchSize, long memoryPool);
7884

7985
/**
8086
* Get a serialized schema from native instance of a Scanner.

dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,12 @@ public synchronized NativeScanner newScan(ScanOptions options) {
4040
if (closed) {
4141
throw new NativeInstanceReleasedException();
4242
}
43+
4344
long scannerId = JniWrapper.get().createScanner(datasetId, options.getColumns().orElse(null),
45+
options.getSubstraitProjection().orElse(null),
46+
options.getSubstraitFilter().orElse(null),
4447
options.getBatchSize(), context.getMemoryPool().getNativeInstanceId());
48+
4549
return new NativeScanner(context, scannerId);
4650
}
4751

dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.arrow.dataset.scanner;
1919

20+
import java.nio.ByteBuffer;
2021
import java.util.Optional;
2122

2223
import org.apache.arrow.util.Preconditions;
@@ -25,8 +26,10 @@
2526
* Options used during scanning.
2627
*/
2728
public class ScanOptions {
28-
private final Optional<String[]> columns;
2929
private final long batchSize;
30+
private final Optional<String[]> columns;
31+
private final Optional<ByteBuffer> substraitProjection;
32+
private final Optional<ByteBuffer> substraitFilter;
3033

3134
/**
3235
* Constructor.
@@ -56,6 +59,8 @@ public ScanOptions(long batchSize, Optional<String[]> columns) {
5659
Preconditions.checkNotNull(columns);
5760
this.batchSize = batchSize;
5861
this.columns = columns;
62+
this.substraitProjection = Optional.empty();
63+
this.substraitFilter = Optional.empty();
5964
}
6065

6166
public ScanOptions(long batchSize) {
@@ -69,4 +74,77 @@ public Optional<String[]> getColumns() {
6974
public long getBatchSize() {
7075
return batchSize;
7176
}
77+
78+
public Optional<ByteBuffer> getSubstraitProjection() {
79+
return substraitProjection;
80+
}
81+
82+
public Optional<ByteBuffer> getSubstraitFilter() {
83+
return substraitFilter;
84+
}
85+
86+
/**
87+
* Builder for Options used during scanning.
88+
*/
89+
public static class Builder {
90+
private final long batchSize;
91+
private Optional<String[]> columns;
92+
private ByteBuffer substraitProjection;
93+
private ByteBuffer substraitFilter;
94+
95+
/**
96+
* Constructor.
97+
* @param batchSize Maximum row number of each returned {@link org.apache.arrow.vector.ipc.message.ArrowRecordBatch}
98+
*/
99+
public Builder(long batchSize) {
100+
this.batchSize = batchSize;
101+
}
102+
103+
/**
104+
* Set the Projected columns. Empty for scanning all columns.
105+
*
106+
* @param columns Projected columns. Empty for scanning all columns.
107+
* @return the ScanOptions configured.
108+
*/
109+
public Builder columns(Optional<String[]> columns) {
110+
Preconditions.checkNotNull(columns);
111+
this.columns = columns;
112+
return this;
113+
}
114+
115+
/**
116+
* Set the Substrait extended expression for Projection new columns.
117+
*
118+
* @param substraitProjection Expressions to evaluate for project new columns.
119+
* @return the ScanOptions configured.
120+
*/
121+
public Builder substraitProjection(ByteBuffer substraitProjection) {
122+
Preconditions.checkNotNull(substraitProjection);
123+
this.substraitProjection = substraitProjection;
124+
return this;
125+
}
126+
127+
/**
128+
* Set the Substrait extended expression for Filter.
129+
*
130+
* @param substraitFilter Expressions to evaluate for apply Filter.
131+
* @return the ScanOptions configured.
132+
*/
133+
public Builder substraitFilter(ByteBuffer substraitFilter) {
134+
Preconditions.checkNotNull(substraitFilter);
135+
this.substraitFilter = substraitFilter;
136+
return this;
137+
}
138+
139+
public ScanOptions build() {
140+
return new ScanOptions(this);
141+
}
142+
}
143+
144+
private ScanOptions(Builder builder) {
145+
batchSize = builder.batchSize;
146+
columns = builder.columns;
147+
substraitProjection = Optional.ofNullable(builder.substraitProjection);
148+
substraitFilter = Optional.ofNullable(builder.substraitFilter);
149+
}
72150
}

0 commit comments

Comments
 (0)