Skip to content

Commit 6ed2dfb

Browse files
xuanyuankingcloud-fan
authored andcommitted
[SPARK-31519][SQL] Cast in having aggregate expressions returns the wrong result
### What changes were proposed in this pull request? Add a new logical node AggregateWithHaving, and the parser should create this plan for HAVING. The analyzer resolves it to Filter(..., Aggregate(...)). ### Why are the changes needed? The SQL parser in Spark creates Filter(..., Aggregate(...)) for the HAVING query, and Spark has a special analyzer rule ResolveAggregateFunctions to resolve the aggregate functions and grouping columns in the Filter operator. It works for simple cases in a very tricky way as it relies on rule execution order: 1. Rule ResolveReferences hits the Aggregate operator and resolves attributes inside aggregate functions, but the function itself is still unresolved as it's an UnresolvedFunction. This stops resolving the Filter operator as the child Aggrege operator is still unresolved. 2. Rule ResolveFunctions resolves UnresolvedFunction. This makes the Aggrege operator resolved. 3. Rule ResolveAggregateFunctions resolves the Filter operator if its child is a resolved Aggregate. This rule can correctly resolve the grouping columns. In the example query, I put a CAST, which needs to be resolved by rule ResolveTimeZone, which runs after ResolveAggregateFunctions. This breaks step 3 as the Aggregate operator is unresolved at that time. Then the analyzer starts next round and the Filter operator is resolved by ResolveReferences, which wrongly resolves the grouping columns. See the demo below: ``` SELECT SUM(a) AS b, '2020-01-01' AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10 ``` The query's result is ``` +---+----------+ | b| fake| +---+----------+ | 2|2020-01-01| +---+----------+ ``` But if we add CAST, it will return an empty result. ``` SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10 ``` ### Does this PR introduce any user-facing change? Yes, bug fix for cast in having aggregate expressions. ### How was this patch tested? New UT added. Closes #28294 from xuanyuanking/SPARK-31519. Authored-by: Yuanjian Li <xyliyuanjian@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 079b362 commit 6ed2dfb

File tree

9 files changed

+135
-80
lines changed

9 files changed

+135
-80
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 74 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,13 +1239,13 @@ class Analyzer(
12391239
/**
12401240
* Resolves the attribute and extract value expressions(s) by traversing the
12411241
* input expression in top down manner. The traversal is done in top-down manner as
1242-
* we need to skip over unbound lamda function expression. The lamda expressions are
1242+
* we need to skip over unbound lambda function expression. The lambda expressions are
12431243
* resolved in a different rule [[ResolveLambdaVariables]]
12441244
*
12451245
* Example :
12461246
* SELECT transform(array(1, 2, 3), (x, i) -> x + i)"
12471247
*
1248-
* In the case above, x and i are resolved as lamda variables in [[ResolveLambdaVariables]]
1248+
* In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]
12491249
*
12501250
* Note : In this routine, the unresolved attributes are resolved from the input plan's
12511251
* children attributes.
@@ -1400,6 +1400,9 @@ class Analyzer(
14001400
notMatchedActions = newNotMatchedActions)
14011401
}
14021402

1403+
// Skip the having clause here, this will be handled in ResolveAggregateFunctions.
1404+
case h: AggregateWithHaving => h
1405+
14031406
case q: LogicalPlan =>
14041407
logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}")
14051408
q.mapExpressions(resolveExpressionTopDown(_, q))
@@ -2040,62 +2043,14 @@ class Analyzer(
20402043
*/
20412044
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
20422045
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
2043-
case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved =>
2046+
// Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly
2047+
// resolve the having condition expression, here we skip resolving it in ResolveReferences
2048+
// and transform it to Filter after aggregate is resolved. See more details in SPARK-31519.
2049+
case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved =>
2050+
resolveHaving(Filter(cond, agg), agg)
20442051

2045-
// Try resolving the condition of the filter as though it is in the aggregate clause
2046-
try {
2047-
val aggregatedCondition =
2048-
Aggregate(
2049-
grouping,
2050-
Alias(cond, "havingCondition")() :: Nil,
2051-
child)
2052-
val resolvedOperator = executeSameContext(aggregatedCondition)
2053-
def resolvedAggregateFilter =
2054-
resolvedOperator
2055-
.asInstanceOf[Aggregate]
2056-
.aggregateExpressions.head
2057-
2058-
// If resolution was successful and we see the filter has an aggregate in it, add it to
2059-
// the original aggregate operator.
2060-
if (resolvedOperator.resolved) {
2061-
// Try to replace all aggregate expressions in the filter by an alias.
2062-
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
2063-
val transformedAggregateFilter = resolvedAggregateFilter.transform {
2064-
case ae: AggregateExpression =>
2065-
val alias = Alias(ae, ae.toString)()
2066-
aggregateExpressions += alias
2067-
alias.toAttribute
2068-
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
2069-
case e: Expression if grouping.exists(_.semanticEquals(e)) &&
2070-
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
2071-
!agg.output.exists(_.semanticEquals(e)) =>
2072-
e match {
2073-
case ne: NamedExpression =>
2074-
aggregateExpressions += ne
2075-
ne.toAttribute
2076-
case _ =>
2077-
val alias = Alias(e, e.toString)()
2078-
aggregateExpressions += alias
2079-
alias.toAttribute
2080-
}
2081-
}
2082-
2083-
// Push the aggregate expressions into the aggregate (if any).
2084-
if (aggregateExpressions.nonEmpty) {
2085-
Project(agg.output,
2086-
Filter(transformedAggregateFilter,
2087-
agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions)))
2088-
} else {
2089-
f
2090-
}
2091-
} else {
2092-
f
2093-
}
2094-
} catch {
2095-
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
2096-
// just return the original plan.
2097-
case ae: AnalysisException => f
2098-
}
2052+
case f @ Filter(_, agg: Aggregate) if agg.resolved =>
2053+
resolveHaving(f, agg)
20992054

21002055
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
21012056

@@ -2166,6 +2121,63 @@ class Analyzer(
21662121
def containsAggregate(condition: Expression): Boolean = {
21672122
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
21682123
}
2124+
2125+
def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
2126+
// Try resolving the condition of the filter as though it is in the aggregate clause
2127+
try {
2128+
val aggregatedCondition =
2129+
Aggregate(
2130+
agg.groupingExpressions,
2131+
Alias(filter.condition, "havingCondition")() :: Nil,
2132+
agg.child)
2133+
val resolvedOperator = executeSameContext(aggregatedCondition)
2134+
def resolvedAggregateFilter =
2135+
resolvedOperator
2136+
.asInstanceOf[Aggregate]
2137+
.aggregateExpressions.head
2138+
2139+
// If resolution was successful and we see the filter has an aggregate in it, add it to
2140+
// the original aggregate operator.
2141+
if (resolvedOperator.resolved) {
2142+
// Try to replace all aggregate expressions in the filter by an alias.
2143+
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
2144+
val transformedAggregateFilter = resolvedAggregateFilter.transform {
2145+
case ae: AggregateExpression =>
2146+
val alias = Alias(ae, ae.toString)()
2147+
aggregateExpressions += alias
2148+
alias.toAttribute
2149+
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
2150+
case e: Expression if agg.groupingExpressions.exists(_.semanticEquals(e)) &&
2151+
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
2152+
!agg.output.exists(_.semanticEquals(e)) =>
2153+
e match {
2154+
case ne: NamedExpression =>
2155+
aggregateExpressions += ne
2156+
ne.toAttribute
2157+
case _ =>
2158+
val alias = Alias(e, e.toString)()
2159+
aggregateExpressions += alias
2160+
alias.toAttribute
2161+
}
2162+
}
2163+
2164+
// Push the aggregate expressions into the aggregate (if any).
2165+
if (aggregateExpressions.nonEmpty) {
2166+
Project(agg.output,
2167+
Filter(transformedAggregateFilter,
2168+
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
2169+
} else {
2170+
filter
2171+
}
2172+
} else {
2173+
filter
2174+
}
2175+
} catch {
2176+
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
2177+
// just return the original plan.
2178+
case ae: AnalysisException => filter
2179+
}
2180+
}
21692181
}
21702182

21712183
/**
@@ -2590,11 +2602,14 @@ class Analyzer(
25902602
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
25912603

25922604
case Filter(condition, _) if hasWindowFunction(condition) =>
2593-
failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses")
2605+
failAnalysis("It is not allowed to use window functions inside WHERE clause")
2606+
2607+
case AggregateWithHaving(condition, _) if hasWindowFunction(condition) =>
2608+
failAnalysis("It is not allowed to use window functions inside HAVING clause")
25942609

25952610
// Aggregate with Having clause. This rule works with an unresolved Aggregate because
25962611
// a resolved Aggregate will not have Window Functions.
2597-
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
2612+
case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
25982613
if child.resolved &&
25992614
hasWindowFunction(aggregateExprs) &&
26002615
a.expressions.forall(_.resolved) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
2525
import org.apache.spark.sql.catalyst.parser.ParserUtils
26-
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode}
26+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, UnaryNode}
2727
import org.apache.spark.sql.catalyst.trees.TreeNode
2828
import org.apache.spark.sql.catalyst.util.quoteIdentifier
2929
import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
@@ -538,3 +538,14 @@ case class UnresolvedOrdinal(ordinal: Int)
538538
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
539539
override lazy val resolved = false
540540
}
541+
542+
/**
543+
* Represents unresolved aggregate with having clause, it is turned by the analyzer into a Filter.
544+
*/
545+
case class AggregateWithHaving(
546+
havingCondition: Expression,
547+
child: Aggregate)
548+
extends UnaryNode {
549+
override lazy val resolved: Boolean = false
550+
override def output: Seq[Attribute] = child.output
551+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,14 @@ package object dsl {
364364
Aggregate(groupingExprs, aliasedExprs, logicalPlan)
365365
}
366366

367+
def having(
368+
groupingExprs: Expression*)(
369+
aggregateExprs: Expression*)(
370+
havingCondition: Expression): LogicalPlan = {
371+
AggregateWithHaving(havingCondition,
372+
groupBy(groupingExprs: _*)(aggregateExprs: _*).asInstanceOf[Aggregate])
373+
}
374+
367375
def window(
368376
windowExpressions: Seq[NamedExpression],
369377
partitionSpec: Seq[Expression],

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
629629
case p: Predicate => p
630630
case e => Cast(e, BooleanType)
631631
}
632-
Filter(predicate, plan)
632+
plan match {
633+
case aggregate: Aggregate =>
634+
AggregateWithHaving(predicate, aggregate)
635+
case _ =>
636+
Filter(predicate, plan)
637+
}
633638
}
634639

635640
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ class PlanParserSuite extends AnalysisTest {
208208
assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b))
209209
assertEqual(
210210
"select a, b from db.c having x < 1",
211-
table("db", "c").groupBy()('a, 'b).where('x < 1))
211+
table("db", "c").having()('a, 'b)('x < 1))
212212
assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b)))
213213
assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b))
214214
assertEqual("select from tbl", OneRowRelation().select('from.as("tbl")))
@@ -574,8 +574,7 @@ class PlanParserSuite extends AnalysisTest {
574574
assertEqual(
575575
"select g from t group by g having a > (select b from s)",
576576
table("t")
577-
.groupBy('g)('g)
578-
.where('a > ScalarSubquery(table("s").select('b))))
577+
.having('g)('g)('a > ScalarSubquery(table("s").select('b))))
579578
}
580579

581580
test("table reference") {

sql/core/src/test/resources/sql-tests/inputs/having.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,6 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
1616

1717
-- SPARK-20329: make sure we handle timezones correctly
1818
SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1;
19+
20+
-- SPARK-31519: Cast in having aggregate expressions returns the wrong result
21+
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10

sql/core/src/test/resources/sql-tests/results/having.sql.out

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 5
2+
-- Number of queries: 6
33

44

55
-- !query
@@ -47,3 +47,11 @@ struct<(a + CAST(b AS BIGINT)):bigint>
4747
-- !query output
4848
3
4949
7
50+
51+
52+
-- !query
53+
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
54+
-- !query schema
55+
struct<b:bigint,fake:date>
56+
-- !query output
57+
2 2020-01-01

sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ SELECT * FROM empsalary WHERE row_number() OVER (ORDER BY salary) < 10
294294
struct<>
295295
-- !query output
296296
org.apache.spark.sql.AnalysisException
297-
It is not allowed to use window functions inside WHERE and HAVING clauses;
297+
It is not allowed to use window functions inside WHERE clause;
298298

299299

300300
-- !query
@@ -341,7 +341,7 @@ SELECT * FROM empsalary WHERE (rank() OVER (ORDER BY random())) > 10
341341
struct<>
342342
-- !query output
343343
org.apache.spark.sql.AnalysisException
344-
It is not allowed to use window functions inside WHERE and HAVING clauses;
344+
It is not allowed to use window functions inside WHERE clause;
345345

346346

347347
-- !query
@@ -350,7 +350,7 @@ SELECT * FROM empsalary WHERE rank() OVER (ORDER BY random())
350350
struct<>
351351
-- !query output
352352
org.apache.spark.sql.AnalysisException
353-
It is not allowed to use window functions inside WHERE and HAVING clauses;
353+
It is not allowed to use window functions inside WHERE clause;
354354

355355

356356
-- !query

sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -665,40 +665,46 @@ class DataFrameWindowFunctionsSuite extends QueryTest
665665
}
666666

667667
test("SPARK-24575: Window functions inside WHERE and HAVING clauses") {
668-
def checkAnalysisError(df: => DataFrame): Unit = {
668+
def checkAnalysisError(df: => DataFrame, clause: String): Unit = {
669669
val thrownException = the[AnalysisException] thrownBy {
670670
df.queryExecution.analyzed
671671
}
672-
assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses"))
672+
assert(thrownException.message.contains(s"window functions inside $clause clause"))
673673
}
674674

675-
checkAnalysisError(testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1))
676-
checkAnalysisError(testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1))
675+
checkAnalysisError(
676+
testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1), "WHERE")
677+
checkAnalysisError(
678+
testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1), "WHERE")
677679
checkAnalysisError(
678680
testData2.groupBy($"a")
679681
.agg(avg($"b").as("avgb"))
680-
.where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1))
682+
.where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1), "WHERE")
681683
checkAnalysisError(
682684
testData2.groupBy($"a")
683685
.agg(max($"b").as("maxb"), sum($"b").as("sumb"))
684-
.where(rank().over(Window.orderBy($"a")) === 1))
686+
.where(rank().over(Window.orderBy($"a")) === 1), "WHERE")
685687
checkAnalysisError(
686688
testData2.groupBy($"a")
687689
.agg(max($"b").as("maxb"), sum($"b").as("sumb"))
688-
.where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1))
690+
.where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1), "WHERE")
689691

690-
checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1"))
691-
checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1"))
692+
checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1"), "WHERE")
693+
checkAnalysisError(
694+
sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1"), "WHERE")
692695
checkAnalysisError(
693-
sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1"))
696+
sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1"),
697+
"HAVING")
694698
checkAnalysisError(
695-
sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1"))
699+
sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1"),
700+
"HAVING")
696701
checkAnalysisError(
697702
sql(
698703
s"""SELECT a, MAX(b)
699704
|FROM testData2
700705
|GROUP BY a
701-
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin))
706+
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin),
707+
"HAVING")
702708
}
703709

704710
test("window functions in multiple selects") {

0 commit comments

Comments
 (0)