-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-36114][SQL] Support subqueries with correlated non-equality predicates #38135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,6 +66,11 @@ class SubquerySuite extends QueryTest | |
| t.createOrReplaceTempView("t") | ||
| } | ||
|
|
||
| private def checkNumJoins(plan: LogicalPlan, numJoins: Int): Unit = { | ||
| val joins = plan.collect { case j: Join => j } | ||
| assert(joins.size == numJoins) | ||
| } | ||
|
|
||
| test("SPARK-18854 numberedTreeString for subquery") { | ||
| val df = sql("select * from range(10) where id not in " + | ||
| "(select id from range(2) union all select id from range(2))") | ||
|
|
@@ -562,17 +567,10 @@ class SubquerySuite extends QueryTest | |
| } | ||
|
|
||
| test("non-equal correlated scalar subquery") { | ||
| val exception = intercept[AnalysisException] { | ||
| sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1") | ||
| } | ||
| checkErrorMatchPVals( | ||
| exception, | ||
| errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + | ||
| "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", | ||
| parameters = Map("treeNode" -> "(?s).*"), | ||
| sqlState = None, | ||
| context = ExpectedContext( | ||
| fragment = "select sum(b) from l l2 where l2.a < l1.a", start = 11, stop = 51)) | ||
| checkAnswer( | ||
| sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1"), | ||
| Seq(Row(1, null), Row(1, null), Row(2, 4), Row(2, 4), Row(3, 6), Row(null, null), | ||
| Row(null, null), Row(6, 9))) | ||
| } | ||
|
|
||
| test("disjunctive correlated scalar subquery") { | ||
|
|
@@ -2105,25 +2103,17 @@ class SubquerySuite extends QueryTest | |
| } | ||
| } | ||
|
|
||
| test("SPARK-38155: disallow distinct aggregate in lateral subqueries") { | ||
| test("SPARK-36114: distinct aggregate in lateral subqueries") { | ||
| withTempView("t1", "t2") { | ||
| Seq((0, 1)).toDF("c1", "c2").createOrReplaceTempView("t1") | ||
| Seq((1, 2), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") | ||
| val exception = intercept[AnalysisException] { | ||
| sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)") | ||
| } | ||
| checkErrorMatchPVals( | ||
| exception, | ||
| errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + | ||
| "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", | ||
| parameters = Map("treeNode" -> "(?s).*"), | ||
| sqlState = None, | ||
| context = ExpectedContext( | ||
| fragment = "SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1", start = 31, stop = 73)) | ||
| checkAnswer( | ||
| sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)"), | ||
| Row(0, 1, 2) :: Nil) | ||
| } | ||
| } | ||
|
|
||
| test("SPARK-38180: allow safe cast expressions in correlated equality conditions") { | ||
| test("SPARK-38180, SPARK-36114: allow safe cast expressions in correlated equality conditions") { | ||
| withTempView("t1", "t2") { | ||
| Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") | ||
| Seq((0, 2), (0, 3)).toDF("c1", "c2").createOrReplaceTempView("t2") | ||
|
|
@@ -2139,19 +2129,14 @@ class SubquerySuite extends QueryTest | |
| |FROM (SELECT CAST(c1 AS STRING) a FROM t1) | ||
| |""".stripMargin), | ||
| Row(5) :: Row(null) :: Nil) | ||
| val exception1 = intercept[AnalysisException] { | ||
| sql( | ||
| """SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a) | ||
| |FROM (SELECT CAST(c1 AS SHORT) a FROM t1)""".stripMargin) | ||
| } | ||
| checkErrorMatchPVals( | ||
| exception1, | ||
| errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + | ||
| "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", | ||
| parameters = Map("treeNode" -> "(?s).*"), | ||
| sqlState = None, | ||
| context = ExpectedContext( | ||
| fragment = "SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a", start = 8, stop = 57)) | ||
| // SPARK-36114: we now allow non-safe cast expressions in correlated predicates. | ||
cloud-fan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| val df = sql( | ||
| """SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a) | ||
| |FROM (SELECT CAST(c1 AS SHORT) a FROM t1) | ||
| |""".stripMargin) | ||
| checkAnswer(df, Row(5) :: Row(null) :: Nil) | ||
| // The optimized plan should have one left outer join and one domain (inner) join. | ||
| checkNumJoins(df.queryExecution.optimizedPlan, 2) | ||
|
||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I have been missing context:
After the non-equality predicates are supported, what are the left gap? I assuming all the predicates are supported now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh you have an example below which makes sense: