-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-41162][SQL] Fix anti- and semi-join for self-join with aggregations #39131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eebeb34
76dcd1e
ade61b8
cc28b21
2ae3a41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules._ | |
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.types.IntegerType | ||
|
|
||
| class LeftSemiPushdownSuite extends PlanTest { | ||
| class LeftSemiAntiJoinPushDownSuite extends PlanTest { | ||
|
|
||
| object Optimize extends RuleExecutor[LogicalPlan] { | ||
| val batches = | ||
|
|
@@ -46,7 +46,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| val testRelation1 = LocalRelation($"d".int) | ||
| val testRelation2 = LocalRelation($"e".int) | ||
|
|
||
| test("Project: LeftSemiAnti join pushdown") { | ||
| test("Project: LeftSemi join pushdown") { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These change to test names are necessary?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The term |
||
| val originalQuery = testRelation | ||
| .select(star()) | ||
| .join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d")) | ||
|
|
@@ -59,7 +59,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, correctAnswer) | ||
| } | ||
|
|
||
| test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") { | ||
| test("Project: LeftSemi join no pushdown - non-deterministic proj exprs") { | ||
| val originalQuery = testRelation | ||
| .select(Rand(1), $"b", $"c") | ||
| .join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d")) | ||
|
|
@@ -68,7 +68,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, originalQuery.analyze) | ||
| } | ||
|
|
||
| test("Project: LeftSemiAnti join non correlated scalar subq") { | ||
| test("Project: LeftSemi join pushdown - non-correlated scalar subq") { | ||
| val subq = ScalarSubquery(testRelation.groupBy($"b")(sum($"c").as("sum")).analyze) | ||
| val originalQuery = testRelation | ||
| .select(subq.as("sum")) | ||
|
|
@@ -83,7 +83,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, correctAnswer) | ||
| } | ||
|
|
||
| test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in projection list") { | ||
| test("Project: LeftSemi join no pushdown - correlated scalar subq in projection list") { | ||
| val testRelation2 = LocalRelation($"e".int, $"f".int) | ||
| val subqPlan = testRelation2.groupBy($"e")(sum($"f").as("sum")).where($"e" === $"a") | ||
| val subqExpr = ScalarSubquery(subqPlan) | ||
|
|
@@ -95,7 +95,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, originalQuery.analyze) | ||
| } | ||
|
|
||
| test("Aggregate: LeftSemiAnti join pushdown") { | ||
| test("Aggregate: LeftSemi join pushdown") { | ||
| val originalQuery = testRelation | ||
| .groupBy($"b")($"b", sum($"c")) | ||
| .join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d")) | ||
|
|
@@ -109,7 +109,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, correctAnswer) | ||
| } | ||
|
|
||
| test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr expressions") { | ||
| test("Aggregate: LeftSemi join no pushdown - non-deterministic aggr expressions") { | ||
| val originalQuery = testRelation | ||
| .groupBy($"b")($"b", Rand(10).as("c")) | ||
| .join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d")) | ||
|
|
@@ -142,7 +142,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, originalQuery.analyze) | ||
| } | ||
|
|
||
| test("LeftSemiAnti join over aggregate - no pushdown") { | ||
| test("Aggregate: LeftSemi join no pushdown") { | ||
| val originalQuery = testRelation | ||
| .groupBy($"b")($"b", sum($"c").as("sum")) | ||
| .join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d" && $"sum" === $"d")) | ||
|
|
@@ -151,7 +151,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, originalQuery.analyze) | ||
| } | ||
|
|
||
| test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") { | ||
| test("Aggregate: LeftSemi join pushdown - non-correlated scalar subq aggr exprs") { | ||
| val subq = ScalarSubquery(testRelation.groupBy($"b")(sum($"c").as("sum")).analyze) | ||
| val originalQuery = testRelation | ||
| .groupBy($"a") ($"a", subq.as("sum")) | ||
|
|
@@ -166,7 +166,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, correctAnswer) | ||
| } | ||
|
|
||
| test("LeftSemiAnti join over Window") { | ||
| test("Window: LeftSemi join pushdown") { | ||
| val winExpr = windowExpr(count($"b"), | ||
| windowSpec($"a" :: Nil, $"b".asc :: Nil, UnspecifiedFrame)) | ||
|
|
||
|
|
@@ -185,7 +185,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, correctAnswer) | ||
| } | ||
|
|
||
| test("Window: LeftSemi partial pushdown") { | ||
| test("Window: LeftSemi join partial pushdown") { | ||
| // Attributes from join condition which does not refer to the window partition spec | ||
| // are kept up in the plan as a Filter operator above Window. | ||
| val winExpr = windowExpr(count($"b"), | ||
|
|
@@ -227,7 +227,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, correctAnswer) | ||
| } | ||
|
|
||
| test("Union: LeftSemiAnti join pushdown") { | ||
| test("Union: LeftSemi join pushdown") { | ||
| val testRelation2 = LocalRelation($"x".int, $"y".int, $"z".int) | ||
|
|
||
| val originalQuery = Union(Seq(testRelation, testRelation2)) | ||
|
|
@@ -243,7 +243,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, correctAnswer) | ||
| } | ||
|
|
||
| test("Union: LeftSemiAnti join pushdown in self join scenario") { | ||
| test("Union: LeftSemi join pushdown in self join scenario") { | ||
| val testRelation2 = LocalRelation($"x".int, $"y".int, $"z".int) | ||
| val attrX = testRelation2.output.head | ||
|
|
||
|
|
@@ -262,7 +262,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, correctAnswer) | ||
| } | ||
|
|
||
| test("Unary: LeftSemiAnti join pushdown") { | ||
| test("Unary: LeftSemi join pushdown") { | ||
| val originalQuery = testRelation | ||
| .select(star()) | ||
| .repartition(1) | ||
|
|
@@ -277,7 +277,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, correctAnswer) | ||
| } | ||
|
|
||
| test("Unary: LeftSemiAnti join pushdown - empty join condition") { | ||
| test("Unary: LeftSemi join pushdown - empty join condition") { | ||
| val originalQuery = testRelation | ||
| .select(star()) | ||
| .repartition(1) | ||
|
|
@@ -292,7 +292,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, correctAnswer) | ||
| } | ||
|
|
||
| test("Unary: LeftSemi join pushdown - partial pushdown") { | ||
| test("Unary: LeftSemi join partial pushdown") { | ||
| val testRelationWithArrayType = LocalRelation($"a".int, $"b".int, $"c_arr".array(IntegerType)) | ||
| val originalQuery = testRelationWithArrayType | ||
| .generate(Explode($"c_arr"), alias = Some("arr"), outputNames = Seq("out_col")) | ||
|
|
@@ -309,7 +309,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, correctAnswer) | ||
| } | ||
|
|
||
| test("Unary: LeftAnti join pushdown - no pushdown") { | ||
| test("Unary: LeftAnti join no pushdown") { | ||
| val testRelationWithArrayType = LocalRelation($"a".int, $"b".int, $"c_arr".array(IntegerType)) | ||
| val originalQuery = testRelationWithArrayType | ||
| .generate(Explode($"c_arr"), alias = Some("arr"), outputNames = Seq("out_col")) | ||
|
|
@@ -320,7 +320,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, originalQuery.analyze) | ||
| } | ||
|
|
||
| test("Unary: LeftSemiAnti join pushdown - no pushdown") { | ||
| test("Unary: LeftSemi join - no pushdown") { | ||
| val testRelationWithArrayType = LocalRelation($"a".int, $"b".int, $"c_arr".array(IntegerType)) | ||
| val originalQuery = testRelationWithArrayType | ||
| .generate(Explode($"c_arr"), alias = Some("arr"), outputNames = Seq("out_col")) | ||
|
|
@@ -331,7 +331,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| comparePlans(optimized, originalQuery.analyze) | ||
| } | ||
|
|
||
| test("Unary: LeftSemi join push down through Expand") { | ||
| test("Unary: LeftSemi join pushdown through Expand") { | ||
| val expand = Expand(Seq(Seq($"a", $"b", "null"), Seq($"a", "null", $"c")), | ||
| Seq($"a", $"b", $"c"), testRelation) | ||
| val originalQuery = expand | ||
|
|
@@ -437,6 +437,25 @@ class LeftSemiPushdownSuite extends PlanTest { | |
| } | ||
| } | ||
|
|
||
| Seq(LeftSemi, LeftAnti).foreach { case jt => | ||
| test(s"Aggregate: $jt join no pushdown - join condition refers left leg and right leg child") { | ||
| val aggregation = testRelation | ||
| .select($"b".as("id"), $"c") | ||
| .groupBy($"id")($"id", sum($"c").as("sum")) | ||
|
|
||
| // reference "b" exists in left leg, and the children of the right leg of the join | ||
| val originalQuery = aggregation.select(($"id" + 1).as("id_plus_1"), $"sum") | ||
| .join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1")) | ||
| val optimized = Optimize.execute(originalQuery.analyze) | ||
| val correctAnswer = testRelation | ||
| .select($"b".as("id"), $"c") | ||
| .groupBy($"id")(($"id" + 1).as("id_plus_1"), sum($"c").as("sum")) | ||
| .join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1")) | ||
| .analyze | ||
| comparePlans(optimized, correctAnswer) | ||
| } | ||
| } | ||
|
|
||
| Seq(LeftSemi, LeftAnti).foreach { case outerJT => | ||
| Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT => | ||
| test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should rewrite the
joinConditionassuming the join has already been pushed through Aggregate. That said, we need to do alias replacement forjoinConditionfirst. cc @EnricoMiThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand. The
canPushThroughConditionis called before theJoinis being pushed through theAggregate, it has been added to prevent this from happening in this situation. The other cases (e.g.Union) are calling intocanPushThroughConditionequivalently.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm,
canPushThroughConditionchecks the right side references of the join condition, and check if the right side references have conflict expr ID with left side plan (below Project) output. It doesn't care about the left side references of the join condition.