Skip to content

Commit 6708f2a

Browse files
committed
[SPARK-38959][SQL] DataSource V2: Support runtime group filtering in row-level commands
1 parent 03ef022 commit 6708f2a

File tree

8 files changed

+295
-16
lines changed

8 files changed

+295
-16
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.spark.sql.connector.expressions.NamedReference;
2222
import org.apache.spark.sql.connector.read.Scan;
2323
import org.apache.spark.sql.connector.read.ScanBuilder;
24+
import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering;
2425
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
2526

2627
/**
@@ -68,6 +69,19 @@ default String description() {
6869
* be returned by the scan, even if a filter can narrow the set of changes to a single file
6970
* in the partition. Similarly, a data source that can swap individual files must produce all
7071
* rows from files where at least one record must be changed, not just rows that must be changed.
72+
* <p>
73+
* Data sources that replace groups of data (e.g. files, partitions) may prune entire groups
74+
* using provided data source filters when building a scan for this row-level operation.
75+
* However, such data skipping is limited as not all expressions can be converted into data source
76+
* filters and some can only be evaluated by Spark (e.g. subqueries). Since rewriting groups is
77+
* expensive, Spark allows group-based data sources to filter groups at runtime. The runtime
78+
* filtering enables data sources to narrow down the scope of rewriting to only groups that must
79+
* be rewritten. If the row-level operation scan implements {@link SupportsRuntimeV2Filtering},
80+
* Spark will execute a query at runtime to find which records match the row-level condition.
81+
* The runtime group filter subquery will leverage a regular batch scan, which isn't required to
82+
* produce all rows in a group if any are returned. The information about matching records will
83+
* be passed back into the row-level operation scan, allowing data sources to discard groups
84+
* that don't have to be rewritten.
7185
*/
7286
ScanBuilder newScanBuilder(CaseInsensitiveStringMap options);
7387

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,21 @@ object SQLConf {
412412
.longConf
413413
.createWithDefault(67108864L)
414414

415+
val RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED =
416+
buildConf("spark.sql.optimizer.runtime.rowLevelOperationGroupFilter.enabled")
417+
.doc("Enables runtime group filtering for group-based row-level operations. " +
418+
"Data sources that replace groups of data (e.g. files, partitions) may prune entire " +
419+
"groups using provided data source filters when planning a row-level operation scan. " +
420+
"However, such filtering is limited as not all expressions can be converted into data " +
421+
"source filters and some expressions can only be evaluated by Spark (e.g. subqueries). " +
422+
"Since rewriting groups is expensive, Spark can execute a query at runtime to find what " +
423+
"records match the condition of the row-level operation. The information about matching " +
424+
"records will be passed back to the row-level operation scan, allowing data sources to " +
425+
"discard groups that don't have to be rewritten.")
426+
.version("3.4.0")
427+
.booleanConf
428+
.createWithDefault(true)
429+
415430
val PLANNED_WRITE_ENABLED = buildConf("spark.sql.optimizer.plannedWrite.enabled")
416431
.internal()
417432
.doc("When set to true, Spark optimizer will add logical sort operators to V1 write commands " +
@@ -4084,6 +4099,9 @@ class SQLConf extends Serializable with Logging {
40844099
def runtimeFilterCreationSideThreshold: Long =
40854100
getConf(RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD)
40864101

4102+
def runtimeRowLevelOperationGroupFilterEnabled: Boolean =
4103+
getConf(RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED)
4104+
40874105
def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS)
40884106

40894107
def isStateSchemaCheckEnabled: Boolean = getConf(STATE_SCHEMA_CHECK_ENABLED)

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class InMemoryRowLevelOperationTable(
3434
properties: util.Map[String, String])
3535
extends InMemoryTable(name, schema, partitioning, properties) with SupportsRowLevelOperations {
3636

37+
var replacedPartitions: Seq[Seq[Any]] = Seq.empty
38+
3739
override def newRowLevelOperationBuilder(
3840
info: RowLevelOperationInfo): RowLevelOperationBuilder = {
3941
() => PartitionBasedOperation(info.command)
@@ -88,8 +90,9 @@ class InMemoryRowLevelOperationTable(
8890
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
8991
val newData = messages.map(_.asInstanceOf[BufferedRows])
9092
val readRows = scan.data.flatMap(_.asInstanceOf[BufferedRows].rows)
91-
val readPartitions = readRows.map(r => getKey(r, schema))
93+
val readPartitions = readRows.map(r => getKey(r, schema)).distinct
9294
dataMap --= readPartitions
95+
replacedPartitions = readPartitions
9396
withData(newData, schema)
9497
}
9598
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
2525
import org.apache.spark.sql.connector.catalog.CatalogManager
2626
import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes}
2727
import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes}
28-
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning}
28+
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning, RowLevelOperationRuntimeGroupFiltering}
2929
import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
3030

3131
class SparkOptimizer(
@@ -50,7 +50,8 @@ class SparkOptimizer(
5050
override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
5151
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
5252
Batch("PartitionPruning", Once,
53-
PartitionPruning) :+
53+
PartitionPruning,
54+
RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)) :+
5455
Batch("InjectRuntimeFilter", FixedPoint(1),
5556
InjectRuntimeFilter) :+
5657
Batch("MergeScalarSubqueries", Once,

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashedRelati
3232
case class PlanAdaptiveDynamicPruningFilters(
3333
rootPlan: AdaptiveSparkPlanExec) extends Rule[SparkPlan] with AdaptiveSparkPlanHelper {
3434
def apply(plan: SparkPlan): SparkPlan = {
35-
if (!conf.dynamicPartitionPruningEnabled) {
35+
if (!conf.dynamicPartitionPruningEnabled && !conf.runtimeRowLevelOperationGroupFilterEnabled) {
3636
return plan
3737
}
3838

sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp
4545
}
4646

4747
override def apply(plan: SparkPlan): SparkPlan = {
48-
if (!conf.dynamicPartitionPruningEnabled) {
48+
if (!conf.dynamicPartitionPruningEnabled && !conf.runtimeRowLevelOperationGroupFilterEnabled) {
4949
return plan
5050
}
5151

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.dynamicpruning
19+
20+
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruningSubquery, Expression, PredicateHelper, V2ExpressionUtils}
21+
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
22+
import org.apache.spark.sql.catalyst.planning.GroupBasedRowLevelOperation
23+
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering
26+
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation, DataSourceV2ScanRelation}
27+
28+
/**
29+
* A rule that assigns a subquery to filter groups in row-level operations at runtime.
30+
*
31+
* Data skipping during job planning for row-level operations is limited to expressions that can be
32+
* converted to data source filters. Since not all expressions can be pushed down that way and
33+
* rewriting groups is expensive, Spark allows data sources to filter group at runtime.
34+
* If the primary scan in a group-based row-level operation supports runtime filtering, this rule
35+
* will inject a subquery to find all rows that match the condition so that data sources know
36+
* exactly which groups must be rewritten.
37+
*
38+
* Note this rule only applies to group-based row-level operations.
39+
*/
40+
case class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPlan])
41+
extends Rule[LogicalPlan] with PredicateHelper {
42+
43+
import DataSourceV2Implicits._
44+
45+
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
46+
// apply special dynamic filtering only for group-based row-level operations
47+
case GroupBasedRowLevelOperation(replaceData, cond,
48+
DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _))
49+
if conf.runtimeRowLevelOperationGroupFilterEnabled && cond != TrueLiteral =>
50+
51+
// use reference equality on scan to find required scan relations
52+
val newQuery = replaceData.query transformUp {
53+
case r: DataSourceV2ScanRelation if r.scan eq scan =>
54+
// use the original table instance that was loaded for this row-level operation
55+
// in order to leverage a regular batch scan in the group filter query
56+
val originalTable = r.relation.table.asRowLevelOperationTable.table
57+
val relation = r.relation.copy(table = originalTable)
58+
val matchingRowsPlan = buildMatchingRowsPlan(relation, cond)
59+
60+
val filterAttrs = scan.filterAttributes
61+
val buildKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, matchingRowsPlan)
62+
val pruningKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, r)
63+
val dynamicPruningCond = buildDynamicPruningCond(matchingRowsPlan, buildKeys, pruningKeys)
64+
65+
Filter(dynamicPruningCond, r)
66+
}
67+
68+
// optimize subqueries to rewrite them as joins and trigger job planning
69+
replaceData.copy(query = optimizeSubqueries(newQuery))
70+
}
71+
72+
private def buildMatchingRowsPlan(
73+
relation: DataSourceV2Relation,
74+
cond: Expression): LogicalPlan = {
75+
76+
val matchingRowsPlan = Filter(cond, relation)
77+
78+
// clone the relation and assign new expr IDs to avoid conflicts
79+
matchingRowsPlan transformUpWithNewOutput {
80+
case r: DataSourceV2Relation if r eq relation =>
81+
val oldOutput = r.output
82+
val newOutput = oldOutput.map(_.newInstance())
83+
r.copy(output = newOutput) -> oldOutput.zip(newOutput)
84+
}
85+
}
86+
87+
private def buildDynamicPruningCond(
88+
matchingRowsPlan: LogicalPlan,
89+
buildKeys: Seq[Attribute],
90+
pruningKeys: Seq[Attribute]): Expression = {
91+
92+
val buildQuery = Project(buildKeys, matchingRowsPlan)
93+
val dynamicPruningSubqueries = pruningKeys.zipWithIndex.map { case (key, index) =>
94+
DynamicPruningSubquery(key, buildQuery, buildKeys, index, onlyInBroadcast = false)
95+
}
96+
dynamicPruningSubqueries.reduce(And)
97+
}
98+
}

0 commit comments

Comments
 (0)