Skip to content

Commit 514e30c

Browse files
committed
[SPARK-32220][SQL]SHUFFLE_REPLICATE_NL Hint should not change Non-Cartesian Product join result
1 parent 3b0aee3 commit 514e30c

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
244244
// 4. Pick cartesian product if join type is inner like.
245245
// 5. Pick broadcast nested loop join as the final solution. It may OOM but we don't have
246246
// other choice.
247-
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) =>
247+
case p @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) =>
248248
def createBroadcastHashJoin(buildLeft: Boolean, buildRight: Boolean) = {
249249
val wantToBuildLeft = canBuildLeft(joinType) && buildLeft
250250
val wantToBuildRight = canBuildRight(joinType) && buildRight
@@ -286,7 +286,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
286286

287287
def createCartesianProduct() = {
288288
if (joinType.isInstanceOf[InnerLike]) {
289-
Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition)))
289+
Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), p.condition)))
290290
} else {
291291
None
292292
}

sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,4 +570,31 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP
570570
assert(joinHints == expectedHints)
571571
}
572572
}
573+
574+
test("SPARK-32220: Non Cartesian Product Join Result Correct with SHUFFLE_REPLICATE_NL hint") {
575+
withTempView("t1", "t2") {
576+
Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
577+
Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")
578+
val df1 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2 ON t1.key = t2.key")
579+
val df2 = sql("SELECT * from t1 join t2 ON t1.key = t2.key")
580+
assert(df1.collect().size == df2.collect().size)
581+
582+
val df3 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2")
583+
val df4 = sql("SELECT * from t1 join t2")
584+
assert(df3.collect().size == df4.collect().size)
585+
586+
val df5 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2 ON t1.key < t2.key")
587+
val df6 = sql("SELECT * from t1 join t2 ON t1.key < t2.key")
588+
assert(df5.collect().size == df6.collect().size)
589+
590+
val df7 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2 ON t1.key < 2")
591+
val df8 = sql("SELECT * from t1 join t2 ON t1.key < 2")
592+
assert(df7.collect().size == df8.collect().size)
593+
594+
595+
val df9 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2 ON t2.key < 2")
596+
val df10 = sql("SELECT * from t1 join t2 ON t2.key < 2")
597+
assert(df9.collect().size == df10.collect().size)
598+
}
599+
}
573600
}

0 commit comments

Comments
 (0)