Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,12 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
// 1 | 2 | 4
// and the plan after rewrite will give the original query incorrect results.
def failOnUnsupportedCorrelatedPredicate(predicates: Seq[Expression], p: LogicalPlan): Unit = {
if (predicates.nonEmpty) {
// Correlated non-equality predicates are only supported with the decorrelate
// inner query framework. Currently we only use this new framework for scalar
// and lateral subqueries.
val allowNonEqualityPredicates =
SQLConf.get.decorrelateInnerQueryEnabled && (isScalar || isLateral)
if (!allowNonEqualityPredicates && predicates.nonEmpty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I have been missing context:

After the non-equality predicates are supported, what are the left gap? I assuming all the predicates are supported now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh you have an example below which makes sense:

-- Correlated equality predicates that are not supported after SPARK-35080
SELECT c, (
    SELECT count(*)
    FROM (VALUES ('ab'), ('abc'), ('bc')) t2(c)
    WHERE t1.c = substring(t2.c, 1, 1)
) FROM (VALUES ('a'), ('b')) t1(c);

// Report a non-supported case as an exception
p.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ class AnalysisErrorSuite extends AnalysisTest {
(And($"a" === $"c", Cast($"d", IntegerType) === $"c"), "CAST(d#x AS INT) = outer(c#x)"))
conditions.foreach { case (cond, msg) =>
val plan = Project(
ScalarSubquery(
Exists(
Aggregate(Nil, count(Literal(1)).as("cnt") :: Nil,
Filter(cond, t1))
).as("sub") :: Nil,
Expand Down
3 changes: 3 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c1 = t2.c1);
-- lateral join with correlated non-equality predicates
SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c2 < t2.c2);

-- SPARK-36114: lateral join with aggregation and correlated non-equality predicates
SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2 WHERE t1.c2 < t2.c2);

-- lateral join can reference preceding FROM clause items
SELECT * FROM t1 JOIN t2 JOIN LATERAL (SELECT t1.c2 + t2.c2);
-- expect error: cannot resolve `t2.c1`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,48 @@ SELECT c1, (

-- Multi-value subquery error
SELECT (SELECT a FROM (SELECT 1 AS a UNION ALL SELECT 2 AS a) t) AS b;

-- SPARK-36114: Support correlated non-equality predicates
CREATE OR REPLACE TEMP VIEW t1(c1, c2) AS (VALUES (0, 1), (1, 2));
CREATE OR REPLACE TEMP VIEW t2(c1, c2) AS (VALUES (0, 2), (0, 3));

-- Neumann example Q2
CREATE OR REPLACE TEMP VIEW students(id, name, major, year) AS (VALUES
(0, 'A', 'CS', 2022),
(1, 'B', 'CS', 2022),
(2, 'C', 'Math', 2022));
CREATE OR REPLACE TEMP VIEW exams(sid, course, curriculum, grade, date) AS (VALUES
(0, 'C1', 'CS', 4, 2020),
(0, 'C2', 'CS', 3, 2021),
(1, 'C1', 'CS', 2, 2020),
(1, 'C2', 'CS', 1, 2021));

SELECT students.name, exams.course
FROM students, exams
WHERE students.id = exams.sid
AND (students.major = 'CS' OR students.major = 'Games Eng')
AND exams.grade >= (
SELECT avg(exams.grade) + 1
FROM exams
WHERE students.id = exams.sid
OR (exams.curriculum = students.major AND students.year > exams.date));

-- Correlated non-equality predicates
SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 > t2.c1) FROM t1;
SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 >= t2.c1 AND t1.c2 < t2.c2) FROM t1;

-- Correlated non-equality predicates with the COUNT bug.
SELECT (SELECT count(*) FROM t2 WHERE t1.c1 > t2.c1) FROM t1;

-- Correlated equality predicates that are not supported after SPARK-35080
SELECT c, (
SELECT count(*)
FROM (VALUES ('ab'), ('abc'), ('bc')) t2(c)
WHERE t1.c = substring(t2.c, 1, 1)
) FROM (VALUES ('a'), ('b')) t1(c);

SELECT c, (
SELECT count(*)
FROM (VALUES (0, 6), (1, 5), (2, 4), (3, 3)) t1(a, b)
WHERE a + b = c
) FROM (VALUES (6)) t2(c);
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,15 @@ struct<c1:int,c2:int,c2:int>
1 2 3


-- !query
SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2 WHERE t1.c2 < t2.c2)
-- !query schema
struct<c1:int,c2:int,m:int>
-- !query output
0 1 3
1 2 3


-- !query
SELECT * FROM t1 JOIN t2 JOIN LATERAL (SELECT t1.c2 + t2.c2)
-- !query schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,110 @@ org.apache.spark.SparkException
"fragment" : "(SELECT a FROM (SELECT 1 AS a UNION ALL SELECT 2 AS a) t)"
} ]
}


-- !query
CREATE OR REPLACE TEMP VIEW t1(c1, c2) AS (VALUES (0, 1), (1, 2))
-- !query schema
struct<>
-- !query output



-- !query
CREATE OR REPLACE TEMP VIEW t2(c1, c2) AS (VALUES (0, 2), (0, 3))
-- !query schema
struct<>
-- !query output



-- !query
CREATE OR REPLACE TEMP VIEW students(id, name, major, year) AS (VALUES
(0, 'A', 'CS', 2022),
(1, 'B', 'CS', 2022),
(2, 'C', 'Math', 2022))
-- !query schema
struct<>
-- !query output



-- !query
CREATE OR REPLACE TEMP VIEW exams(sid, course, curriculum, grade, date) AS (VALUES
(0, 'C1', 'CS', 4, 2020),
(0, 'C2', 'CS', 3, 2021),
(1, 'C1', 'CS', 2, 2020),
(1, 'C2', 'CS', 1, 2021))
-- !query schema
struct<>
-- !query output



-- !query
SELECT students.name, exams.course
FROM students, exams
WHERE students.id = exams.sid
AND (students.major = 'CS' OR students.major = 'Games Eng')
AND exams.grade >= (
SELECT avg(exams.grade) + 1
FROM exams
WHERE students.id = exams.sid
OR (exams.curriculum = students.major AND students.year > exams.date))
-- !query schema
struct<name:string,course:string>
-- !query output
A C1


-- !query
SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 > t2.c1) FROM t1
-- !query schema
struct<scalarsubquery(c1):int>
-- !query output
2
NULL


-- !query
SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 >= t2.c1 AND t1.c2 < t2.c2) FROM t1
-- !query schema
struct<scalarsubquery(c1, c2):int>
-- !query output
2
3


-- !query
SELECT (SELECT count(*) FROM t2 WHERE t1.c1 > t2.c1) FROM t1
-- !query schema
struct<scalarsubquery(c1):bigint>
-- !query output
0
2


-- !query
SELECT c, (
SELECT count(*)
FROM (VALUES ('ab'), ('abc'), ('bc')) t2(c)
WHERE t1.c = substring(t2.c, 1, 1)
) FROM (VALUES ('a'), ('b')) t1(c)
-- !query schema
struct<c:string,scalarsubquery(c):bigint>
-- !query output
a 2
b 1


-- !query
SELECT c, (
SELECT count(*)
FROM (VALUES (0, 6), (1, 5), (2, 4), (3, 3)) t1(a, b)
WHERE a + b = c
) FROM (VALUES (6)) t2(c)
-- !query schema
struct<c:int,scalarsubquery(c):bigint>
-- !query output
6 4
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,6 @@ WHERE udf(t1.v) >= (SELECT min(udf(t2.v))
FROM t2
WHERE t2.k = t1.k)
-- !query schema
struct<>
struct<k:string>
-- !query output
org.apache.spark.sql.AnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE",
"messageParameters" : {
"treeNode" : "(cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\nFilter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\n+- SubqueryAlias t2\n +- View (`t2`, [k#x,v#x])\n +- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]\n +- Project [k#x, v#x]\n +- SubqueryAlias t2\n +- LocalRelation [k#x, v#x]\n"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 39,
"stopIndex" : 141,
"fragment" : "SELECT udf(max(udf(t2.v)))\n FROM t2\n WHERE udf(t2.k) = udf(t1.k)"
} ]
}
two
59 changes: 22 additions & 37 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ class SubquerySuite extends QueryTest
t.createOrReplaceTempView("t")
}

private def checkNumJoins(plan: LogicalPlan, numJoins: Int): Unit = {
val joins = plan.collect { case j: Join => j }
assert(joins.size == numJoins)
}

test("SPARK-18854 numberedTreeString for subquery") {
val df = sql("select * from range(10) where id not in " +
"(select id from range(2) union all select id from range(2))")
Expand Down Expand Up @@ -562,17 +567,10 @@ class SubquerySuite extends QueryTest
}

test("non-equal correlated scalar subquery") {
val exception = intercept[AnalysisException] {
sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1")
}
checkErrorMatchPVals(
exception,
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE",
parameters = Map("treeNode" -> "(?s).*"),
sqlState = None,
context = ExpectedContext(
fragment = "select sum(b) from l l2 where l2.a < l1.a", start = 11, stop = 51))
checkAnswer(
sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1"),
Seq(Row(1, null), Row(1, null), Row(2, 4), Row(2, 4), Row(3, 6), Row(null, null),
Row(null, null), Row(6, 9)))
}

test("disjunctive correlated scalar subquery") {
Expand Down Expand Up @@ -2105,25 +2103,17 @@ class SubquerySuite extends QueryTest
}
}

test("SPARK-38155: disallow distinct aggregate in lateral subqueries") {
test("SPARK-36114: distinct aggregate in lateral subqueries") {
withTempView("t1", "t2") {
Seq((0, 1)).toDF("c1", "c2").createOrReplaceTempView("t1")
Seq((1, 2), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
val exception = intercept[AnalysisException] {
sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)")
}
checkErrorMatchPVals(
exception,
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE",
parameters = Map("treeNode" -> "(?s).*"),
sqlState = None,
context = ExpectedContext(
fragment = "SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1", start = 31, stop = 73))
checkAnswer(
sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)"),
Row(0, 1, 2) :: Nil)
}
}

test("SPARK-38180: allow safe cast expressions in correlated equality conditions") {
test("SPARK-38180, SPARK-36114: allow safe cast expressions in correlated equality conditions") {
withTempView("t1", "t2") {
Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
Seq((0, 2), (0, 3)).toDF("c1", "c2").createOrReplaceTempView("t2")
Expand All @@ -2139,19 +2129,14 @@ class SubquerySuite extends QueryTest
|FROM (SELECT CAST(c1 AS STRING) a FROM t1)
|""".stripMargin),
Row(5) :: Row(null) :: Nil)
val exception1 = intercept[AnalysisException] {
sql(
"""SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a)
|FROM (SELECT CAST(c1 AS SHORT) a FROM t1)""".stripMargin)
}
checkErrorMatchPVals(
exception1,
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE",
parameters = Map("treeNode" -> "(?s).*"),
sqlState = None,
context = ExpectedContext(
fragment = "SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a", start = 8, stop = 57))
// SPARK-36114: we now allow non-safe cast expressions in correlated predicates.
val df = sql(
"""SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a)
|FROM (SELECT CAST(c1 AS SHORT) a FROM t1)
|""".stripMargin)
checkAnswer(df, Row(5) :: Row(null) :: Nil)
// The optimized plan should have one left outer join and one domain (inner) join.
checkNumJoins(df.queryExecution.optimizedPlan, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am missing context here: why need to check NumJoins only for this case? I did a code search and seems like other test cases in this suite do not care NumJoins.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am verifying the optimized plan should have 1 left outer join and 1 domain (inner) join.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the number of joins for the safe cast case as well?

}
}

Expand Down