diff --git a/docs/sql-keywords.md b/docs/sql-keywords.md index 4f50ba6d440c..34e8cfb02c9f 100644 --- a/docs/sql-keywords.md +++ b/docs/sql-keywords.md @@ -117,6 +117,7 @@ Below is a list of all the keywords in Spark SQL. FALSEreservednon-reservedreserved FETCHreservednon-reservedreserved FIELDSnon-reservednon-reservednon-reserved + FILTERreservednon-reservedreserved FILEFORMATnon-reservednon-reservednon-reserved FIRSTnon-reservednon-reservednon-reserved FIRST_VALUEreservednon-reservedreserved diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index cd9748eaa6f2..8ad82b3ed99a 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -766,7 +766,7 @@ primaryExpression | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor | '(' query ')' #subqueryExpression | functionName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' - (OVER windowSpec)? #functionCall + (FILTER '(' WHERE where=booleanExpression ')')? (OVER windowSpec)? #functionCall | identifier '->' expression #lambda | '(' identifier (',' identifier)+ ')' '->' expression #lambda | value=primaryExpression '[' index=valueExpression ']' #subscript @@ -931,6 +931,7 @@ qualifiedNameList functionName : qualifiedName + | FILTER | LEFT | RIGHT ; @@ -1286,6 +1287,7 @@ nonReserved | EXTRACT | FALSE | FETCH + | FILTER | FIELDS | FILEFORMAT | FIRST @@ -1549,6 +1551,7 @@ EXTRACT: 'EXTRACT'; FALSE: 'FALSE'; FETCH: 'FETCH'; FIELDS: 'FIELDS'; +FILTER: 'FILTER'; FILEFORMAT: 'FILEFORMAT'; FIRST: 'FIRST'; FIRST_VALUE: 'FIRST_VALUE'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2438ef921822..bfbe3b4a9c94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1381,8 +1381,8 @@ class Analyzer( */ def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = { expr.transformUp { - case f1: UnresolvedFunction if containsStar(f1.children) => - f1.copy(children = f1.children.flatMap { + case f1: UnresolvedFunction if containsStar(f1.arguments) => + f1.copy(arguments = f1.arguments.flatMap { case s: Star => s.expand(child, resolver) case o => o :: Nil }) @@ -1734,26 +1734,37 @@ class Analyzer( s"its class is ${other.getClass.getCanonicalName}, which is not a generator.") } } - case u @ UnresolvedFunction(funcId, children, isDistinct) => + case u @ UnresolvedFunction(funcId, arguments, isDistinct, filter) => withPosition(u) { - v1SessionCatalog.lookupFunction(funcId, children) match { + v1SessionCatalog.lookupFunction(funcId, arguments) match { // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. case wf: AggregateWindowFunction => - if (isDistinct) { - failAnalysis( - s"DISTINCT specified, but ${wf.prettyName} is not an aggregate function") + if (isDistinct || filter.isDefined) { + failAnalysis("DISTINCT or FILTER specified, " + + s"but ${wf.prettyName} is not an aggregate function") } else { wf } // We get an aggregate function, we need to wrap it in an AggregateExpression. - case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) + case agg: AggregateFunction => + // TODO: SPARK-30276 Support Filter expression allows simultaneous use of DISTINCT + if (filter.isDefined) { + if (isDistinct) { + failAnalysis("DISTINCT and FILTER cannot be used in aggregate functions " + + "at the same time") + } else if (!filter.get.deterministic) { + failAnalysis("FILTER expression is non-deterministic, " + + "it cannot be used in aggregate functions") + } + } + AggregateExpression(agg, Complete, isDistinct, filter) // This function is not an aggregate function, just return the resolved one. case other => - if (isDistinct) { - failAnalysis( - s"DISTINCT specified, but ${other.prettyName} is not an aggregate function") + if (isDistinct || filter.isDefined) { + failAnalysis("DISTINCT or FILTER specified, " + + s"but ${other.prettyName} is not an aggregate function") } else { other } @@ -2351,7 +2362,7 @@ class Analyzer( // Extract Windowed AggregateExpression case we @ WindowExpression( - ae @ AggregateExpression(function, _, _, _), + ae @ AggregateExpression(function, _, _, _, _), spec: WindowSpecDefinition) => val newChildren = function.children.map(extractExpr) val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] @@ -2359,7 +2370,7 @@ class Analyzer( seenWindowAggregates += newAgg WindowExpression(newAgg, spec) - case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) => + case AggregateExpression(aggFunc, _, _, _, _) if hasWindowFunction(aggFunc.children) => failAnalysis("It is not allowed to use a window function inside an aggregate " + "function. Please use the inner window function in a sub-query.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index cfb16233b394..9268c7c3c944 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -155,7 +155,7 @@ trait CheckAnalysis extends PredicateHelper { case g: GroupingID => failAnalysis("grouping_id() can only be used with GroupingSets/Cube/Rollup") - case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => + case w @ WindowExpression(AggregateExpression(_, _, true, _, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") case w @ WindowExpression(_: OffsetWindowFunction, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index 1cd7f412bb67..11f94762d43e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -33,11 +33,14 @@ import org.apache.spark.sql.types.DataType case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { - case u @ UnresolvedFunction(fn, children, false) + case u @ UnresolvedFunction(fn, children, false, filter) if hasLambdaAndResolvedArguments(children) => withPosition(u) { catalog.lookupFunction(fn, children) match { - case func: HigherOrderFunction => func + case func: HigherOrderFunction => + filter.foreach(_.failAnalysis("FILTER predicate specified, " + + s"but ${func.prettyName} is not an aggregate function")) + func case other => other.failAnalysis( "A lambda function should only be used in a higher order function. However, " + s"its class is ${other.getClass.getCanonicalName}, which is not a " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index e5a6f30c330e..608f39c2d86f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -243,10 +243,13 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio case class UnresolvedFunction( name: FunctionIdentifier, - children: Seq[Expression], - isDistinct: Boolean) + arguments: Seq[Expression], + isDistinct: Boolean, + filter: Option[Expression] = None) extends Expression with Unevaluable { + override def children: Seq[Expression] = arguments ++ filter.toSeq + override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") @@ -257,8 +260,8 @@ case class UnresolvedFunction( } object UnresolvedFunction { - def apply(name: String, children: Seq[Expression], isDistinct: Boolean): UnresolvedFunction = { - UnresolvedFunction(FunctionIdentifier(name, None), children, isDistinct) + def apply(name: String, arguments: Seq[Expression], isDistinct: Boolean): UnresolvedFunction = { + UnresolvedFunction(FunctionIdentifier(name, None), arguments, isDistinct) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 595d7db0bb49..1f85b07d8385 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -71,23 +71,27 @@ object AggregateExpression { def apply( aggregateFunction: AggregateFunction, mode: AggregateMode, - isDistinct: Boolean): AggregateExpression = { + isDistinct: Boolean, + filter: Option[Expression] = None): AggregateExpression = { AggregateExpression( aggregateFunction, mode, isDistinct, + filter, NamedExpression.newExprId) } } /** * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field - * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. + * (`isDistinct`) indicating if DISTINCT keyword is specified for this function and + * a field (`filter`) indicating if filter clause is specified for this function. */ case class AggregateExpression( aggregateFunction: AggregateFunction, mode: AggregateMode, isDistinct: Boolean, + filter: Option[Expression], resultId: ExprId) extends Expression with Unevaluable { @@ -104,6 +108,8 @@ case class AggregateExpression( UnresolvedAttribute(aggregateFunction.toString) } + lazy val filterAttributes: AttributeSet = filter.map(_.references).getOrElse(AttributeSet.empty) + // We compute the same thing regardless of our final result. override lazy val canonicalized: Expression = { val normalizedAggFunc = mode match { @@ -119,10 +125,12 @@ case class AggregateExpression( normalizedAggFunc.canonicalized.asInstanceOf[AggregateFunction], mode, isDistinct, + filter.map(_.canonicalized), ExprId(0)) } - override def children: Seq[Expression] = aggregateFunction :: Nil + override def children: Seq[Expression] = aggregateFunction +: filter.toSeq + override def dataType: DataType = aggregateFunction.dataType override def foldable: Boolean = false override def nullable: Boolean = aggregateFunction.nullable @@ -130,7 +138,7 @@ case class AggregateExpression( @transient override lazy val references: AttributeSet = { mode match { - case Partial | Complete => aggregateFunction.references + case Partial | Complete => aggregateFunction.references ++ filterAttributes case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 05fd5e35e22a..836b29d6f37c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1459,7 +1459,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _), _) => af match { + case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) => af match { case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), prec + 10, scale) @@ -1473,7 +1473,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { case _ => we } - case ae @ AggregateExpression(af, _, _, _) => af match { + case ae @ AggregateExpression(af, _, _, _, _) => af match { case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index b9468007cac6..d14d45449013 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.IntegerType * aggregation in which the regular aggregation expressions and every distinct clause is aggregated * in a separate group. The results are then combined in a second aggregate. * - * For example (in scala): + * First example: query without filter clauses (in scala): * {{{ * val data = Seq( * ("a", "ca1", "cb1", 10), @@ -75,6 +75,49 @@ import org.apache.spark.sql.types.IntegerType * LocalTableScan [...] * }}} * + * Second example: aggregate function without distinct and with filter clauses (in sql): + * {{{ + * SELECT + * COUNT(DISTINCT cat1) as cat1_cnt, + * COUNT(DISTINCT cat2) as cat2_cnt, + * SUM(value) FILTER (WHERE id > 1) AS total + * FROM + * data + * GROUP BY + * key + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1), + * COUNT(DISTINCT 'cat2), + * sum('value) with FILTER('id > 1)] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count(if (('gid = 1)) 'cat1 else null), + * count(if (('gid = 2)) 'cat2 else null), + * first(if (('gid = 0)) 'total else null) ignore nulls] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * Aggregate( + * key = ['key, 'cat1, 'cat2, 'gid] + * functions = [sum('value) with FILTER('id > 1)] + * output = ['key, 'cat1, 'cat2, 'gid, 'total]) + * Expand( + * projections = [('key, null, null, 0, cast('value as bigint), 'id), + * ('key, 'cat1, null, 1, null, null), + * ('key, null, 'cat2, 2, null, null)] + * output = ['key, 'cat1, 'cat2, 'gid, 'value, 'id]) + * LocalTableScan [...] + * }}} + * * The rule does the following things here: * 1. Expand the data. There are three aggregation groups in this query: * i. the non-distinct group; @@ -183,9 +226,10 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // only expand unfoldable children val regularAggExprs = aggExpressions .filter(e => !e.isDistinct && e.children.exists(!_.foldable)) - val regularAggChildren = regularAggExprs + val regularAggFunChildren = regularAggExprs .flatMap(_.aggregateFunction.children.filter(!_.foldable)) - .distinct + val regularAggFilterAttrs = regularAggExprs.flatMap(_.filterAttributes) + val regularAggChildren = (regularAggFunChildren ++ regularAggFilterAttrs).distinct val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) // Setup aggregates for 'regular' aggregate expressions. @@ -194,7 +238,12 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val regularAggOperatorMap = regularAggExprs.map { e => // Perform the actual aggregation in the initial aggregate. val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get) - val operator = Alias(e.copy(aggregateFunction = af), e.sql)() + // We changed the attributes in the [[Expand]] output using expressionAttributePair. + // So we need to replace the attributes in FILTER expression with new ones. + val filterOpt = e.filter.map(_.transform { + case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a) + }) + val operator = Alias(e.copy(aggregateFunction = af, filter = filterOpt), e.sql)() // Select the result of the first aggregate in the last aggregate. val result = AggregateExpression( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 98acad8a7413..b26e936c2ea2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -525,9 +525,9 @@ object NullPropagation extends Rule[LogicalPlan] { case q: LogicalPlan => q transformExpressionsUp { case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) => Cast(Literal(0L), e.dataType, Option(SQLConf.get.sessionLocalTimeZone)) - case e @ AggregateExpression(Count(exprs), _, _, _) if exprs.forall(isNullLiteral) => + case e @ AggregateExpression(Count(exprs), _, _, _, _) if exprs.forall(isNullLiteral) => Cast(Literal(0L), e.dataType, Option(SQLConf.get.sessionLocalTimeZone)) - case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => + case ae @ AggregateExpression(Count(exprs), _, false, _, _) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. ae.copy(aggregateFunction = Count(Literal(1))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 32dbd389afd9..8ac14264a929 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -368,7 +368,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { // in the expression with the value they would return for zero input tuples. // Also replace attribute refs (for example, for grouping columns) with NULL. val rewrittenExpr = expr transform { - case a @ AggregateExpression(aggFunc, _, _, resultId) => + case a @ AggregateExpression(aggFunc, _, _, resultId, _) => aggFunc.defaultResult.getOrElse(Literal.default(NullType)) case _: AttributeReference => Literal.default(NullType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d3df7e03962e..41df5de61694 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1599,8 +1599,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case expressions => expressions } + val filter = Option(ctx.where).map(expression(_)) val function = UnresolvedFunction( - getFunctionIdentifier(ctx.functionName), arguments, isDistinct) + getFunctionIdentifier(ctx.functionName), arguments, isDistinct, filter) // Check if the function is evaluated in a windowed context. ctx.windowSpec match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 646abb6b8591..7023dbe2a367 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -167,12 +167,38 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "distinct function", CatalystSqlParser.parsePlan("SELECT hex(DISTINCT a) FROM TaBlE"), - "DISTINCT specified, but hex is not an aggregate function" :: Nil) + "DISTINCT or FILTER specified, but hex is not an aggregate function" :: Nil) + + errorTest( + "non aggregate function with filter predicate", + CatalystSqlParser.parsePlan("SELECT hex(a) FILTER (WHERE c = 1) FROM TaBlE2"), + "DISTINCT or FILTER specified, but hex is not an aggregate function" :: Nil) errorTest( "distinct window function", - CatalystSqlParser.parsePlan("SELECT percent_rank(DISTINCT a) over () FROM TaBlE"), - "DISTINCT specified, but percent_rank is not an aggregate function" :: Nil) + CatalystSqlParser.parsePlan("SELECT percent_rank(DISTINCT a) OVER () FROM TaBlE"), + "DISTINCT or FILTER specified, but percent_rank is not an aggregate function" :: Nil) + + errorTest( + "window function with filter predicate", + CatalystSqlParser.parsePlan("SELECT percent_rank(a) FILTER (WHERE c > 1) OVER () FROM TaBlE2"), + "DISTINCT or FILTER specified, but percent_rank is not an aggregate function" :: Nil) + + errorTest( + "higher order function with filter predicate", + CatalystSqlParser.parsePlan("SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) " + + "FILTER (WHERE c > 1)"), + "FILTER predicate specified, but aggregate is not an aggregate function" :: Nil) + + errorTest( + "DISTINCT and FILTER cannot be used in aggregate functions at the same time", + CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"), + "DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil) + + errorTest( + "FILTER expression is non-deterministic, it cannot be used in aggregate functions", + CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"), + "FILTER expression is non-deterministic, it cannot be used in aggregate functions" :: Nil) errorTest( "nested aggregate functions", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala index 9aab5b390fe1..9807b5dbe934 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala @@ -230,14 +230,14 @@ object AggregatingAccumulator { val typedImperatives = mutable.Buffer.empty[TypedImperativeAggregate[_]] val inputAttributeSeq: AttributeSeq = inputAttributes val resultExpressions = functions.map(_.transform { - case AggregateExpression(agg: DeclarativeAggregate, _, _, _) => + case AggregateExpression(agg: DeclarativeAggregate, _, _, _, _) => aggBufferAttributes ++= agg.aggBufferAttributes inputAggBufferAttributes ++= agg.inputAggBufferAttributes initialValues ++= agg.initialValues updateExpressions ++= agg.updateExpressions mergeExpressions ++= agg.mergeExpressions agg.evaluateExpression - case AggregateExpression(agg: ImperativeAggregate, _, _, _) => + case AggregateExpression(agg: ImperativeAggregate, _, _, _, _) => val imperative = BindReferences.bindReference(agg .withNewMutableAggBufferOffset(aggBufferAttributes.size) .withNewInputAggBufferOffset(inputAggBufferAttributes.size), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 4d762c5ea9f3..e729fa278e9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -174,7 +174,7 @@ object AggUtils { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. - case agg @ AggregateExpression(aggregateFunction, mode, true, _) => + case agg @ AggregateExpression(aggregateFunction, mode, true, _, _) => aggregateFunction.transformDown(distinctColumnAttributeLookup) .asInstanceOf[AggregateFunction] case agg => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index d03de1507fbb..527a9eac9948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -157,19 +157,44 @@ abstract class AggregationIterator( inputAttributes: Seq[Attribute]): (InternalRow, InternalRow) => Unit = { val joinedRow = new JoinedRow if (expressions.nonEmpty) { - val mergeExpressions = functions.zip(expressions).flatMap { - case (ae: DeclarativeAggregate, expression) => - expression.mode match { - case Partial | Complete => ae.updateExpressions - case PartialMerge | Final => ae.mergeExpressions + val mergeExpressions = + functions.zip(expressions.map(ae => (ae.mode, ae.isDistinct, ae.filter))).flatMap { + case (ae: DeclarativeAggregate, (mode, isDistinct, filter)) => + mode match { + case Partial | Complete => + if (filter.isDefined) { + ae.updateExpressions.zip(ae.aggBufferAttributes).map { + case (updateExpr, attr) => If(filter.get, updateExpr, attr) + } + } else { + ae.updateExpressions + } + case PartialMerge | Final => ae.mergeExpressions + } + case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + // Initialize predicates for aggregate functions if necessary + val predicateOptions = expressions.map { + case AggregateExpression(_, mode, _, Some(filter), _) => + mode match { + case Partial | Complete => + val predicate = Predicate.create(filter, inputAttributes) + predicate.initialize(partIndex) + Some(predicate) + case _ => None } - case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case _ => None } val updateFunctions = functions.zipWithIndex.collect { case (ae: ImperativeAggregate, i) => expressions(i).mode match { case Partial | Complete => - (buffer: InternalRow, row: InternalRow) => ae.update(buffer, row) + if (predicateOptions(i).isDefined) { + (buffer: InternalRow, row: InternalRow) => + if (predicateOptions(i).get.eval(row)) { ae.update(buffer, row) } + } else { + (buffer: InternalRow, row: InternalRow) => ae.update(buffer, row) + } case PartialMerge | Final => (buffer: InternalRow, row: InternalRow) => ae.merge(buffer, row) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index b79d3a278bb3..7f19d2754673 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -152,8 +152,10 @@ case class HashAggregateExec( override def usedInputs: AttributeSet = inputSet override def supportCodegen: Boolean = { - // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) + // ImperativeAggregate and filter predicate are not supported right now + // TODO: SPARK-30027 Support codegen for filter exprs in HashAggregateExec + !(aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) || + aggregateExpressions.exists(_.filter.isDefined)) } override def inputRDDs(): Seq[RDD[InternalRow]] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 1f325c11c9e4..a30b49f5680c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -61,9 +61,9 @@ class ObjectAggregationIterator( // Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = { val newExpressions = aggregateExpressions.map { - case agg @ AggregateExpression(_, Partial, _, _) => + case agg @ AggregateExpression(_, Partial, _, _, _) => agg.copy(mode = PartialMerge) - case agg @ AggregateExpression(_, Complete, _, _) => + case agg @ AggregateExpression(_, Complete, _, _, _) => agg.copy(mode = Final) case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 6dc64657ebf1..99358fbf4e94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -249,9 +249,9 @@ class TungstenAggregationIterator( // Basically the value of the KVIterator returned by externalSorter // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it. val newExpressions = aggregateExpressions.map { - case agg @ AggregateExpression(_, Partial, _, _) => + case agg @ AggregateExpression(_, Partial, _, _, _) => agg.copy(mode = PartialMerge) - case agg @ AggregateExpression(_, Complete, _, _) => + case agg @ AggregateExpression(_, Complete, _, _, _) => agg.copy(mode = Final) case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala index e8248b702875..d5d11c45f853 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala @@ -136,7 +136,7 @@ abstract class WindowExecBase( case e @ WindowExpression(function, spec) => val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] function match { - case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) + case AggregateExpression(f, _, _, _, _) => collect("AGGREGATE", frame, e, f) case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) case f: PythonUDF => collect("AGGREGATE", frame, e, f) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql new file mode 100644 index 000000000000..beb5b9e5fe51 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql @@ -0,0 +1,132 @@ +-- Test filter clause for aggregate expression. + +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null) +AS testData(a, b); + +CREATE OR REPLACE TEMPORARY VIEW EMP AS SELECT * FROM VALUES + (100, "emp 1", date "2005-01-01", 100.00D, 10), + (100, "emp 1", date "2005-01-01", 100.00D, 10), + (200, "emp 2", date "2003-01-01", 200.00D, 10), + (300, "emp 3", date "2002-01-01", 300.00D, 20), + (400, "emp 4", date "2005-01-01", 400.00D, 30), + (500, "emp 5", date "2001-01-01", 400.00D, NULL), + (600, "emp 6 - no dept", date "2001-01-01", 400.00D, 100), + (700, "emp 7", date "2010-01-01", 400.00D, 100), + (800, "emp 8", date "2016-01-01", 150.00D, 70) +AS EMP(id, emp_name, hiredate, salary, dept_id); + +CREATE OR REPLACE TEMPORARY VIEW DEPT AS SELECT * FROM VALUES + (10, "dept 1", "CA"), + (20, "dept 2", "NY"), + (30, "dept 3", "TX"), + (40, "dept 4 - unassigned", "OR"), + (50, "dept 5 - unassigned", "NJ"), + (70, "dept 7", "FL") +AS DEPT(dept_id, dept_name, state); + +-- Aggregate with filter and empty GroupBy expressions. +SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData; +SELECT COUNT(a) FILTER (WHERE a = 1), COUNT(b) FILTER (WHERE a > 1) FROM testData; +SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp; +SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp; +SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp; +SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp; +-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT +-- SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; + +-- Aggregate with filter and non-empty GroupBy expressions. +SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a; +SELECT a, COUNT(b) FILTER (WHERE a != 2) FROM testData GROUP BY b; +SELECT COUNT(a) FILTER (WHERE a >= 0), COUNT(b) FILTER (WHERE a >= 3) FROM testData GROUP BY a; +SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id; +-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT +-- SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; + +-- Aggregate with filter and grouped by literals. +SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1; +SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= date "2003-01-01") FROM emp GROUP BY 1; +SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_date("2003-01-01")) FROM emp GROUP BY 1; +SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_timestamp("2003-01-01")) FROM emp GROUP BY 1; + +-- Aggregate with filter, more than one aggregate function goes with distinct. +select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id; +-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT +-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; +-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; +-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; +-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id; +-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; + +-- Aggregate with filter and grouped by literals (hash aggregate), here the input table is filtered using WHERE. +SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1; + +-- Aggregate with filter and grouped by literals (sort aggregate), here the input table is filtered using WHERE. +SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1; + +-- Aggregate with filter and complex GroupBy expressions. +SELECT a + b, COUNT(b) FILTER (WHERE b >= 2) FROM testData GROUP BY a + b; +SELECT a + 2, COUNT(b) FILTER (WHERE b IN (1, 2)) FROM testData GROUP BY a + 1; +SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1; + +-- Aggregate with filter, foldable input and multiple distinct groups. +-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT +-- SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2) +-- FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; + +-- Check analysis exceptions +SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k; + +-- Aggregate with filter contains exists subquery +SELECT emp.dept_id, + avg(salary), + avg(salary) FILTER (WHERE id > (SELECT 200)) +FROM emp +GROUP BY dept_id; + +SELECT emp.dept_id, + avg(salary), + avg(salary) FILTER (WHERE emp.dept_id = (SELECT dept_id FROM dept LIMIT 1)) +FROM emp +GROUP BY dept_id; + +-- [SPARK-30220] Support Filter expression uses IN/EXISTS predicate sub-queries +SELECT emp.dept_id, + avg(salary), + avg(salary) FILTER (WHERE EXISTS (SELECT state + FROM dept + WHERE dept.dept_id = emp.dept_id)) +FROM emp +GROUP BY dept_id; + +SELECT emp.dept_id, + Sum(salary), + Sum(salary) FILTER (WHERE NOT EXISTS (SELECT state + FROM dept + WHERE dept.dept_id = emp.dept_id)) +FROM emp +GROUP BY dept_id; + +SELECT emp.dept_id, + avg(salary), + avg(salary) FILTER (WHERE emp.dept_id IN (SELECT DISTINCT dept_id + FROM dept)) +FROM emp +GROUP BY dept_id; +SELECT emp.dept_id, + Sum(salary), + Sum(salary) FILTER (WHERE emp.dept_id NOT IN (SELECT DISTINCT dept_id + FROM dept)) +FROM emp +GROUP BY dept_id; + +-- Aggregate with filter is subquery +SELECT t1.b FROM (SELECT COUNT(b) FILTER (WHERE a >= 2) AS b FROM testData) t1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql index 6f5e549644bb..746b67723483 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql @@ -232,16 +232,16 @@ select max(min(unique1)) from tenk1; -- drop table bytea_test_table; --- [SPARK-27986] Support Aggregate Expressions with filter -- FILTER tests --- select min(unique1) filter (where unique1 > 100) from tenk1; +select min(unique1) filter (where unique1 > 100) from tenk1; --- select sum(1/ten) filter (where ten > 0) from tenk1; +select sum(1/ten) filter (where ten > 0) from tenk1; -- select ten, sum(distinct four) filter (where four::text ~ '123') from onek a -- group by ten; +-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT -- select ten, sum(distinct four) filter (where four > 10) from onek a -- group by ten -- having exists (select 1 from onek b where sum(distinct a.four) = b.four); @@ -254,6 +254,7 @@ select max(min(unique1)) from tenk1; select (select count(*) from (values (1)) t0(inner_c)) from (values (2),(3)) t1(outer_c); -- inner query is aggregation query +-- [SPARK-30219] Support Filter expression reference the outer query -- select (select count(*) filter (where outer_c <> 0) -- from (values (1)) t0(inner_c)) -- from (values (2),(3)) t1(outer_c); -- outer query is aggregation query @@ -265,6 +266,7 @@ from (values (2),(3)) t1(outer_c); -- inner query is aggregation query -- filter (where o.unique1 < 10)) -- from tenk1 o; -- outer query is aggregation query +-- [SPARK-30220] Support Filter expression uses IN/EXISTS predicate sub-queries -- subquery in FILTER clause (PostgreSQL extension) -- select sum(unique1) FILTER (WHERE -- unique1 IN (SELECT unique1 FROM onek where unique1 < 100)) FROM tenk1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql index 330817fb5374..c9ee83eb75eb 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql @@ -336,8 +336,8 @@ order by 2,1; -- order by 2,1; -- FILTER queries --- [SPARK-27986] Support Aggregate Expressions with filter --- select ten, sum(distinct four) filter (where string(four) ~ '123') from onek a +-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT +-- select ten, sum(distinct four) filter (where string(four) like '123') from onek a -- group by rollup(ten); -- More rescan tests diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql index 8187f8a2773f..cd3b74b3aa03 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql @@ -404,7 +404,7 @@ SELECT ntile(0) OVER (ORDER BY ten), ten, four FROM tenk1; -- filter --- [SPARK-28500] Adds support for `filter` clause +-- [SPARK-30182] Support nested aggregates -- SELECT sum(salary), row_number() OVER (ORDER BY depname), sum( -- sum(salary) FILTER (WHERE enroll_date > '2007-01-01') -- ) diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part3.sql index 231c5235b313..b11c8c05f310 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part3.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part3.sql @@ -229,7 +229,6 @@ select udf(max(min(unique1))) from tenk1; -- drop table bytea_test_table; --- [SPARK-27986] Support Aggregate Expressions with filter -- FILTER tests -- select min(unique1) filter (where unique1 > 100) from tenk1; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out new file mode 100644 index 000000000000..5d266c980a49 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -0,0 +1,464 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 37 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null) +AS testData(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE OR REPLACE TEMPORARY VIEW EMP AS SELECT * FROM VALUES + (100, "emp 1", date "2005-01-01", 100.00D, 10), + (100, "emp 1", date "2005-01-01", 100.00D, 10), + (200, "emp 2", date "2003-01-01", 200.00D, 10), + (300, "emp 3", date "2002-01-01", 300.00D, 20), + (400, "emp 4", date "2005-01-01", 400.00D, 30), + (500, "emp 5", date "2001-01-01", 400.00D, NULL), + (600, "emp 6 - no dept", date "2001-01-01", 400.00D, 100), + (700, "emp 7", date "2010-01-01", 400.00D, 100), + (800, "emp 8", date "2016-01-01", 150.00D, 70) +AS EMP(id, emp_name, hiredate, salary, dept_id) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE OR REPLACE TEMPORARY VIEW DEPT AS SELECT * FROM VALUES + (10, "dept 1", "CA"), + (20, "dept 2", "NY"), + (30, "dept 3", "TX"), + (40, "dept 4 - unassigned", "OR"), + (50, "dept 5 - unassigned", "NJ"), + (70, "dept 7", "FL") +AS DEPT(dept_id, dept_name, state) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) AS `count(b)`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.; + + +-- !query 4 +SELECT COUNT(a) FILTER (WHERE a = 1), COUNT(b) FILTER (WHERE a > 1) FROM testData +-- !query 4 schema +struct +-- !query 4 output +2 4 + + +-- !query 5 +SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp +-- !query 5 schema +struct +-- !query 5 output +2 + + +-- !query 6 +SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp +-- !query 6 schema +struct +-- !query 6 output +2 + + +-- !query 7 +SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp +-- !query 7 schema +struct +-- !query 7 output +2 + + +-- !query 8 +SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp +-- !query 8 schema +struct +-- !query 8 output +2 + + +-- !query 9 +SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a +-- !query 9 schema +struct +-- !query 9 output +1 0 +2 2 +3 2 +NULL 0 + + +-- !query 10 +SELECT a, COUNT(b) FILTER (WHERE a != 2) FROM testData GROUP BY b +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +expression 'testdata.`a`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; + + +-- !query 11 +SELECT COUNT(a) FILTER (WHERE a >= 0), COUNT(b) FILTER (WHERE a >= 3) FROM testData GROUP BY a +-- !query 11 schema +struct +-- !query 11 output +0 0 +2 0 +2 0 +3 2 + + +-- !query 12 +SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp GROUP BY dept_id +-- !query 12 schema +struct +-- !query 12 output +10 200.0 +100 400.0 +20 NULL +30 400.0 +70 150.0 +NULL NULL + + +-- !query 13 +SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id +-- !query 13 schema +struct +-- !query 13 output +10 200.0 +100 400.0 +20 NULL +30 400.0 +70 150.0 +NULL NULL + + +-- !query 14 +SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id +-- !query 14 schema +struct +-- !query 14 output +10 200.0 +100 400.0 +20 NULL +30 400.0 +70 150.0 +NULL NULL + + +-- !query 15 +SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id +-- !query 15 schema +struct +-- !query 15 output +10 200.0 +100 400.0 +20 NULL +30 400.0 +70 150.0 +NULL NULL + + +-- !query 16 +SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1 +-- !query 16 schema +struct +-- !query 16 output +foo 6 + + +-- !query 17 +SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= date "2003-01-01") FROM emp GROUP BY 1 +-- !query 17 schema +struct +-- !query 17 output +foo 1350.0 + + +-- !query 18 +SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_date("2003-01-01")) FROM emp GROUP BY 1 +-- !query 18 schema +struct +-- !query 18 output +foo 1350.0 + + +-- !query 19 +SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_timestamp("2003-01-01")) FROM emp GROUP BY 1 +-- !query 19 schema +struct +-- !query 19 output +foo 1350.0 + + +-- !query 20 +select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query 20 schema +struct +-- !query 20 output +10 2 2 400.0 NULL +100 2 2 800.0 800.0 +20 1 1 300.0 300.0 +30 1 1 400.0 400.0 +70 1 1 150.0 150.0 +NULL 1 1 400.0 400.0 + + +-- !query 21 +select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id +-- !query 21 schema +struct +-- !query 21 output +10 2 2 400.0 NULL +100 2 2 800.0 800.0 +20 1 1 300.0 NULL +30 1 1 400.0 NULL +70 1 1 150.0 150.0 +NULL 1 1 400.0 NULL + + +-- !query 22 +select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query 22 schema +struct +-- !query 22 output +10 2 2 400.0 NULL +100 2 2 NULL 800.0 +20 1 1 300.0 300.0 +30 1 1 NULL 400.0 +70 1 1 150.0 150.0 +NULL 1 1 NULL 400.0 + + +-- !query 23 +select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id +-- !query 23 schema +struct +-- !query 23 output +10 2 2 400.0 NULL +100 2 2 NULL 800.0 +20 1 1 300.0 NULL +30 1 1 NULL NULL +70 1 1 150.0 150.0 +NULL 1 1 NULL NULL + + +-- !query 24 +SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1 +-- !query 24 schema +struct +-- !query 24 output + + + +-- !query 25 +SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1 +-- !query 25 schema +struct> +-- !query 25 output + + + +-- !query 26 +SELECT a + b, COUNT(b) FILTER (WHERE b >= 2) FROM testData GROUP BY a + b +-- !query 26 schema +struct<(a + b):int,count(b):bigint> +-- !query 26 output +2 0 +3 1 +4 1 +5 1 +NULL 0 + + +-- !query 27 +SELECT a + 2, COUNT(b) FILTER (WHERE b IN (1, 2)) FROM testData GROUP BY a + 1 +-- !query 27 schema +struct<> +-- !query 27 output +org.apache.spark.sql.AnalysisException +expression 'testdata.`a`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; + + +-- !query 28 +SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1 +-- !query 28 schema +struct<((a + 1) + 1):int,count(b):bigint> +-- !query 28 output +3 2 +4 2 +5 2 +NULL 1 + + +-- !query 29 +SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k +-- !query 29 schema +struct +-- !query 29 output +1 2 +2 2 +3 2 +NULL 1 + + +-- !query 30 +SELECT emp.dept_id, + avg(salary), + avg(salary) FILTER (WHERE id > (SELECT 200)) +FROM emp +GROUP BY dept_id +-- !query 30 schema +struct +-- !query 30 output +10 133.33333333333334 NULL +100 400.0 400.0 +20 300.0 300.0 +30 400.0 400.0 +70 150.0 150.0 +NULL 400.0 400.0 + + +-- !query 31 +SELECT emp.dept_id, + avg(salary), + avg(salary) FILTER (WHERE emp.dept_id = (SELECT dept_id FROM dept LIMIT 1)) +FROM emp +GROUP BY dept_id +-- !query 31 schema +struct +-- !query 31 output +10 133.33333333333334 133.33333333333334 +100 400.0 NULL +20 300.0 NULL +30 400.0 NULL +70 150.0 NULL +NULL 400.0 NULL + + +-- !query 32 +SELECT emp.dept_id, + avg(salary), + avg(salary) FILTER (WHERE EXISTS (SELECT state + FROM dept + WHERE dept.dept_id = emp.dept_id)) +FROM emp +GROUP BY dept_id +-- !query 32 schema +struct<> +-- !query 32 output +org.apache.spark.sql.AnalysisException +IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) AS avg(salary)#x] +: +- Project [state#x] +: +- Filter (dept_id#x = outer(dept_id#x)) +: +- SubqueryAlias `dept` +: +- Project [dept_id#x, dept_name#x, state#x] +: +- SubqueryAlias `DEPT` +: +- LocalRelation [dept_id#x, dept_name#x, state#x] ++- SubqueryAlias `emp` + +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] + +- SubqueryAlias `EMP` + +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] +; + + +-- !query 33 +SELECT emp.dept_id, + Sum(salary), + Sum(salary) FILTER (WHERE NOT EXISTS (SELECT state + FROM dept + WHERE dept.dept_id = emp.dept_id)) +FROM emp +GROUP BY dept_id +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) AS sum(salary)#x] +: +- Project [state#x] +: +- Filter (dept_id#x = outer(dept_id#x)) +: +- SubqueryAlias `dept` +: +- Project [dept_id#x, dept_name#x, state#x] +: +- SubqueryAlias `DEPT` +: +- LocalRelation [dept_id#x, dept_name#x, state#x] ++- SubqueryAlias `emp` + +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] + +- SubqueryAlias `EMP` + +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] +; + + +-- !query 34 +SELECT emp.dept_id, + avg(salary), + avg(salary) FILTER (WHERE emp.dept_id IN (SELECT DISTINCT dept_id + FROM dept)) +FROM emp +GROUP BY dept_id +-- !query 34 schema +struct<> +-- !query 34 output +org.apache.spark.sql.AnalysisException +IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) AS avg(salary)#x] +: +- Distinct +: +- Project [dept_id#x] +: +- SubqueryAlias `dept` +: +- Project [dept_id#x, dept_name#x, state#x] +: +- SubqueryAlias `DEPT` +: +- LocalRelation [dept_id#x, dept_name#x, state#x] ++- SubqueryAlias `emp` + +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] + +- SubqueryAlias `EMP` + +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] +; + + +-- !query 35 +SELECT emp.dept_id, + Sum(salary), + Sum(salary) FILTER (WHERE emp.dept_id NOT IN (SELECT DISTINCT dept_id + FROM dept)) +FROM emp +GROUP BY dept_id +-- !query 35 schema +struct<> +-- !query 35 output +org.apache.spark.sql.AnalysisException +IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) AS sum(salary)#x] +: +- Distinct +: +- Project [dept_id#x] +: +- SubqueryAlias `dept` +: +- Project [dept_id#x, dept_name#x, state#x] +: +- SubqueryAlias `DEPT` +: +- LocalRelation [dept_id#x, dept_name#x, state#x] ++- SubqueryAlias `emp` + +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] + +- SubqueryAlias `EMP` + +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] +; + + +-- !query 36 +SELECT t1.b FROM (SELECT COUNT(b) FILTER (WHERE a >= 2) AS b FROM testData) t1 +-- !query 36 schema +struct +-- !query 36 output +4 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out index f102383cb4d8..9678b2e8966b 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 2 +-- Number of queries: 4 -- !query 0 @@ -12,11 +12,27 @@ It is not allowed to use an aggregate function in the argument of another aggreg -- !query 1 +select min(unique1) filter (where unique1 > 100) from tenk1 +-- !query 1 schema +struct +-- !query 1 output +101 + + +-- !query 2 +select sum(1/ten) filter (where ten > 0) from tenk1 +-- !query 2 schema +struct +-- !query 2 output +2828.9682539682954 + + +-- !query 3 select (select count(*) from (values (1)) t0(inner_c)) from (values (2),(3)) t1(outer_c) --- !query 1 schema +-- !query 3 schema struct --- !query 1 output +-- !query 3 output 1 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 37d98f7c8742..fb91a7f77d60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.HiveResult.hiveResultString -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.FunctionsCommand import org.apache.spark.sql.execution.datasources.v2.BatchScanExec @@ -2835,6 +2835,40 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { checkAnswer(df, Row(1, 3, 4) :: Row(2, 3, 4) :: Row(3, 3, 4) :: Nil) } + test("Support filter clause for aggregate function with hash aggregate") { + Seq(("COUNT(a)", 3), ("COLLECT_LIST(a)", Seq(1, 2, 3))).foreach { funcToResult => + val query = s"SELECT ${funcToResult._1} FILTER (WHERE b > 1) FROM testData2" + val df = sql(query) + val physical = df.queryExecution.sparkPlan + val aggregateExpressions = physical.collectFirst { + case agg: HashAggregateExec => agg.aggregateExpressions + case agg: ObjectHashAggregateExec => agg.aggregateExpressions + } + assert(aggregateExpressions.isDefined) + assert(aggregateExpressions.get.size == 1) + aggregateExpressions.get.foreach { expr => + assert(expr.filter.isDefined) + } + checkAnswer(df, Row(funcToResult._2) :: Nil) + } + } + + test("Support filter clause for aggregate function uses SortAggregateExec") { + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { + val df = sql("SELECT PERCENTILE(a, 1) FILTER (WHERE b > 1) FROM testData2") + val physical = df.queryExecution.sparkPlan + val aggregateExpressions = physical.collectFirst { + case agg: SortAggregateExec => agg.aggregateExpressions + } + assert(aggregateExpressions.isDefined) + assert(aggregateExpressions.get.size == 1) + aggregateExpressions.get.foreach { expr => + assert(expr.filter.isDefined) + } + checkAnswer(df, Row(3) :: Nil) + } + } + test("Non-deterministic aggregate functions should not be deduplicated") { val query = "SELECT a, first_value(b), first_value(b) + 1 FROM testData2 GROUP BY a" val df = sql(query)