diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index a983bc850057..e041c54c61db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -244,7 +244,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // 4. Pick cartesian product if join type is inner like. // 5. Pick broadcast nested loop join as the final solution. It may OOM but we don't have // other choice. - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) => + case p @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) => def createBroadcastHashJoin(buildLeft: Boolean, buildRight: Boolean) = { val wantToBuildLeft = canBuildLeft(joinType) && buildLeft val wantToBuildRight = canBuildRight(joinType) && buildRight @@ -286,7 +286,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def createCartesianProduct() = { if (joinType.isInstanceOf[InnerLike]) { - Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition))) + Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), p.condition))) } else { None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index f68c41694126..71f7a708ad68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -570,4 +570,31 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP assert(joinHints == expectedHints) } } + + test("SPARK-32220: Non Cartesian Product Join Result Correct with SHUFFLE_REPLICATE_NL hint") { + withTempView("t1", "t2") { + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") + val df1 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2 ON t1.key = t2.key") + val df2 = sql("SELECT * from t1 join t2 ON t1.key = t2.key") + assert(df1.collect().size == df2.collect().size) + + val df3 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2") + val df4 = sql("SELECT * from t1 join t2") + assert(df3.collect().size == df4.collect().size) + + val df5 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2 ON t1.key < t2.key") + val df6 = sql("SELECT * from t1 join t2 ON t1.key < t2.key") + assert(df5.collect().size == df6.collect().size) + + val df7 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2 ON t1.key < 2") + val df8 = sql("SELECT * from t1 join t2 ON t1.key < 2") + assert(df7.collect().size == df8.collect().size) + + + val df9 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2 ON t2.key < 2") + val df10 = sql("SELECT * from t1 join t2 ON t2.key < 2") + assert(df9.collect().size == df10.collect().size) + } + } }