diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 9f914865b3a2..75ca4930cf8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -257,12 +257,16 @@ object SubExprUtils extends PredicateHelper { * We can derive these from correlated equality predicates, though we need to take care about * propagating this through operators like OUTER JOIN or UNION. * - * Positive examples: x = outer(a) AND y = outer(b) + * Positive examples: + * - x = outer(a) AND y = outer(b) + * - x = 1 + * - x = outer(a) + 1 + * * Negative examples: * - x <= outer(a) * - x + y = outer(a) * - x = outer(a) OR y = outer(b) - * - y = outer(b) + 1 (this and similar expressions could be supported, but very carefully) + * - y + outer(b) = 1 (this and similar expressions could be supported, but very carefully) * - An equality under the right side of a LEFT OUTER JOIN, e.g. * select *, (select count(*) from y left join * (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x; @@ -274,7 +278,9 @@ object SubExprUtils extends PredicateHelper { plan match { case Filter(cond, child) => val correlated = AttributeSet(splitConjunctivePredicates(cond) - .filter(containsOuter) // TODO: can remove this line to allow e.g. where x = 1 group by x + .filter( + SQLConf.get.getConf(SQLConf.SCALAR_SUBQUERY_ALLOW_GROUP_BY_COLUMN_EQUAL_TO_CONSTANT) + || containsOuter(_)) .filter(DecorrelateInnerQuery.canPullUpOverAgg) .flatMap(_.references)) correlated ++ getCorrelatedEquivalentInnerColumns(child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f4751f202789..9b0f81c12412 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4940,6 +4940,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val SCALAR_SUBQUERY_ALLOW_GROUP_BY_COLUMN_EQUAL_TO_CONSTANT = + buildConf("spark.sql.analyzer.scalarSubqueryAllowGroupByColumnEqualToConstant") + .internal() + .doc("When set to true, allow scalar subqueries with group-by on a column that also " + + " has an equality filter with a constant (SPARK-48557).") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + val ALLOW_SUBQUERY_EXPRESSIONS_IN_LAMBDAS_AND_HIGHER_ORDER_FUNCTIONS = buildConf("spark.sql.analyzer.allowSubqueryExpressionsInLambdasOrHigherOrderFunctions") .internal() diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out index d9eff3459235..671557aa3956 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out @@ -77,6 +77,38 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x +-- !query analysis +Project [x1#x, x2#x, scalar-subquery#x [x1#x] AS scalarsubquery(x1)#xL] +: +- Aggregate [y2#x], [count(1) AS count(1)#xL] +: +- Filter ((outer(x1#x) = y1#x) AND (y2#x = 1)) +: +- SubqueryAlias y +: +- View (`y`, [y1#x, y2#x]) +: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from y where x1 = y1 and y2 = x1 + 1 group by y2) from x +-- !query analysis +Project [x1#x, x2#x, scalar-subquery#x [x1#x && x1#x] AS scalarsubquery(x1, x1)#xL] +: +- Aggregate [y2#x], [count(1) AS count(1)#xL] +: +- Filter ((outer(x1#x) = y1#x) AND (y2#x = (outer(x1#x) + 1))) +: +- SubqueryAlias y +: +- View (`y`, [y1#x, y2#x]) +: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query analysis @@ -117,26 +149,6 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } --- !query -select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x --- !query analysis -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", - "sqlState" : "0A000", - "messageParameters" : { - "value" : "y2" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 11, - "stopIndex" : 71, - "fragment" : "(select count(*) from y where x1 = y1 and y2 = 1 group by y2)" - } ] -} - - -- !query select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql index 627b27ad285b..6787fac75b39 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql @@ -11,13 +11,15 @@ select * from x where (select count(*) from y where y1 = x1 group by y1) = 1; select * from x where (select count(*) from y where y1 = x1 group by x1) = 1; select * from x where (select count(*) from y where y1 > x1 group by x1) = 1; +-- Group-by column equal to constant - legal +select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x; +-- Group-by column equal to expression with constants and outer refs - legal +select *, (select count(*) from y where x1 = y1 and y2 = x1 + 1 group by y2) from x; + -- Illegal queries select * from x where (select count(*) from y where y1 > x1 group by y1) = 1; select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x; --- Equality with literal - disallowed currently but can actually be allowed -select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x; - -- Certain other operators like OUTER JOIN or UNION between the correlating filter and the group-by also can cause the scalar subquery to return multiple values and hence make the query illegal. select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x; select *, (select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x; -- The correlation below the join is unsupported in Spark anyway, but when we do support it this query should still be disallowed. diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out index c044e59a26fd..85ebd91c28c9 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out @@ -75,6 +75,24 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x +-- !query schema +struct +-- !query output +1 1 NULL +2 2 NULL + + +-- !query +select *, (select count(*) from y where x1 = y1 and y2 = x1 + 1 group by y2) from x +-- !query schema +struct +-- !query output +1 1 NULL +2 2 NULL + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query schema @@ -119,28 +137,6 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } --- !query -select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", - "sqlState" : "0A000", - "messageParameters" : { - "value" : "y2" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 11, - "stopIndex" : 71, - "fragment" : "(select count(*) from y where x1 = y1 and y2 = 1 group by y2)" - } ] -} - - -- !query select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x -- !query schema