Skip to content

Commit 0acfc57

Browse files
committed
Add new logic node AggregateWithHaving
1 parent cf60384 commit 0acfc57

File tree

5 files changed

+38
-8
lines changed

5 files changed

+38
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,13 @@ class Analyzer(
238238
ResolveNaturalAndUsingJoin ::
239239
ResolveOutputRelation ::
240240
ExtractWindowExpressions ::
241+
ResolveTimeZone(conf) ::
241242
GlobalAggregates ::
242243
ResolveAggregateFunctions ::
243244
TimeWindowing ::
244245
ResolveInlineTables(conf) ::
245246
ResolveHigherOrderFunctions(v1SessionCatalog) ::
246247
ResolveLambdaVariables(conf) ::
247-
ResolveTimeZone(conf) ::
248248
ResolveRandomSeed ::
249249
ResolveBinaryArithmetic(conf) ::
250250
TypeCoercion.typeCoercionRules(conf) ++
@@ -1393,6 +1393,9 @@ class Analyzer(
13931393
notMatchedActions = newNotMatchedActions)
13941394
}
13951395

1396+
// Skip the having clause here, this will be handled in ResolveAggregateFunctions.
1397+
case h: AggregateWithHaving => h
1398+
13961399
case q: LogicalPlan =>
13971400
logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}")
13981401
q.mapExpressions(resolveExpressionTopDown(_, q))
@@ -2033,8 +2036,9 @@ class Analyzer(
20332036
*/
20342037
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
20352038
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
2036-
case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved =>
2037-
2039+
case AggregateWithHaving(
2040+
cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved =>
2041+
val having = Filter(cond, agg)
20382042
// Try resolving the condition of the filter as though it is in the aggregate clause
20392043
try {
20402044
val aggregatedCondition =
@@ -2079,15 +2083,15 @@ class Analyzer(
20792083
Filter(transformedAggregateFilter,
20802084
agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions)))
20812085
} else {
2082-
f
2086+
having
20832087
}
20842088
} else {
2085-
f
2089+
having
20862090
}
20872091
} catch {
20882092
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
20892093
// just return the original plan.
2090-
case ae: AnalysisException => f
2094+
case ae: AnalysisException => having
20912095
}
20922096

20932097
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
629629
case p: Predicate => p
630630
case e => Cast(e, BooleanType)
631631
}
632-
Filter(predicate, plan)
632+
plan match {
633+
case aggregate: Aggregate =>
634+
AggregateWithHaving(predicate, aggregate)
635+
case _ =>
636+
Filter(predicate, plan)
637+
}
633638
}
634639

635640
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,16 @@ case class Aggregate(
583583
}
584584
}
585585

586+
case class AggregateWithHaving(
587+
havingCondition: Expression,
588+
child: LogicalPlan)
589+
extends UnaryNode {
590+
591+
override lazy val resolved: Boolean = false
592+
593+
override def output: Seq[Attribute] = child.output
594+
}
595+
586596
case class Window(
587597
windowExpressions: Seq[NamedExpression],
588598
partitionSpec: Seq[Expression],

sql/core/src/test/resources/sql-tests/inputs/having.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,6 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
1616

1717
-- SPARK-20329: make sure we handle timezones correctly
1818
SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1;
19+
20+
-- SPARK-31519: Cast in having aggregate expressions returns the wrong result
21+
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10

sql/core/src/test/resources/sql-tests/results/having.sql.out

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 5
2+
-- Number of queries: 6
33

44

55
-- !query
@@ -47,3 +47,11 @@ struct<(a + CAST(b AS BIGINT)):bigint>
4747
-- !query output
4848
3
4949
7
50+
51+
52+
-- !query
53+
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
54+
-- !query schema
55+
struct<b:bigint,fake:date>
56+
-- !query output
57+
2 2020-01-01

0 commit comments

Comments
 (0)