diff --git a/docs/sql-keywords.md b/docs/sql-keywords.md
index 4f50ba6d440c..34e8cfb02c9f 100644
--- a/docs/sql-keywords.md
+++ b/docs/sql-keywords.md
@@ -117,6 +117,7 @@ Below is a list of all the keywords in Spark SQL.
| FALSE | reserved | non-reserved | reserved |
| FETCH | reserved | non-reserved | reserved |
| FIELDS | non-reserved | non-reserved | non-reserved |
+ | FILTER | reserved | non-reserved | reserved |
| FILEFORMAT | non-reserved | non-reserved | non-reserved |
| FIRST | non-reserved | non-reserved | non-reserved |
| FIRST_VALUE | reserved | non-reserved | reserved |
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index cd9748eaa6f2..8ad82b3ed99a 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -766,7 +766,7 @@ primaryExpression
| '(' namedExpression (',' namedExpression)+ ')' #rowConstructor
| '(' query ')' #subqueryExpression
| functionName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')'
- (OVER windowSpec)? #functionCall
+ (FILTER '(' WHERE where=booleanExpression ')')? (OVER windowSpec)? #functionCall
| identifier '->' expression #lambda
| '(' identifier (',' identifier)+ ')' '->' expression #lambda
| value=primaryExpression '[' index=valueExpression ']' #subscript
@@ -931,6 +931,7 @@ qualifiedNameList
functionName
: qualifiedName
+ | FILTER
| LEFT
| RIGHT
;
@@ -1286,6 +1287,7 @@ nonReserved
| EXTRACT
| FALSE
| FETCH
+ | FILTER
| FIELDS
| FILEFORMAT
| FIRST
@@ -1549,6 +1551,7 @@ EXTRACT: 'EXTRACT';
FALSE: 'FALSE';
FETCH: 'FETCH';
FIELDS: 'FIELDS';
+FILTER: 'FILTER';
FILEFORMAT: 'FILEFORMAT';
FIRST: 'FIRST';
FIRST_VALUE: 'FIRST_VALUE';
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 2438ef921822..bfbe3b4a9c94 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1381,8 +1381,8 @@ class Analyzer(
*/
def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
expr.transformUp {
- case f1: UnresolvedFunction if containsStar(f1.children) =>
- f1.copy(children = f1.children.flatMap {
+ case f1: UnresolvedFunction if containsStar(f1.arguments) =>
+ f1.copy(arguments = f1.arguments.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
@@ -1734,26 +1734,37 @@ class Analyzer(
s"its class is ${other.getClass.getCanonicalName}, which is not a generator.")
}
}
- case u @ UnresolvedFunction(funcId, children, isDistinct) =>
+ case u @ UnresolvedFunction(funcId, arguments, isDistinct, filter) =>
withPosition(u) {
- v1SessionCatalog.lookupFunction(funcId, children) match {
+ v1SessionCatalog.lookupFunction(funcId, arguments) match {
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
// the context of a Window clause. They do not need to be wrapped in an
// AggregateExpression.
case wf: AggregateWindowFunction =>
- if (isDistinct) {
- failAnalysis(
- s"DISTINCT specified, but ${wf.prettyName} is not an aggregate function")
+ if (isDistinct || filter.isDefined) {
+ failAnalysis("DISTINCT or FILTER specified, " +
+ s"but ${wf.prettyName} is not an aggregate function")
} else {
wf
}
// We get an aggregate function, we need to wrap it in an AggregateExpression.
- case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct)
+ case agg: AggregateFunction =>
+ // TODO: SPARK-30276 Support Filter expression allows simultaneous use of DISTINCT
+ if (filter.isDefined) {
+ if (isDistinct) {
+ failAnalysis("DISTINCT and FILTER cannot be used in aggregate functions " +
+ "at the same time")
+ } else if (!filter.get.deterministic) {
+ failAnalysis("FILTER expression is non-deterministic, " +
+ "it cannot be used in aggregate functions")
+ }
+ }
+ AggregateExpression(agg, Complete, isDistinct, filter)
// This function is not an aggregate function, just return the resolved one.
case other =>
- if (isDistinct) {
- failAnalysis(
- s"DISTINCT specified, but ${other.prettyName} is not an aggregate function")
+ if (isDistinct || filter.isDefined) {
+ failAnalysis("DISTINCT or FILTER specified, " +
+ s"but ${other.prettyName} is not an aggregate function")
} else {
other
}
@@ -2351,7 +2362,7 @@ class Analyzer(
// Extract Windowed AggregateExpression
case we @ WindowExpression(
- ae @ AggregateExpression(function, _, _, _),
+ ae @ AggregateExpression(function, _, _, _, _),
spec: WindowSpecDefinition) =>
val newChildren = function.children.map(extractExpr)
val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
@@ -2359,7 +2370,7 @@ class Analyzer(
seenWindowAggregates += newAgg
WindowExpression(newAgg, spec)
- case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) =>
+ case AggregateExpression(aggFunc, _, _, _, _) if hasWindowFunction(aggFunc.children) =>
failAnalysis("It is not allowed to use a window function inside an aggregate " +
"function. Please use the inner window function in a sub-query.")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index cfb16233b394..9268c7c3c944 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -155,7 +155,7 @@ trait CheckAnalysis extends PredicateHelper {
case g: GroupingID =>
failAnalysis("grouping_id() can only be used with GroupingSets/Cube/Rollup")
- case w @ WindowExpression(AggregateExpression(_, _, true, _), _) =>
+ case w @ WindowExpression(AggregateExpression(_, _, true, _, _), _) =>
failAnalysis(s"Distinct window functions are not supported: $w")
case w @ WindowExpression(_: OffsetWindowFunction,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
index 1cd7f412bb67..11f94762d43e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
@@ -33,11 +33,14 @@ import org.apache.spark.sql.types.DataType
case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
- case u @ UnresolvedFunction(fn, children, false)
+ case u @ UnresolvedFunction(fn, children, false, filter)
if hasLambdaAndResolvedArguments(children) =>
withPosition(u) {
catalog.lookupFunction(fn, children) match {
- case func: HigherOrderFunction => func
+ case func: HigherOrderFunction =>
+ filter.foreach(_.failAnalysis("FILTER predicate specified, " +
+ s"but ${func.prettyName} is not an aggregate function"))
+ func
case other => other.failAnalysis(
"A lambda function should only be used in a higher order function. However, " +
s"its class is ${other.getClass.getCanonicalName}, which is not a " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index e5a6f30c330e..608f39c2d86f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -243,10 +243,13 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio
case class UnresolvedFunction(
name: FunctionIdentifier,
- children: Seq[Expression],
- isDistinct: Boolean)
+ arguments: Seq[Expression],
+ isDistinct: Boolean,
+ filter: Option[Expression] = None)
extends Expression with Unevaluable {
+ override def children: Seq[Expression] = arguments ++ filter.toSeq
+
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
@@ -257,8 +260,8 @@ case class UnresolvedFunction(
}
object UnresolvedFunction {
- def apply(name: String, children: Seq[Expression], isDistinct: Boolean): UnresolvedFunction = {
- UnresolvedFunction(FunctionIdentifier(name, None), children, isDistinct)
+ def apply(name: String, arguments: Seq[Expression], isDistinct: Boolean): UnresolvedFunction = {
+ UnresolvedFunction(FunctionIdentifier(name, None), arguments, isDistinct)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 595d7db0bb49..1f85b07d8385 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -71,23 +71,27 @@ object AggregateExpression {
def apply(
aggregateFunction: AggregateFunction,
mode: AggregateMode,
- isDistinct: Boolean): AggregateExpression = {
+ isDistinct: Boolean,
+ filter: Option[Expression] = None): AggregateExpression = {
AggregateExpression(
aggregateFunction,
mode,
isDistinct,
+ filter,
NamedExpression.newExprId)
}
}
/**
* A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field
- * (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
+ * (`isDistinct`) indicating if DISTINCT keyword is specified for this function and
+ * a field (`filter`) indicating if filter clause is specified for this function.
*/
case class AggregateExpression(
aggregateFunction: AggregateFunction,
mode: AggregateMode,
isDistinct: Boolean,
+ filter: Option[Expression],
resultId: ExprId)
extends Expression
with Unevaluable {
@@ -104,6 +108,8 @@ case class AggregateExpression(
UnresolvedAttribute(aggregateFunction.toString)
}
+ lazy val filterAttributes: AttributeSet = filter.map(_.references).getOrElse(AttributeSet.empty)
+
// We compute the same thing regardless of our final result.
override lazy val canonicalized: Expression = {
val normalizedAggFunc = mode match {
@@ -119,10 +125,12 @@ case class AggregateExpression(
normalizedAggFunc.canonicalized.asInstanceOf[AggregateFunction],
mode,
isDistinct,
+ filter.map(_.canonicalized),
ExprId(0))
}
- override def children: Seq[Expression] = aggregateFunction :: Nil
+ override def children: Seq[Expression] = aggregateFunction +: filter.toSeq
+
override def dataType: DataType = aggregateFunction.dataType
override def foldable: Boolean = false
override def nullable: Boolean = aggregateFunction.nullable
@@ -130,7 +138,7 @@ case class AggregateExpression(
@transient
override lazy val references: AttributeSet = {
mode match {
- case Partial | Complete => aggregateFunction.references
+ case Partial | Complete => aggregateFunction.references ++ filterAttributes
case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 05fd5e35e22a..836b29d6f37c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1459,7 +1459,7 @@ object DecimalAggregates extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
- case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _), _) => af match {
+ case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) => af match {
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))),
prec + 10, scale)
@@ -1473,7 +1473,7 @@ object DecimalAggregates extends Rule[LogicalPlan] {
case _ => we
}
- case ae @ AggregateExpression(af, _, _, _) => af match {
+ case ae @ AggregateExpression(af, _, _, _, _) => af match {
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
index b9468007cac6..d14d45449013 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types.IntegerType
* aggregation in which the regular aggregation expressions and every distinct clause is aggregated
* in a separate group. The results are then combined in a second aggregate.
*
- * For example (in scala):
+ * First example: query without filter clauses (in scala):
* {{{
* val data = Seq(
* ("a", "ca1", "cb1", 10),
@@ -75,6 +75,49 @@ import org.apache.spark.sql.types.IntegerType
* LocalTableScan [...]
* }}}
*
+ * Second example: aggregate function without distinct and with filter clauses (in sql):
+ * {{{
+ * SELECT
+ * COUNT(DISTINCT cat1) as cat1_cnt,
+ * COUNT(DISTINCT cat2) as cat2_cnt,
+ * SUM(value) FILTER (WHERE id > 1) AS total
+ * FROM
+ * data
+ * GROUP BY
+ * key
+ * }}}
+ *
+ * This translates to the following (pseudo) logical plan:
+ * {{{
+ * Aggregate(
+ * key = ['key]
+ * functions = [COUNT(DISTINCT 'cat1),
+ * COUNT(DISTINCT 'cat2),
+ * sum('value) with FILTER('id > 1)]
+ * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
+ * LocalTableScan [...]
+ * }}}
+ *
+ * This rule rewrites this logical plan to the following (pseudo) logical plan:
+ * {{{
+ * Aggregate(
+ * key = ['key]
+ * functions = [count(if (('gid = 1)) 'cat1 else null),
+ * count(if (('gid = 2)) 'cat2 else null),
+ * first(if (('gid = 0)) 'total else null) ignore nulls]
+ * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
+ * Aggregate(
+ * key = ['key, 'cat1, 'cat2, 'gid]
+ * functions = [sum('value) with FILTER('id > 1)]
+ * output = ['key, 'cat1, 'cat2, 'gid, 'total])
+ * Expand(
+ * projections = [('key, null, null, 0, cast('value as bigint), 'id),
+ * ('key, 'cat1, null, 1, null, null),
+ * ('key, null, 'cat2, 2, null, null)]
+ * output = ['key, 'cat1, 'cat2, 'gid, 'value, 'id])
+ * LocalTableScan [...]
+ * }}}
+ *
* The rule does the following things here:
* 1. Expand the data. There are three aggregation groups in this query:
* i. the non-distinct group;
@@ -183,9 +226,10 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// only expand unfoldable children
val regularAggExprs = aggExpressions
.filter(e => !e.isDistinct && e.children.exists(!_.foldable))
- val regularAggChildren = regularAggExprs
+ val regularAggFunChildren = regularAggExprs
.flatMap(_.aggregateFunction.children.filter(!_.foldable))
- .distinct
+ val regularAggFilterAttrs = regularAggExprs.flatMap(_.filterAttributes)
+ val regularAggChildren = (regularAggFunChildren ++ regularAggFilterAttrs).distinct
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)
// Setup aggregates for 'regular' aggregate expressions.
@@ -194,7 +238,12 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get)
- val operator = Alias(e.copy(aggregateFunction = af), e.sql)()
+ // We changed the attributes in the [[Expand]] output using expressionAttributePair.
+ // So we need to replace the attributes in FILTER expression with new ones.
+ val filterOpt = e.filter.map(_.transform {
+ case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a)
+ })
+ val operator = Alias(e.copy(aggregateFunction = af, filter = filterOpt), e.sql)()
// Select the result of the first aggregate in the last aggregate.
val result = AggregateExpression(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 98acad8a7413..b26e936c2ea2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -525,9 +525,9 @@ object NullPropagation extends Rule[LogicalPlan] {
case q: LogicalPlan => q transformExpressionsUp {
case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) =>
Cast(Literal(0L), e.dataType, Option(SQLConf.get.sessionLocalTimeZone))
- case e @ AggregateExpression(Count(exprs), _, _, _) if exprs.forall(isNullLiteral) =>
+ case e @ AggregateExpression(Count(exprs), _, _, _, _) if exprs.forall(isNullLiteral) =>
Cast(Literal(0L), e.dataType, Option(SQLConf.get.sessionLocalTimeZone))
- case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
+ case ae @ AggregateExpression(Count(exprs), _, false, _, _) if !exprs.exists(_.nullable) =>
// This rule should be only triggered when isDistinct field is false.
ae.copy(aggregateFunction = Count(Literal(1)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 32dbd389afd9..8ac14264a929 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -368,7 +368,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
// in the expression with the value they would return for zero input tuples.
// Also replace attribute refs (for example, for grouping columns) with NULL.
val rewrittenExpr = expr transform {
- case a @ AggregateExpression(aggFunc, _, _, resultId) =>
+ case a @ AggregateExpression(aggFunc, _, _, resultId, _) =>
aggFunc.defaultResult.getOrElse(Literal.default(NullType))
case _: AttributeReference => Literal.default(NullType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index d3df7e03962e..41df5de61694 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1599,8 +1599,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case expressions =>
expressions
}
+ val filter = Option(ctx.where).map(expression(_))
val function = UnresolvedFunction(
- getFunctionIdentifier(ctx.functionName), arguments, isDistinct)
+ getFunctionIdentifier(ctx.functionName), arguments, isDistinct, filter)
// Check if the function is evaluated in a windowed context.
ctx.windowSpec match {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 646abb6b8591..7023dbe2a367 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -167,12 +167,38 @@ class AnalysisErrorSuite extends AnalysisTest {
errorTest(
"distinct function",
CatalystSqlParser.parsePlan("SELECT hex(DISTINCT a) FROM TaBlE"),
- "DISTINCT specified, but hex is not an aggregate function" :: Nil)
+ "DISTINCT or FILTER specified, but hex is not an aggregate function" :: Nil)
+
+ errorTest(
+ "non aggregate function with filter predicate",
+ CatalystSqlParser.parsePlan("SELECT hex(a) FILTER (WHERE c = 1) FROM TaBlE2"),
+ "DISTINCT or FILTER specified, but hex is not an aggregate function" :: Nil)
errorTest(
"distinct window function",
- CatalystSqlParser.parsePlan("SELECT percent_rank(DISTINCT a) over () FROM TaBlE"),
- "DISTINCT specified, but percent_rank is not an aggregate function" :: Nil)
+ CatalystSqlParser.parsePlan("SELECT percent_rank(DISTINCT a) OVER () FROM TaBlE"),
+ "DISTINCT or FILTER specified, but percent_rank is not an aggregate function" :: Nil)
+
+ errorTest(
+ "window function with filter predicate",
+ CatalystSqlParser.parsePlan("SELECT percent_rank(a) FILTER (WHERE c > 1) OVER () FROM TaBlE2"),
+ "DISTINCT or FILTER specified, but percent_rank is not an aggregate function" :: Nil)
+
+ errorTest(
+ "higher order function with filter predicate",
+ CatalystSqlParser.parsePlan("SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) " +
+ "FILTER (WHERE c > 1)"),
+ "FILTER predicate specified, but aggregate is not an aggregate function" :: Nil)
+
+ errorTest(
+ "DISTINCT and FILTER cannot be used in aggregate functions at the same time",
+ CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"),
+ "DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil)
+
+ errorTest(
+ "FILTER expression is non-deterministic, it cannot be used in aggregate functions",
+ CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"),
+ "FILTER expression is non-deterministic, it cannot be used in aggregate functions" :: Nil)
errorTest(
"nested aggregate functions",
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala
index 9aab5b390fe1..9807b5dbe934 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala
@@ -230,14 +230,14 @@ object AggregatingAccumulator {
val typedImperatives = mutable.Buffer.empty[TypedImperativeAggregate[_]]
val inputAttributeSeq: AttributeSeq = inputAttributes
val resultExpressions = functions.map(_.transform {
- case AggregateExpression(agg: DeclarativeAggregate, _, _, _) =>
+ case AggregateExpression(agg: DeclarativeAggregate, _, _, _, _) =>
aggBufferAttributes ++= agg.aggBufferAttributes
inputAggBufferAttributes ++= agg.inputAggBufferAttributes
initialValues ++= agg.initialValues
updateExpressions ++= agg.updateExpressions
mergeExpressions ++= agg.mergeExpressions
agg.evaluateExpression
- case AggregateExpression(agg: ImperativeAggregate, _, _, _) =>
+ case AggregateExpression(agg: ImperativeAggregate, _, _, _, _) =>
val imperative = BindReferences.bindReference(agg
.withNewMutableAggBufferOffset(aggBufferAttributes.size)
.withNewInputAggBufferOffset(inputAggBufferAttributes.size),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index 4d762c5ea9f3..e729fa278e9f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -174,7 +174,7 @@ object AggUtils {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
// to AttributeReferences.
- case agg @ AggregateExpression(aggregateFunction, mode, true, _) =>
+ case agg @ AggregateExpression(aggregateFunction, mode, true, _, _) =>
aggregateFunction.transformDown(distinctColumnAttributeLookup)
.asInstanceOf[AggregateFunction]
case agg =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index d03de1507fbb..527a9eac9948 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -157,19 +157,44 @@ abstract class AggregationIterator(
inputAttributes: Seq[Attribute]): (InternalRow, InternalRow) => Unit = {
val joinedRow = new JoinedRow
if (expressions.nonEmpty) {
- val mergeExpressions = functions.zip(expressions).flatMap {
- case (ae: DeclarativeAggregate, expression) =>
- expression.mode match {
- case Partial | Complete => ae.updateExpressions
- case PartialMerge | Final => ae.mergeExpressions
+ val mergeExpressions =
+ functions.zip(expressions.map(ae => (ae.mode, ae.isDistinct, ae.filter))).flatMap {
+ case (ae: DeclarativeAggregate, (mode, isDistinct, filter)) =>
+ mode match {
+ case Partial | Complete =>
+ if (filter.isDefined) {
+ ae.updateExpressions.zip(ae.aggBufferAttributes).map {
+ case (updateExpr, attr) => If(filter.get, updateExpr, attr)
+ }
+ } else {
+ ae.updateExpressions
+ }
+ case PartialMerge | Final => ae.mergeExpressions
+ }
+ case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ }
+ // Initialize predicates for aggregate functions if necessary
+ val predicateOptions = expressions.map {
+ case AggregateExpression(_, mode, _, Some(filter), _) =>
+ mode match {
+ case Partial | Complete =>
+ val predicate = Predicate.create(filter, inputAttributes)
+ predicate.initialize(partIndex)
+ Some(predicate)
+ case _ => None
}
- case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ case _ => None
}
val updateFunctions = functions.zipWithIndex.collect {
case (ae: ImperativeAggregate, i) =>
expressions(i).mode match {
case Partial | Complete =>
- (buffer: InternalRow, row: InternalRow) => ae.update(buffer, row)
+ if (predicateOptions(i).isDefined) {
+ (buffer: InternalRow, row: InternalRow) =>
+ if (predicateOptions(i).get.eval(row)) { ae.update(buffer, row) }
+ } else {
+ (buffer: InternalRow, row: InternalRow) => ae.update(buffer, row)
+ }
case PartialMerge | Final =>
(buffer: InternalRow, row: InternalRow) => ae.merge(buffer, row)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index b79d3a278bb3..7f19d2754673 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -152,8 +152,10 @@ case class HashAggregateExec(
override def usedInputs: AttributeSet = inputSet
override def supportCodegen: Boolean = {
- // ImperativeAggregate is not supported right now
- !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
+ // ImperativeAggregate and filter predicate are not supported right now
+ // TODO: SPARK-30027 Support codegen for filter exprs in HashAggregateExec
+ !(aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) ||
+ aggregateExpressions.exists(_.filter.isDefined))
}
override def inputRDDs(): Seq[RDD[InternalRow]] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
index 1f325c11c9e4..a30b49f5680c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
@@ -61,9 +61,9 @@ class ObjectAggregationIterator(
// Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers
private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = {
val newExpressions = aggregateExpressions.map {
- case agg @ AggregateExpression(_, Partial, _, _) =>
+ case agg @ AggregateExpression(_, Partial, _, _, _) =>
agg.copy(mode = PartialMerge)
- case agg @ AggregateExpression(_, Complete, _, _) =>
+ case agg @ AggregateExpression(_, Complete, _, _, _) =>
agg.copy(mode = Final)
case other => other
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 6dc64657ebf1..99358fbf4e94 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -249,9 +249,9 @@ class TungstenAggregationIterator(
// Basically the value of the KVIterator returned by externalSorter
// will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it.
val newExpressions = aggregateExpressions.map {
- case agg @ AggregateExpression(_, Partial, _, _) =>
+ case agg @ AggregateExpression(_, Partial, _, _, _) =>
agg.copy(mode = PartialMerge)
- case agg @ AggregateExpression(_, Complete, _, _) =>
+ case agg @ AggregateExpression(_, Complete, _, _, _) =>
agg.copy(mode = Final)
case other => other
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
index e8248b702875..d5d11c45f853 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
@@ -136,7 +136,7 @@ abstract class WindowExecBase(
case e @ WindowExpression(function, spec) =>
val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
function match {
- case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f)
+ case AggregateExpression(f, _, _, _, _) => collect("AGGREGATE", frame, e, f)
case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f)
case f: OffsetWindowFunction => collect("OFFSET", frame, e, f)
case f: PythonUDF => collect("AGGREGATE", frame, e, f)
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql
new file mode 100644
index 000000000000..beb5b9e5fe51
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql
@@ -0,0 +1,132 @@
+-- Test filter clause for aggregate expression.
+
+-- Test data.
+CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES
+(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null)
+AS testData(a, b);
+
+CREATE OR REPLACE TEMPORARY VIEW EMP AS SELECT * FROM VALUES
+ (100, "emp 1", date "2005-01-01", 100.00D, 10),
+ (100, "emp 1", date "2005-01-01", 100.00D, 10),
+ (200, "emp 2", date "2003-01-01", 200.00D, 10),
+ (300, "emp 3", date "2002-01-01", 300.00D, 20),
+ (400, "emp 4", date "2005-01-01", 400.00D, 30),
+ (500, "emp 5", date "2001-01-01", 400.00D, NULL),
+ (600, "emp 6 - no dept", date "2001-01-01", 400.00D, 100),
+ (700, "emp 7", date "2010-01-01", 400.00D, 100),
+ (800, "emp 8", date "2016-01-01", 150.00D, 70)
+AS EMP(id, emp_name, hiredate, salary, dept_id);
+
+CREATE OR REPLACE TEMPORARY VIEW DEPT AS SELECT * FROM VALUES
+ (10, "dept 1", "CA"),
+ (20, "dept 2", "NY"),
+ (30, "dept 3", "TX"),
+ (40, "dept 4 - unassigned", "OR"),
+ (50, "dept 5 - unassigned", "NJ"),
+ (70, "dept 7", "FL")
+AS DEPT(dept_id, dept_name, state);
+
+-- Aggregate with filter and empty GroupBy expressions.
+SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData;
+SELECT COUNT(a) FILTER (WHERE a = 1), COUNT(b) FILTER (WHERE a > 1) FROM testData;
+SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp;
+SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp;
+SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp;
+SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp;
+-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
+-- SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp;
+
+-- Aggregate with filter and non-empty GroupBy expressions.
+SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a;
+SELECT a, COUNT(b) FILTER (WHERE a != 2) FROM testData GROUP BY b;
+SELECT COUNT(a) FILTER (WHERE a >= 0), COUNT(b) FILTER (WHERE a >= 3) FROM testData GROUP BY a;
+SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp GROUP BY dept_id;
+SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id;
+SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id;
+SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id;
+-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
+-- SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id;
+
+-- Aggregate with filter and grouped by literals.
+SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1;
+SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= date "2003-01-01") FROM emp GROUP BY 1;
+SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_date("2003-01-01")) FROM emp GROUP BY 1;
+SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_timestamp("2003-01-01")) FROM emp GROUP BY 1;
+
+-- Aggregate with filter, more than one aggregate function goes with distinct.
+select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id;
+select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id;
+select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id;
+select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id;
+-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
+-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id;
+-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id;
+-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id;
+-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id;
+-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id;
+-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id;
+
+-- Aggregate with filter and grouped by literals (hash aggregate), here the input table is filtered using WHERE.
+SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1;
+
+-- Aggregate with filter and grouped by literals (sort aggregate), here the input table is filtered using WHERE.
+SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1;
+
+-- Aggregate with filter and complex GroupBy expressions.
+SELECT a + b, COUNT(b) FILTER (WHERE b >= 2) FROM testData GROUP BY a + b;
+SELECT a + 2, COUNT(b) FILTER (WHERE b IN (1, 2)) FROM testData GROUP BY a + 1;
+SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1;
+
+-- Aggregate with filter, foldable input and multiple distinct groups.
+-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
+-- SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2)
+-- FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a;
+
+-- Check analysis exceptions
+SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k;
+
+-- Aggregate with filter contains exists subquery
+SELECT emp.dept_id,
+ avg(salary),
+ avg(salary) FILTER (WHERE id > (SELECT 200))
+FROM emp
+GROUP BY dept_id;
+
+SELECT emp.dept_id,
+ avg(salary),
+ avg(salary) FILTER (WHERE emp.dept_id = (SELECT dept_id FROM dept LIMIT 1))
+FROM emp
+GROUP BY dept_id;
+
+-- [SPARK-30220] Support Filter expression uses IN/EXISTS predicate sub-queries
+SELECT emp.dept_id,
+ avg(salary),
+ avg(salary) FILTER (WHERE EXISTS (SELECT state
+ FROM dept
+ WHERE dept.dept_id = emp.dept_id))
+FROM emp
+GROUP BY dept_id;
+
+SELECT emp.dept_id,
+ Sum(salary),
+ Sum(salary) FILTER (WHERE NOT EXISTS (SELECT state
+ FROM dept
+ WHERE dept.dept_id = emp.dept_id))
+FROM emp
+GROUP BY dept_id;
+
+SELECT emp.dept_id,
+ avg(salary),
+ avg(salary) FILTER (WHERE emp.dept_id IN (SELECT DISTINCT dept_id
+ FROM dept))
+FROM emp
+GROUP BY dept_id;
+SELECT emp.dept_id,
+ Sum(salary),
+ Sum(salary) FILTER (WHERE emp.dept_id NOT IN (SELECT DISTINCT dept_id
+ FROM dept))
+FROM emp
+GROUP BY dept_id;
+
+-- Aggregate with filter is subquery
+SELECT t1.b FROM (SELECT COUNT(b) FILTER (WHERE a >= 2) AS b FROM testData) t1;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
index 6f5e549644bb..746b67723483 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
@@ -232,16 +232,16 @@ select max(min(unique1)) from tenk1;
-- drop table bytea_test_table;
--- [SPARK-27986] Support Aggregate Expressions with filter
-- FILTER tests
--- select min(unique1) filter (where unique1 > 100) from tenk1;
+select min(unique1) filter (where unique1 > 100) from tenk1;
--- select sum(1/ten) filter (where ten > 0) from tenk1;
+select sum(1/ten) filter (where ten > 0) from tenk1;
-- select ten, sum(distinct four) filter (where four::text ~ '123') from onek a
-- group by ten;
+-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
-- select ten, sum(distinct four) filter (where four > 10) from onek a
-- group by ten
-- having exists (select 1 from onek b where sum(distinct a.four) = b.four);
@@ -254,6 +254,7 @@ select max(min(unique1)) from tenk1;
select (select count(*)
from (values (1)) t0(inner_c))
from (values (2),(3)) t1(outer_c); -- inner query is aggregation query
+-- [SPARK-30219] Support Filter expression reference the outer query
-- select (select count(*) filter (where outer_c <> 0)
-- from (values (1)) t0(inner_c))
-- from (values (2),(3)) t1(outer_c); -- outer query is aggregation query
@@ -265,6 +266,7 @@ from (values (2),(3)) t1(outer_c); -- inner query is aggregation query
-- filter (where o.unique1 < 10))
-- from tenk1 o; -- outer query is aggregation query
+-- [SPARK-30220] Support Filter expression uses IN/EXISTS predicate sub-queries
-- subquery in FILTER clause (PostgreSQL extension)
-- select sum(unique1) FILTER (WHERE
-- unique1 IN (SELECT unique1 FROM onek where unique1 < 100)) FROM tenk1;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql
index 330817fb5374..c9ee83eb75eb 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql
@@ -336,8 +336,8 @@ order by 2,1;
-- order by 2,1;
-- FILTER queries
--- [SPARK-27986] Support Aggregate Expressions with filter
--- select ten, sum(distinct four) filter (where string(four) ~ '123') from onek a
+-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
+-- select ten, sum(distinct four) filter (where string(four) like '123') from onek a
-- group by rollup(ten);
-- More rescan tests
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql
index 8187f8a2773f..cd3b74b3aa03 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql
@@ -404,7 +404,7 @@ SELECT ntile(0) OVER (ORDER BY ten), ten, four FROM tenk1;
-- filter
--- [SPARK-28500] Adds support for `filter` clause
+-- [SPARK-30182] Support nested aggregates
-- SELECT sum(salary), row_number() OVER (ORDER BY depname), sum(
-- sum(salary) FILTER (WHERE enroll_date > '2007-01-01')
-- )
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part3.sql
index 231c5235b313..b11c8c05f310 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part3.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part3.sql
@@ -229,7 +229,6 @@ select udf(max(min(unique1))) from tenk1;
-- drop table bytea_test_table;
--- [SPARK-27986] Support Aggregate Expressions with filter
-- FILTER tests
-- select min(unique1) filter (where unique1 > 100) from tenk1;
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out
new file mode 100644
index 000000000000..5d266c980a49
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out
@@ -0,0 +1,464 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 37
+
+
+-- !query 0
+CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES
+(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null)
+AS testData(a, b)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+CREATE OR REPLACE TEMPORARY VIEW EMP AS SELECT * FROM VALUES
+ (100, "emp 1", date "2005-01-01", 100.00D, 10),
+ (100, "emp 1", date "2005-01-01", 100.00D, 10),
+ (200, "emp 2", date "2003-01-01", 200.00D, 10),
+ (300, "emp 3", date "2002-01-01", 300.00D, 20),
+ (400, "emp 4", date "2005-01-01", 400.00D, 30),
+ (500, "emp 5", date "2001-01-01", 400.00D, NULL),
+ (600, "emp 6 - no dept", date "2001-01-01", 400.00D, 100),
+ (700, "emp 7", date "2010-01-01", 400.00D, 100),
+ (800, "emp 8", date "2016-01-01", 150.00D, 70)
+AS EMP(id, emp_name, hiredate, salary, dept_id)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+CREATE OR REPLACE TEMPORARY VIEW DEPT AS SELECT * FROM VALUES
+ (10, "dept 1", "CA"),
+ (20, "dept 2", "NY"),
+ (30, "dept 3", "TX"),
+ (40, "dept 4 - unassigned", "OR"),
+ (50, "dept 5 - unassigned", "NJ"),
+ (70, "dept 7", "FL")
+AS DEPT(dept_id, dept_name, state)
+-- !query 2 schema
+struct<>
+-- !query 2 output
+
+
+
+-- !query 3
+SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData
+-- !query 3 schema
+struct<>
+-- !query 3 output
+org.apache.spark.sql.AnalysisException
+grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) AS `count(b)`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.;
+
+
+-- !query 4
+SELECT COUNT(a) FILTER (WHERE a = 1), COUNT(b) FILTER (WHERE a > 1) FROM testData
+-- !query 4 schema
+struct
+-- !query 4 output
+2 4
+
+
+-- !query 5
+SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp
+-- !query 5 schema
+struct
+-- !query 5 output
+2
+
+
+-- !query 6
+SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp
+-- !query 6 schema
+struct
+-- !query 6 output
+2
+
+
+-- !query 7
+SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp
+-- !query 7 schema
+struct
+-- !query 7 output
+2
+
+
+-- !query 8
+SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp
+-- !query 8 schema
+struct
+-- !query 8 output
+2
+
+
+-- !query 9
+SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a
+-- !query 9 schema
+struct
+-- !query 9 output
+1 0
+2 2
+3 2
+NULL 0
+
+
+-- !query 10
+SELECT a, COUNT(b) FILTER (WHERE a != 2) FROM testData GROUP BY b
+-- !query 10 schema
+struct<>
+-- !query 10 output
+org.apache.spark.sql.AnalysisException
+expression 'testdata.`a`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;
+
+
+-- !query 11
+SELECT COUNT(a) FILTER (WHERE a >= 0), COUNT(b) FILTER (WHERE a >= 3) FROM testData GROUP BY a
+-- !query 11 schema
+struct
+-- !query 11 output
+0 0
+2 0
+2 0
+3 2
+
+
+-- !query 12
+SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp GROUP BY dept_id
+-- !query 12 schema
+struct
+-- !query 12 output
+10 200.0
+100 400.0
+20 NULL
+30 400.0
+70 150.0
+NULL NULL
+
+
+-- !query 13
+SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id
+-- !query 13 schema
+struct
+-- !query 13 output
+10 200.0
+100 400.0
+20 NULL
+30 400.0
+70 150.0
+NULL NULL
+
+
+-- !query 14
+SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id
+-- !query 14 schema
+struct
+-- !query 14 output
+10 200.0
+100 400.0
+20 NULL
+30 400.0
+70 150.0
+NULL NULL
+
+
+-- !query 15
+SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id
+-- !query 15 schema
+struct
+-- !query 15 output
+10 200.0
+100 400.0
+20 NULL
+30 400.0
+70 150.0
+NULL NULL
+
+
+-- !query 16
+SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1
+-- !query 16 schema
+struct
+-- !query 16 output
+foo 6
+
+
+-- !query 17
+SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= date "2003-01-01") FROM emp GROUP BY 1
+-- !query 17 schema
+struct
+-- !query 17 output
+foo 1350.0
+
+
+-- !query 18
+SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_date("2003-01-01")) FROM emp GROUP BY 1
+-- !query 18 schema
+struct
+-- !query 18 output
+foo 1350.0
+
+
+-- !query 19
+SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_timestamp("2003-01-01")) FROM emp GROUP BY 1
+-- !query 19 schema
+struct
+-- !query 19 output
+foo 1350.0
+
+
+-- !query 20
+select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id
+-- !query 20 schema
+struct
+-- !query 20 output
+10 2 2 400.0 NULL
+100 2 2 800.0 800.0
+20 1 1 300.0 300.0
+30 1 1 400.0 400.0
+70 1 1 150.0 150.0
+NULL 1 1 400.0 400.0
+
+
+-- !query 21
+select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id
+-- !query 21 schema
+struct
+-- !query 21 output
+10 2 2 400.0 NULL
+100 2 2 800.0 800.0
+20 1 1 300.0 NULL
+30 1 1 400.0 NULL
+70 1 1 150.0 150.0
+NULL 1 1 400.0 NULL
+
+
+-- !query 22
+select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id
+-- !query 22 schema
+struct
+-- !query 22 output
+10 2 2 400.0 NULL
+100 2 2 NULL 800.0
+20 1 1 300.0 300.0
+30 1 1 NULL 400.0
+70 1 1 150.0 150.0
+NULL 1 1 NULL 400.0
+
+
+-- !query 23
+select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id
+-- !query 23 schema
+struct
+-- !query 23 output
+10 2 2 400.0 NULL
+100 2 2 NULL 800.0
+20 1 1 300.0 NULL
+30 1 1 NULL NULL
+70 1 1 150.0 150.0
+NULL 1 1 NULL NULL
+
+
+-- !query 24
+SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1
+-- !query 24 schema
+struct
+-- !query 24 output
+
+
+
+-- !query 25
+SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1
+-- !query 25 schema
+struct>
+-- !query 25 output
+
+
+
+-- !query 26
+SELECT a + b, COUNT(b) FILTER (WHERE b >= 2) FROM testData GROUP BY a + b
+-- !query 26 schema
+struct<(a + b):int,count(b):bigint>
+-- !query 26 output
+2 0
+3 1
+4 1
+5 1
+NULL 0
+
+
+-- !query 27
+SELECT a + 2, COUNT(b) FILTER (WHERE b IN (1, 2)) FROM testData GROUP BY a + 1
+-- !query 27 schema
+struct<>
+-- !query 27 output
+org.apache.spark.sql.AnalysisException
+expression 'testdata.`a`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;
+
+
+-- !query 28
+SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1
+-- !query 28 schema
+struct<((a + 1) + 1):int,count(b):bigint>
+-- !query 28 output
+3 2
+4 2
+5 2
+NULL 1
+
+
+-- !query 29
+SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k
+-- !query 29 schema
+struct
+-- !query 29 output
+1 2
+2 2
+3 2
+NULL 1
+
+
+-- !query 30
+SELECT emp.dept_id,
+ avg(salary),
+ avg(salary) FILTER (WHERE id > (SELECT 200))
+FROM emp
+GROUP BY dept_id
+-- !query 30 schema
+struct
+-- !query 30 output
+10 133.33333333333334 NULL
+100 400.0 400.0
+20 300.0 300.0
+30 400.0 400.0
+70 150.0 150.0
+NULL 400.0 400.0
+
+
+-- !query 31
+SELECT emp.dept_id,
+ avg(salary),
+ avg(salary) FILTER (WHERE emp.dept_id = (SELECT dept_id FROM dept LIMIT 1))
+FROM emp
+GROUP BY dept_id
+-- !query 31 schema
+struct
+-- !query 31 output
+10 133.33333333333334 133.33333333333334
+100 400.0 NULL
+20 300.0 NULL
+30 400.0 NULL
+70 150.0 NULL
+NULL 400.0 NULL
+
+
+-- !query 32
+SELECT emp.dept_id,
+ avg(salary),
+ avg(salary) FILTER (WHERE EXISTS (SELECT state
+ FROM dept
+ WHERE dept.dept_id = emp.dept_id))
+FROM emp
+GROUP BY dept_id
+-- !query 32 schema
+struct<>
+-- !query 32 output
+org.apache.spark.sql.AnalysisException
+IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) AS avg(salary)#x]
+: +- Project [state#x]
+: +- Filter (dept_id#x = outer(dept_id#x))
+: +- SubqueryAlias `dept`
+: +- Project [dept_id#x, dept_name#x, state#x]
+: +- SubqueryAlias `DEPT`
+: +- LocalRelation [dept_id#x, dept_name#x, state#x]
++- SubqueryAlias `emp`
+ +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+ +- SubqueryAlias `EMP`
+ +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+;
+
+
+-- !query 33
+SELECT emp.dept_id,
+ Sum(salary),
+ Sum(salary) FILTER (WHERE NOT EXISTS (SELECT state
+ FROM dept
+ WHERE dept.dept_id = emp.dept_id))
+FROM emp
+GROUP BY dept_id
+-- !query 33 schema
+struct<>
+-- !query 33 output
+org.apache.spark.sql.AnalysisException
+IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) AS sum(salary)#x]
+: +- Project [state#x]
+: +- Filter (dept_id#x = outer(dept_id#x))
+: +- SubqueryAlias `dept`
+: +- Project [dept_id#x, dept_name#x, state#x]
+: +- SubqueryAlias `DEPT`
+: +- LocalRelation [dept_id#x, dept_name#x, state#x]
++- SubqueryAlias `emp`
+ +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+ +- SubqueryAlias `EMP`
+ +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+;
+
+
+-- !query 34
+SELECT emp.dept_id,
+ avg(salary),
+ avg(salary) FILTER (WHERE emp.dept_id IN (SELECT DISTINCT dept_id
+ FROM dept))
+FROM emp
+GROUP BY dept_id
+-- !query 34 schema
+struct<>
+-- !query 34 output
+org.apache.spark.sql.AnalysisException
+IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) AS avg(salary)#x]
+: +- Distinct
+: +- Project [dept_id#x]
+: +- SubqueryAlias `dept`
+: +- Project [dept_id#x, dept_name#x, state#x]
+: +- SubqueryAlias `DEPT`
+: +- LocalRelation [dept_id#x, dept_name#x, state#x]
++- SubqueryAlias `emp`
+ +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+ +- SubqueryAlias `EMP`
+ +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+;
+
+
+-- !query 35
+SELECT emp.dept_id,
+ Sum(salary),
+ Sum(salary) FILTER (WHERE emp.dept_id NOT IN (SELECT DISTINCT dept_id
+ FROM dept))
+FROM emp
+GROUP BY dept_id
+-- !query 35 schema
+struct<>
+-- !query 35 output
+org.apache.spark.sql.AnalysisException
+IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) AS sum(salary)#x]
+: +- Distinct
+: +- Project [dept_id#x]
+: +- SubqueryAlias `dept`
+: +- Project [dept_id#x, dept_name#x, state#x]
+: +- SubqueryAlias `DEPT`
+: +- LocalRelation [dept_id#x, dept_name#x, state#x]
++- SubqueryAlias `emp`
+ +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+ +- SubqueryAlias `EMP`
+ +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]
+;
+
+
+-- !query 36
+SELECT t1.b FROM (SELECT COUNT(b) FILTER (WHERE a >= 2) AS b FROM testData) t1
+-- !query 36 schema
+struct
+-- !query 36 output
+4
diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out
index f102383cb4d8..9678b2e8966b 100644
--- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 2
+-- Number of queries: 4
-- !query 0
@@ -12,11 +12,27 @@ It is not allowed to use an aggregate function in the argument of another aggreg
-- !query 1
+select min(unique1) filter (where unique1 > 100) from tenk1
+-- !query 1 schema
+struct
+-- !query 1 output
+101
+
+
+-- !query 2
+select sum(1/ten) filter (where ten > 0) from tenk1
+-- !query 2 schema
+struct
+-- !query 2 output
+2828.9682539682954
+
+
+-- !query 3
select (select count(*)
from (values (1)) t0(inner_c))
from (values (2),(3)) t1(outer_c)
--- !query 1 schema
+-- !query 3 schema
struct
--- !query 1 output
+-- !query 3 output
1
1
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 37d98f7c8742..fb91a7f77d60 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.HiveResult.hiveResultString
-import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.FunctionsCommand
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@@ -2835,6 +2835,40 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession {
checkAnswer(df, Row(1, 3, 4) :: Row(2, 3, 4) :: Row(3, 3, 4) :: Nil)
}
+ test("Support filter clause for aggregate function with hash aggregate") {
+ Seq(("COUNT(a)", 3), ("COLLECT_LIST(a)", Seq(1, 2, 3))).foreach { funcToResult =>
+ val query = s"SELECT ${funcToResult._1} FILTER (WHERE b > 1) FROM testData2"
+ val df = sql(query)
+ val physical = df.queryExecution.sparkPlan
+ val aggregateExpressions = physical.collectFirst {
+ case agg: HashAggregateExec => agg.aggregateExpressions
+ case agg: ObjectHashAggregateExec => agg.aggregateExpressions
+ }
+ assert(aggregateExpressions.isDefined)
+ assert(aggregateExpressions.get.size == 1)
+ aggregateExpressions.get.foreach { expr =>
+ assert(expr.filter.isDefined)
+ }
+ checkAnswer(df, Row(funcToResult._2) :: Nil)
+ }
+ }
+
+ test("Support filter clause for aggregate function uses SortAggregateExec") {
+ withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
+ val df = sql("SELECT PERCENTILE(a, 1) FILTER (WHERE b > 1) FROM testData2")
+ val physical = df.queryExecution.sparkPlan
+ val aggregateExpressions = physical.collectFirst {
+ case agg: SortAggregateExec => agg.aggregateExpressions
+ }
+ assert(aggregateExpressions.isDefined)
+ assert(aggregateExpressions.get.size == 1)
+ aggregateExpressions.get.foreach { expr =>
+ assert(expr.filter.isDefined)
+ }
+ checkAnswer(df, Row(3) :: Nil)
+ }
+ }
+
test("Non-deterministic aggregate functions should not be deduplicated") {
val query = "SELECT a, first_value(b), first_value(b) + 1 FROM testData2 GROUP BY a"
val df = sql(query)