Skip to content

Commit aa62055

Browse files
wangyumGitHub Enterprise
authored andcommitted
[CARMEL-6243] Handle outer join build side skew (#1088)
* HandleOuterJoinBuildSideSkew * fix * handleOuterJoinBuildSideSkew * Check optimize tag * fix * Update SQLConf.scala * Update SQLConf.scala
1 parent ea35516 commit aa62055

File tree

7 files changed

+287
-1
lines changed

7 files changed

+287
-1
lines changed

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3079,6 +3079,23 @@ private[spark] object Utils extends Logging {
30793079
0
30803080
}
30813081
}
3082+
3083+
/**
3084+
* Return the median number of a long array
3085+
*
3086+
* @param sizes
3087+
* @param alreadySorted
3088+
* @return
3089+
*/
3090+
def median(sizes: Array[Long], alreadySorted: Boolean): Long = {
3091+
val len = sizes.length
3092+
val sortedSize = if (alreadySorted) sizes else sizes.sorted
3093+
len match {
3094+
case _ if (len % 2 == 0) =>
3095+
math.max((sortedSize(len / 2) + sortedSize(len / 2 - 1)) / 2, 1)
3096+
case _ => math.max(sortedSize(len / 2), 1)
3097+
}
3098+
}
30823099
}
30833100

30843101
private[util] object CallerContext extends Logging {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
183183
children.foldLeft(Option.empty[BaseType]) { (l, r) => l.orElse(r.find(f)) }
184184
}
185185

186+
/**
187+
* Test whether there is [[TreeNode]] satisfies the conditions specified in `f`.
188+
* The condition is recursively applied to this node and all of its children (pre-order).
189+
*/
190+
def exists(f: BaseType => Boolean): Boolean = if (f(this)) {
191+
true
192+
} else {
193+
children.exists(_.exists(f))
194+
}
195+
186196
/**
187197
* Runs the given function on this node and then recursively on [[children]].
188198
* @param f the function to be applied to each node in the tree.

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,20 @@ object SQLConf {
355355
.checkValue(threshold => threshold >= 0, "The maximum row count must be non-negative.")
356356
.createWithDefault(0)
357357

358+
val HANDLE_OUTER_JOIN_BUILD_SIDE_SKEW_ENABLED =
359+
buildConf("spark.sql.optimizer.handleOuterJoinBuildSideSkew.enabled")
360+
.doc("When true, enable handling outer join build side skew.")
361+
.version("3.3.0")
362+
.booleanConf
363+
.createWithDefault(false)
364+
365+
val HANDLE_OUTER_JOIN_BUILD_SIDE_SKEW_THRESHOLD =
366+
buildConf("spark.sql.optimizer.handleOuterJoinBuildSideSkewThreshold")
367+
.doc("Handle outer join build side skew threshold.")
368+
.version("3.3.0")
369+
.doubleConf
370+
.createWithDefault(200)
371+
358372
val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed")
359373
.doc("When set to true Spark SQL will automatically select a compression codec for each " +
360374
"column based on statistics of the data.")

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
787787
}
788788

789789
object BasicOperators extends Strategy {
790-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
790+
private def applyLocally(plan: LogicalPlan): Seq[SparkPlan] = plan match {
791791
case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil
792792
case i: InsertIntoDataSource =>
793793
InsertIntoDataSourceExec(planLater(i.query), i.overwrite,
@@ -929,6 +929,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
929929
planLater(r.child) :: Nil
930930
case _ => Nil
931931
}
932+
933+
def apply(plan: LogicalPlan): Seq[SparkPlan] = {
934+
val sparkPlan = applyLocally(plan)
935+
sparkPlan.foreach { p => plan.getOptimizeTags().foreach(p.addOptimizeTag) }
936+
sparkPlan
937+
}
932938
}
933939

934940
object CompactDataSourceTable extends Strategy {

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class AQEOptimizer(sparkSession: SparkSession) extends RuleExecutor[LogicalPlan]
3232
private val defaultBatches = Seq(
3333
Batch("Dynamic Join Selection", Once, DynamicJoinSelection),
3434
Batch("Adaptive Bloom Filter Join", Once, AdaptiveBloomFilterJoin(sparkSession)),
35+
Batch("Handle Outer Join Build Side Skew", Once, HandleOuterJoinBuildSideSkew),
3536
Batch("Eliminate Join to Empty Relation", Once, EliminateJoinToEmptyRelation),
3637
Batch("Optimize bloom filter Join", Once, OptimizeBloomFilterJoin)
3738
)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import org.apache.spark.internal.Logging
21+
import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Expression, Literal, PredicateHelper, ScalarSubquery, XxHash64}
22+
import org.apache.spark.sql.catalyst.expressions.aggregate.BuildBloomFilter
23+
import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, JoinSelectionHelper}
24+
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
25+
import org.apache.spark.sql.catalyst.plans.{Inner, LeftAnti, LeftOuter}
26+
import org.apache.spark.sql.catalyst.plans.logical._
27+
import org.apache.spark.sql.catalyst.rules.Rule
28+
import org.apache.spark.sql.internal.SQLConf
29+
import org.apache.spark.util.Utils
30+
31+
object HandleOuterJoinBuildSideSkew extends Rule[LogicalPlan]
32+
with JoinSelectionHelper with PredicateHelper with Logging {
33+
34+
private def insertPredicate(
35+
pruningKeys: Seq[Expression],
36+
pruningPlan: LogicalPlan,
37+
filteringKey: Seq[Expression],
38+
filteringPlan: LogicalPlan): LogicalPlan = {
39+
val filteringRowCount = filteringPlan.stats.rowCount.get
40+
// To improve build bloom filter performance.
41+
val coalesceNum = math.max(math.ceil(filteringRowCount.toDouble / 4000000.0).toInt, 1)
42+
43+
val bloomFilterAgg =
44+
new BuildBloomFilter(new XxHash64(filteringKey),
45+
math.max(filteringRowCount.toLong, 1L), true, 0, 0)
46+
val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")()
47+
val aggregate = ConstantFolding(Aggregate(Nil, Seq(alias),
48+
Repartition(coalesceNum, false, filteringPlan)))
49+
50+
val bloomFilterSubquery = ScalarSubquery(aggregate, Nil)
51+
Filter(BloomFilterMightContain(bloomFilterSubquery, new XxHash64(pruningKeys)), pruningPlan)
52+
}
53+
54+
private def containsBloomFilter(plan: LogicalPlan): Boolean = {
55+
plan.exists {
56+
case Filter(condition, _) =>
57+
splitConjunctivePredicates(condition).exists {
58+
case _: BloomFilterMightContain => true
59+
case _ => false
60+
}
61+
case _ => false
62+
}
63+
}
64+
65+
def apply(plan: LogicalPlan): LogicalPlan = {
66+
if (!conf.getConf(SQLConf.HANDLE_OUTER_JOIN_BUILD_SIDE_SKEW_ENABLED)) return plan
67+
68+
plan.transformDown {
69+
case join @ ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, _,
70+
left @ LogicalQueryStage(_, stage1: ShuffleQueryStageExec),
71+
right @ LogicalQueryStage(_, stage2: ShuffleQueryStageExec), _)
72+
if stage1.isMaterialized && stage2.isMaterialized &&
73+
!canPlanAsBroadcastHashJoin(join, conf) && !containsBloomFilter(right) =>
74+
val rightSize = stage2.mapStats.get.bytesByPartitionId
75+
val threshold = conf.getConf(SQLConf.HANDLE_OUTER_JOIN_BUILD_SIDE_SKEW_THRESHOLD)
76+
val maxBloomFilterEntries = conf.dynamicBloomFilterJoinPruningMaxBloomFilterEntries
77+
78+
if (rightSize.max > Utils.median(rightSize, false) * threshold) {
79+
// 1. Insert bloom filter
80+
val insertBF = if (left.stats.rowCount.exists(_ <= maxBloomFilterEntries)) {
81+
insertPredicate(rightKeys, right, leftKeys, left)
82+
} else {
83+
right
84+
}
85+
// TODO: 2. Insert partial aggregate
86+
val joinAttrs = join.condition.map(_.references.filter(canEvaluate(_, right)).toSeq)
87+
.getOrElse(Nil)
88+
val insertPartialAgg =
89+
if (joinAttrs.nonEmpty) PartialAggregate(joinAttrs, joinAttrs, insertBF) else insertBF
90+
91+
// Should not convert to BHJ
92+
val joinHint = JoinHint(Some(HintInfo(strategy = Some(NO_BROADCAST_HASH))), None)
93+
val union = Union(
94+
join.copy(right = insertBF, joinType = Inner, hint = joinHint),
95+
Project(left.output ++
96+
right.output.map(name => Alias(Literal(null, name.dataType), name.name)()),
97+
Join(left, insertBF, LeftAnti, join.condition, join.hint)))
98+
union.addOptimizeTag(s"created by ${this.simpleRuleName}")
99+
union
100+
} else {
101+
join
102+
}
103+
}
104+
}
105+
}
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import org.scalatest.PrivateMethodTester
21+
22+
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
23+
import org.apache.spark.sql.QueryTest
24+
import org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain
25+
import org.apache.spark.sql.execution._
26+
import org.apache.spark.sql.execution.exchange.Exchange
27+
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
28+
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
29+
import org.apache.spark.sql.internal.SQLConf
30+
import org.apache.spark.sql.test.SharedSparkSession
31+
32+
class HandleOuterJoinBuildSideSkewSuite
33+
extends QueryTest
34+
with SharedSparkSession
35+
with AdaptiveSparkPlanHelper
36+
with PrivateMethodTester {
37+
38+
39+
protected def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = {
40+
var finalPlanCnt = 0
41+
val listener = new SparkListener {
42+
override def onOtherEvent(event: SparkListenerEvent): Unit = {
43+
event match {
44+
case SparkListenerSQLAdaptiveExecutionUpdate(_, _, sparkPlanInfo) =>
45+
if (sparkPlanInfo.simpleString.startsWith(
46+
"AdaptiveSparkPlan isFinalPlan=true")) {
47+
finalPlanCnt += 1
48+
}
49+
case _ => // ignore other events
50+
}
51+
}
52+
}
53+
spark.sparkContext.addSparkListener(listener)
54+
55+
val dfAdaptive = spark.sql(query)
56+
val planBefore = dfAdaptive.queryExecution.executedPlan
57+
assert(planBefore.toString.startsWith("AdaptiveSparkPlan isFinalPlan=false"))
58+
val result = dfAdaptive.collect()
59+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
60+
val df = spark.sql(query)
61+
checkAnswer(df, result)
62+
}
63+
val planAfter = dfAdaptive.queryExecution.executedPlan
64+
assert(planAfter.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true"))
65+
val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
66+
67+
spark.sparkContext.listenerBus.waitUntilEmpty()
68+
// AQE will post `SparkListenerSQLAdaptiveExecutionUpdate` twice in case of subqueries that
69+
// exist out of query stages.
70+
val expectedFinalPlanCnt = adaptivePlan.find(_.subqueries.nonEmpty).map(_ => 2).getOrElse(1)
71+
assert(finalPlanCnt == expectedFinalPlanCnt)
72+
spark.sparkContext.removeSparkListener(listener)
73+
74+
val exchanges = adaptivePlan.collect {
75+
case e: Exchange => e
76+
}
77+
assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.")
78+
(dfAdaptive.queryExecution.sparkPlan, adaptivePlan)
79+
}
80+
81+
private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = {
82+
collect(plan) {
83+
case j: BroadcastHashJoinExec => j
84+
}
85+
}
86+
87+
private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = {
88+
collect(plan) {
89+
case j: SortMergeJoinExec => j
90+
}
91+
}
92+
93+
private def findTopLevelShuffledHashJoin(plan: SparkPlan): Seq[ShuffledHashJoinExec] = {
94+
collect(plan) {
95+
case j: ShuffledHashJoinExec => j
96+
}
97+
}
98+
99+
private def hasBloomFilterJoin(plan: SparkPlan): Seq[FilterExec] = {
100+
collectWithSubqueries(plan) {
101+
case f @ FilterExec(e, _) if e.isInstanceOf[BloomFilterMightContain] => f
102+
}
103+
}
104+
105+
test("Handle outer join build side skew suite") {
106+
withSQLConf(
107+
SQLConf.HANDLE_OUTER_JOIN_BUILD_SIDE_SKEW_ENABLED.key -> "true",
108+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
109+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
110+
SQLConf.HANDLE_OUTER_JOIN_BUILD_SIDE_SKEW_THRESHOLD.key -> "5") {
111+
withTable("t1", "t2") {
112+
spark.range(10).selectExpr("id as a", "id as b", "id as c").write.saveAsTable("t1")
113+
spark.range(10000).selectExpr("1 as a", "id as b", "id as c").write.saveAsTable("t2")
114+
115+
sql("insert into t1 values(null, null, null)")
116+
sql("insert into t2 values(null, null, null)")
117+
118+
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
119+
"select * from t1 left join t2 on t1.a = t2.a")
120+
121+
assert(findTopLevelSortMergeJoin(plan).size === 1)
122+
assert(hasBloomFilterJoin(plan).size === 0)
123+
assert(collect(plan) { case j: UnionExec => j }.size === 0)
124+
assert(findTopLevelSortMergeJoin(adaptivePlan).size === 2)
125+
assert(hasBloomFilterJoin(adaptivePlan).size === 2)
126+
assert(collect(adaptivePlan) { case j: UnionExec => j }.size === 1)
127+
128+
// Check optimize tag
129+
assert(adaptivePlan.toString.contains("created by HandleOuterJoinBuildSideSkew"))
130+
}
131+
}
132+
}
133+
}

0 commit comments

Comments
 (0)