diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index 9f3c7ef9c28a..a5fcbe6f16b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -61,9 +61,10 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] } // LeftSemi/LeftAnti over Aggregate, only push down if join can be planned as broadcast join. - case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _) + case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), joinCond, _) if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty && !agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) && + canPushThroughCondition(agg.children, joinCond, rightOp) && canPlanAsBroadcastHashJoin(join, conf) => val aliasMap = getAliasMap(agg) val canPushDownPredicate = (predicate: Expression) => { @@ -110,11 +111,11 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] } /** - * Check if we can safely push a join through a project or union by making sure that attributes - * referred in join condition do not contain the same attributes as the plan they are moved - * into. This can happen when both sides of join refers to the same source (self join). This - * function makes sure that the join condition refers to attributes that are not ambiguous (i.e - * present in both the legs of the join) or else the resultant plan will be invalid. + * Check if we can safely push a join through a project, aggregate, or union by making sure that + * attributes referred in join condition do not contain the same attributes as the plan they are + * moved into. This can happen when both sides of join refers to the same source (self join). + * This function makes sure that the join condition refers to attributes that are not ambiguous + * (i.e present in both the legs of the join) or else the resultant plan will be invalid. */ private def canPushThroughCondition( plans: Seq[LogicalPlan], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index 9f77f448d233..04171a85eec6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType -class LeftSemiPushdownSuite extends PlanTest { +class LeftSemiAntiJoinPushDownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -46,7 +46,7 @@ class LeftSemiPushdownSuite extends PlanTest { val testRelation1 = LocalRelation($"d".int) val testRelation2 = LocalRelation($"e".int) - test("Project: LeftSemiAnti join pushdown") { + test("Project: LeftSemi join pushdown") { val originalQuery = testRelation .select(star()) .join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d")) @@ -59,7 +59,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") { + test("Project: LeftSemi join no pushdown - non-deterministic proj exprs") { val originalQuery = testRelation .select(Rand(1), $"b", $"c") .join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d")) @@ -68,7 +68,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("Project: LeftSemiAnti join non correlated scalar subq") { + test("Project: LeftSemi join pushdown - non-correlated scalar subq") { val subq = ScalarSubquery(testRelation.groupBy($"b")(sum($"c").as("sum")).analyze) val originalQuery = testRelation .select(subq.as("sum")) @@ -83,7 +83,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in projection list") { + test("Project: LeftSemi join no pushdown - correlated scalar subq in projection list") { val testRelation2 = LocalRelation($"e".int, $"f".int) val subqPlan = testRelation2.groupBy($"e")(sum($"f").as("sum")).where($"e" === $"a") val subqExpr = ScalarSubquery(subqPlan) @@ -95,7 +95,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("Aggregate: LeftSemiAnti join pushdown") { + test("Aggregate: LeftSemi join pushdown") { val originalQuery = testRelation .groupBy($"b")($"b", sum($"c")) .join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d")) @@ -109,7 +109,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr expressions") { + test("Aggregate: LeftSemi join no pushdown - non-deterministic aggr expressions") { val originalQuery = testRelation .groupBy($"b")($"b", Rand(10).as("c")) .join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d")) @@ -142,7 +142,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("LeftSemiAnti join over aggregate - no pushdown") { + test("Aggregate: LeftSemi join no pushdown") { val originalQuery = testRelation .groupBy($"b")($"b", sum($"c").as("sum")) .join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d" && $"sum" === $"d")) @@ -151,7 +151,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") { + test("Aggregate: LeftSemi join pushdown - non-correlated scalar subq aggr exprs") { val subq = ScalarSubquery(testRelation.groupBy($"b")(sum($"c").as("sum")).analyze) val originalQuery = testRelation .groupBy($"a") ($"a", subq.as("sum")) @@ -166,7 +166,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("LeftSemiAnti join over Window") { + test("Window: LeftSemi join pushdown") { val winExpr = windowExpr(count($"b"), windowSpec($"a" :: Nil, $"b".asc :: Nil, UnspecifiedFrame)) @@ -185,7 +185,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Window: LeftSemi partial pushdown") { + test("Window: LeftSemi join partial pushdown") { // Attributes from join condition which does not refer to the window partition spec // are kept up in the plan as a Filter operator above Window. val winExpr = windowExpr(count($"b"), @@ -227,7 +227,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Union: LeftSemiAnti join pushdown") { + test("Union: LeftSemi join pushdown") { val testRelation2 = LocalRelation($"x".int, $"y".int, $"z".int) val originalQuery = Union(Seq(testRelation, testRelation2)) @@ -243,7 +243,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Union: LeftSemiAnti join pushdown in self join scenario") { + test("Union: LeftSemi join pushdown in self join scenario") { val testRelation2 = LocalRelation($"x".int, $"y".int, $"z".int) val attrX = testRelation2.output.head @@ -262,7 +262,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Unary: LeftSemiAnti join pushdown") { + test("Unary: LeftSemi join pushdown") { val originalQuery = testRelation .select(star()) .repartition(1) @@ -277,7 +277,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Unary: LeftSemiAnti join pushdown - empty join condition") { + test("Unary: LeftSemi join pushdown - empty join condition") { val originalQuery = testRelation .select(star()) .repartition(1) @@ -292,7 +292,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Unary: LeftSemi join pushdown - partial pushdown") { + test("Unary: LeftSemi join partial pushdown") { val testRelationWithArrayType = LocalRelation($"a".int, $"b".int, $"c_arr".array(IntegerType)) val originalQuery = testRelationWithArrayType .generate(Explode($"c_arr"), alias = Some("arr"), outputNames = Seq("out_col")) @@ -309,7 +309,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Unary: LeftAnti join pushdown - no pushdown") { + test("Unary: LeftAnti join no pushdown") { val testRelationWithArrayType = LocalRelation($"a".int, $"b".int, $"c_arr".array(IntegerType)) val originalQuery = testRelationWithArrayType .generate(Explode($"c_arr"), alias = Some("arr"), outputNames = Seq("out_col")) @@ -320,7 +320,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("Unary: LeftSemiAnti join pushdown - no pushdown") { + test("Unary: LeftSemi join - no pushdown") { val testRelationWithArrayType = LocalRelation($"a".int, $"b".int, $"c_arr".array(IntegerType)) val originalQuery = testRelationWithArrayType .generate(Explode($"c_arr"), alias = Some("arr"), outputNames = Seq("out_col")) @@ -331,7 +331,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("Unary: LeftSemi join push down through Expand") { + test("Unary: LeftSemi join pushdown through Expand") { val expand = Expand(Seq(Seq($"a", $"b", "null"), Seq($"a", "null", $"c")), Seq($"a", $"b", $"c"), testRelation) val originalQuery = expand @@ -437,6 +437,25 @@ class LeftSemiPushdownSuite extends PlanTest { } } + Seq(LeftSemi, LeftAnti).foreach { case jt => + test(s"Aggregate: $jt join no pushdown - join condition refers left leg and right leg child") { + val aggregation = testRelation + .select($"b".as("id"), $"c") + .groupBy($"id")($"id", sum($"c").as("sum")) + + // reference "b" exists in left leg, and the children of the right leg of the join + val originalQuery = aggregation.select(($"id" + 1).as("id_plus_1"), $"sum") + .join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1")) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select($"b".as("id"), $"c") + .groupBy($"id")(($"id" + 1).as("id_plus_1"), sum($"c").as("sum")) + .join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1")) + .analyze + comparePlans(optimized, correctAnswer) + } + } + Seq(LeftSemi, LeftAnti).foreach { case outerJT => Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT => test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c841bffac8cd..e4f6b4cb40c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -344,6 +344,24 @@ class DataFrameJoinSuite extends QueryTest } } + Seq("left_semi", "left_anti").foreach { joinType => + test(s"SPARK-41162: $joinType self-joined aggregated dataframe") { + // aggregated dataframe + val ids = Seq(1, 2, 3).toDF("id").distinct() + + // self-joined via joinType + val result = ids.withColumn("id", $"id" + 1) + .join(ids, "id", joinType).collect() + + val expected = joinType match { + case "left_semi" => 2 + case "left_anti" => 1 + case _ => -1 // unsupported test type, test will always fail + } + assert(result.length == expected) + } + } + def extractLeftDeepInnerJoins(plan: LogicalPlan): Seq[LogicalPlan] = plan match { case j @ Join(left, right, _: InnerLike, _, _) => right +: extractLeftDeepInnerJoins(left) case Filter(_, child) => extractLeftDeepInnerJoins(child)