Skip to content

Commit 0f3b68b

Browse files
committed
Spark 3.5, 4.0: Support Spark Partial Limit Push Down
1 parent 21e6e41 commit 0f3b68b

File tree

19 files changed

+1083
-33
lines changed

19 files changed

+1083
-33
lines changed

api/src/main/java/org/apache/iceberg/FileScanTask.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,8 @@ default boolean isFileScanTask() {
5454
default FileScanTask asFileScanTask() {
5555
return this;
5656
}
57+
58+
default long minRecordCountEstimate() {
59+
return deletes().isEmpty() ? file().recordCount() : 0;
60+
}
5761
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.iceberg.events;
20+
21+
/** Event sent to listeners when a scan task is planned, used for limit push down tracking. */
22+
public final class LimitAwareScanTaskEvent {
23+
private final String taskName;
24+
private final long minRecordCountEstimate;
25+
26+
public LimitAwareScanTaskEvent(String taskName, long minRecordCountEstimate) {
27+
this.taskName = taskName;
28+
this.minRecordCountEstimate = minRecordCountEstimate;
29+
}
30+
31+
public String taskName() {
32+
return taskName;
33+
}
34+
35+
public long minRecordCountEstimate() {
36+
return minRecordCountEstimate;
37+
}
38+
}

core/src/main/java/org/apache/iceberg/util/TableScanUtil.java

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
import org.apache.iceberg.ScanTaskGroup;
3636
import org.apache.iceberg.SplittableScanTask;
3737
import org.apache.iceberg.StructLike;
38+
import org.apache.iceberg.events.LimitAwareScanTaskEvent;
39+
import org.apache.iceberg.events.Listeners;
3840
import org.apache.iceberg.io.CloseableIterable;
3941
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
4042
import org.apache.iceberg.relocated.com.google.common.collect.FluentIterable;
@@ -107,25 +109,35 @@ public static <T extends ScanTask> List<ScanTaskGroup<T>> planTaskGroups(
107109
planTaskGroups(CloseableIterable.withNoopClose(tasks), splitSize, lookback, openFileCost));
108110
}
109111

110-
@SuppressWarnings("unchecked")
111112
public static <T extends ScanTask> CloseableIterable<ScanTaskGroup<T>> planTaskGroups(
112113
CloseableIterable<T> tasks, long splitSize, int lookback, long openFileCost) {
114+
return planTaskGroups(tasks, splitSize, lookback, openFileCost, 0);
115+
}
113116

114-
validatePlanningArguments(splitSize, lookback, openFileCost);
117+
@SuppressWarnings("unchecked")
118+
public static <T extends ScanTask> CloseableIterable<ScanTaskGroup<T>> planTaskGroups(
119+
CloseableIterable<T> tasks, long splitSize, int lookback, long openFileCost, int limit) {
115120

116-
// capture manifests which can be closed after scan planning
117-
CloseableIterable<T> splitTasks =
118-
CloseableIterable.combine(
119-
FluentIterable.from(tasks)
120-
.transformAndConcat(
121-
task -> {
122-
if (task instanceof SplittableScanTask<?>) {
123-
return ((SplittableScanTask<? extends T>) task).split(splitSize);
124-
} else {
125-
return ImmutableList.of(task);
126-
}
127-
}),
128-
tasks);
121+
validatePlanningArguments(splitSize, lookback, openFileCost);
122+
CloseableIterable<T> splitTasks;
123+
if (limit > 0) {
124+
// optimize scan planning by stopping early when estimated row count reaches limit
125+
splitTasks = splitScanTasksWithLimitPushDown(tasks, splitSize, limit);
126+
} else {
127+
// capture manifests which can be closed after scan planning
128+
splitTasks =
129+
CloseableIterable.combine(
130+
FluentIterable.from(tasks)
131+
.transformAndConcat(
132+
task -> {
133+
if (task instanceof SplittableScanTask<?>) {
134+
return ((SplittableScanTask<? extends T>) task).split(splitSize);
135+
} else {
136+
return ImmutableList.of(task);
137+
}
138+
}),
139+
tasks);
140+
}
129141

130142
Function<T, Long> weightFunc =
131143
task -> Math.max(task.sizeBytes(), task.filesCount() * openFileCost);
@@ -249,4 +261,48 @@ private static void validatePlanningArguments(long splitSize, int lookback, long
249261
Preconditions.checkArgument(lookback > 0, "Split planning lookback must be > 0: %s", lookback);
250262
Preconditions.checkArgument(openFileCost >= 0, "File open cost must be >= 0: %s", openFileCost);
251263
}
264+
265+
private static <T> CloseableIterable<T> splitScanTasksWithLimitPushDown(
266+
CloseableIterable<T> tasks, long splitSize, int limit) {
267+
268+
List<T> candidateTasks = Lists.newArrayList();
269+
long remainingLimit = limit;
270+
271+
for (T task : tasks) {
272+
if (task instanceof SplittableScanTask<?>) {
273+
@SuppressWarnings("unchecked")
274+
SplittableScanTask<? extends T> splittable = (SplittableScanTask<? extends T>) task;
275+
for (T splitTask : splittable.split(splitSize)) {
276+
candidateTasks.add(splitTask);
277+
remainingLimit = updateRemainingLimit(splitTask, remainingLimit);
278+
if (remainingLimit <= 0) {
279+
break;
280+
}
281+
}
282+
} else {
283+
candidateTasks.add(task);
284+
remainingLimit = updateRemainingLimit(task, remainingLimit);
285+
}
286+
287+
if (remainingLimit <= 0) {
288+
break;
289+
}
290+
}
291+
292+
return CloseableIterable.combine(candidateTasks, tasks);
293+
}
294+
295+
@SuppressWarnings("unchecked")
296+
private static <T> long updateRemainingLimit(T task, long remainingLimit) {
297+
if (task instanceof ScanTask) {
298+
ScanTask scanTask = (ScanTask) task;
299+
if (scanTask.isFileScanTask()) {
300+
FileScanTask fileTask = scanTask.asFileScanTask();
301+
long estimate = fileTask.minRecordCountEstimate();
302+
Listeners.notifyAll(new LimitAwareScanTaskEvent(fileTask.toString(), estimate));
303+
return remainingLimit - estimate;
304+
}
305+
}
306+
return remainingLimit;
307+
}
252308
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.iceberg.spark;
20+
21+
import com.google.errorprone.annotations.FormatMethod;
22+
import com.google.errorprone.annotations.FormatString;
23+
import java.util.Locale;
24+
import java.util.concurrent.TimeUnit;
25+
import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions;
26+
import org.apache.spark.sql.Dataset;
27+
import org.apache.spark.sql.Row;
28+
import org.apache.spark.sql.SparkSession;
29+
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
30+
import org.apache.spark.sql.catalyst.parser.ParseException;
31+
import org.openjdk.jmh.annotations.Benchmark;
32+
import org.openjdk.jmh.annotations.BenchmarkMode;
33+
import org.openjdk.jmh.annotations.Fork;
34+
import org.openjdk.jmh.annotations.Measurement;
35+
import org.openjdk.jmh.annotations.Mode;
36+
import org.openjdk.jmh.annotations.Param;
37+
import org.openjdk.jmh.annotations.Scope;
38+
import org.openjdk.jmh.annotations.Setup;
39+
import org.openjdk.jmh.annotations.State;
40+
import org.openjdk.jmh.annotations.TearDown;
41+
import org.openjdk.jmh.annotations.Threads;
42+
import org.openjdk.jmh.annotations.Timeout;
43+
import org.openjdk.jmh.annotations.Warmup;
44+
import org.openjdk.jmh.infra.Blackhole;
45+
46+
/**
47+
* A benchmark that evaluates the limit push down performance.
48+
*
49+
* <p>To run this benchmark for spark-3.5: <code>
50+
* ./gradlew -DsparkVersions=3.5 :iceberg-spark:iceberg-spark-extensions-3.5_2.12:jmh
51+
* -PjmhIncludeRegex=LimitPushDownBenchmark
52+
* -PjmhOutputPath=benchmark/iceberg-limit-push-down-benchmark.txt
53+
* </code>
54+
*/
55+
@Fork(1)
56+
@State(Scope.Benchmark)
57+
@Warmup(iterations = 3)
58+
@Measurement(iterations = 5)
59+
@Timeout(time = 10, timeUnit = TimeUnit.MINUTES)
60+
@BenchmarkMode(Mode.AverageTime)
61+
public class LimitPushDownBenchmark {
62+
63+
private static final String TABLE_NAME = "test_limit_table";
64+
65+
@Param({"100", "1000", "10000"})
66+
private int limitValue;
67+
68+
@Param({"true", "false"})
69+
private boolean limitPushDownEnabled;
70+
71+
private SparkSession spark;
72+
73+
@Setup
74+
public void setupBenchmark() throws NoSuchTableException, ParseException {
75+
setupSpark();
76+
setupTable();
77+
}
78+
79+
@TearDown
80+
public void tearDownBenchmark() {
81+
dropTable();
82+
tearDownSpark();
83+
}
84+
85+
@Benchmark
86+
@Threads(1)
87+
public void limitQuery(Blackhole blackhole) {
88+
spark
89+
.conf()
90+
.set("spark.sql.iceberg.limit-push-down.enabled", String.valueOf(limitPushDownEnabled));
91+
92+
Dataset<Row> result =
93+
spark.sql(
94+
String.format(Locale.ROOT, "SELECT * FROM local.%s LIMIT %d", TABLE_NAME, limitValue));
95+
96+
blackhole.consume(result.count());
97+
}
98+
99+
@Benchmark
100+
@Threads(1)
101+
public void limitQueryWithPartitionPruning(Blackhole blackhole) {
102+
spark
103+
.conf()
104+
.set("spark.sql.iceberg.limit-push-down.enabled", String.valueOf(limitPushDownEnabled));
105+
106+
Dataset<Row> result =
107+
spark.sql(
108+
String.format(
109+
Locale.ROOT,
110+
"SELECT * FROM local.%s WHERE category != '0' LIMIT %d",
111+
TABLE_NAME,
112+
limitValue));
113+
114+
blackhole.consume(result.count());
115+
}
116+
117+
private void setupSpark() {
118+
this.spark =
119+
SparkSession.builder()
120+
.appName("limit-push-down-benchmark")
121+
.master("local[1]")
122+
.config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName())
123+
.config("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog")
124+
.config("spark.sql.catalog.local.type", "hadoop")
125+
.config("spark.sql.catalog.local.warehouse", "/tmp/iceberg-benchmark-warehouse")
126+
.getOrCreate();
127+
}
128+
129+
private void setupTable() {
130+
sql("DROP TABLE IF EXISTS local.%s PURGE", TABLE_NAME);
131+
132+
sql(
133+
"CREATE TABLE local.%s (id BIGINT, data STRING, value DOUBLE, category STRING) "
134+
+ "USING iceberg PARTITIONED BY (id)",
135+
TABLE_NAME);
136+
137+
// Insert substantial data across multiple partitions
138+
for (int partition = 0; partition < 100; partition++) {
139+
StringBuilder values = new StringBuilder();
140+
for (int row = 0; row < 5000; row++) {
141+
int id = partition * 5000 + row;
142+
if (values.length() > 0) {
143+
values.append(", ");
144+
}
145+
values.append(
146+
String.format(
147+
Locale.ROOT,
148+
"(%d, 'data_%d', %f, 'category_%d')",
149+
id,
150+
id,
151+
Math.random() * 1000,
152+
partition));
153+
}
154+
sql("INSERT INTO local.%s VALUES %s", TABLE_NAME, values.toString());
155+
}
156+
}
157+
158+
private void dropTable() {
159+
sql("DROP TABLE IF EXISTS local.%s PURGE", TABLE_NAME);
160+
}
161+
162+
private void tearDownSpark() {
163+
if (spark != null) {
164+
spark.stop();
165+
}
166+
}
167+
168+
@FormatMethod
169+
private void sql(@FormatString String query, Object... args) {
170+
spark.sql(String.format(query, args));
171+
}
172+
}

spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,14 @@ public boolean aggregatePushDownEnabled() {
277277
.parse();
278278
}
279279

280+
public boolean limitPushDownEnabled() {
281+
return confParser
282+
.booleanConf()
283+
.sessionConf(SparkSQLProperties.LIMIT_PUSH_DOWN_ENABLED)
284+
.defaultValue(SparkSQLProperties.LIMIT_PUSH_DOWN_ENABLED_DEFAULT)
285+
.parse();
286+
}
287+
280288
public boolean adaptiveSplitSizeEnabled() {
281289
return confParser
282290
.booleanConf()

spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,8 @@ private SparkSQLProperties() {}
101101
// Controls whether to report available column statistics to Spark for query optimization.
102102
public static final String REPORT_COLUMN_STATS = "spark.sql.iceberg.report-column-stats";
103103
public static final boolean REPORT_COLUMN_STATS_DEFAULT = true;
104+
105+
// Controls whether to push down limit to Iceberg scan planning
106+
public static final String LIMIT_PUSH_DOWN_ENABLED = "spark.sql.iceberg.limit-push-down.enabled";
107+
public static final boolean LIMIT_PUSH_DOWN_ENABLED_DEFAULT = true;
104108
}

spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ class SparkBatchQueryScan extends SparkPartitioningAwareScan<PartitionScanTask>
8181
SparkReadConf readConf,
8282
Schema expectedSchema,
8383
List<Expression> filters,
84-
Supplier<ScanReport> scanReportSupplier) {
85-
super(spark, table, scan, readConf, expectedSchema, filters, scanReportSupplier);
84+
Supplier<ScanReport> scanReportSupplier,
85+
int pushedLimit) {
86+
super(spark, table, scan, readConf, expectedSchema, filters, scanReportSupplier, pushedLimit);
8687

8788
this.snapshotId = readConf.snapshotId();
8889
this.startSnapshotId = readConf.startSnapshotId();

0 commit comments

Comments
 (0)