diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 18f814d6cdfd..089ab2b29b5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -194,8 +194,8 @@ class Analyzer( exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => - Aggregate(groups, assignAliases(aggs), child) + case Aggregate(groups, aggs, child, grouped) if child.resolved && hasUnresolvedAlias(aggs) => + Aggregate(groups, assignAliases(aggs), child, grouped) case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.copy(aggregations = assignAliases(g.aggregations)) @@ -281,9 +281,9 @@ class Analyzer( failAnalysis( s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") - case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => + case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child, _) => GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions) - case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => + case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child, _) => GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions) // Ensure all the expressions have been resolved. @@ -496,7 +496,7 @@ class Analyzer( if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) - case oldVersion @ Aggregate(_, aggregateExpressions, _) + case oldVersion @ Aggregate(_, aggregateExpressions, _, _) if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) @@ -728,7 +728,7 @@ class Analyzer( // Replace the index with the corresponding expression in aggregateExpressions. The index is // a 1-base position of aggregateExpressions, which is output columns (select expression) - case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && + case a @ Aggregate(groups, aggs, child, isGrouped) if aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => val newGroups = groups.map { case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => @@ -745,7 +745,7 @@ class Analyzer( s"(valid range is [1, ${aggs.size}])") case o => o } - Aggregate(newGroups, aggs, child) + Aggregate(newGroups, aggs, child, isGrouped) } } @@ -991,11 +991,12 @@ class Analyzer( } else { p } - case a @ Aggregate(grouping, expressions, child) => + case a @ Aggregate(grouping, expressions, child, isGrouped) => failOnOuterReference(a) val referencesToAdd = missingReferences(a) if (referencesToAdd.nonEmpty) { - Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) + val newGrouping = grouping ++ referencesToAdd + Aggregate(newGrouping, expressions ++ referencesToAdd, child) } else { a } @@ -1189,7 +1190,7 @@ class Analyzer( object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case filter @ Filter(havingCondition, - aggregate @ Aggregate(grouping, originalAggExprs, child)) + aggregate @ Aggregate(grouping, originalAggExprs, child, isGrouped)) if aggregate.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause @@ -1198,7 +1199,8 @@ class Analyzer( Aggregate( grouping, Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, - child) + child, + isGrouped) val resolvedOperator = execute(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator @@ -1684,13 +1686,13 @@ class Analyzer( // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. - case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) + case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child, isGrouped)) if child.resolved && hasWindowFunction(aggregateExprs) && a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. - val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) + val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child, isGrouped) // Add a Filter operator for conditions in the Having clause. val withFilter = Filter(condition, withAggregate) val withWindow = addWindow(windowExpressions, withFilter) @@ -1702,12 +1704,12 @@ class Analyzer( case p: LogicalPlan if !p.childrenResolved => p // Aggregate without Having clause. - case a @ Aggregate(groupingExprs, aggregateExprs, child) + case a @ Aggregate(groupingExprs, aggregateExprs, child, isGrouped) if hasWindowFunction(aggregateExprs) && a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. - val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) + val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child, isGrouped) // Add Window operators. val withWindow = addWindow(windowExpressions, withAggregate) @@ -2100,9 +2102,9 @@ object CleanupAliases extends Rule[LogicalPlan] { projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) Project(cleanedProjectList, child) - case Aggregate(grouping, aggs, child) => + case Aggregate(grouping, aggs, child, isGrouped) => val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) - Aggregate(grouping.map(trimAliases), cleanedAggs, child) + Aggregate(grouping.map(trimAliases), cleanedAggs, child, isGrouped) case w @ Window(windowExprs, partitionSpec, orderSpec, child) => val cleanedWindowExprs = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index e07e9194bee9..03ac5f19672d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -195,7 +195,7 @@ trait CheckAnalysis extends PredicateHelper { checkValidJoinConditionExprs(condition) - case Aggregate(groupingExprs, aggregateExprs, child) => + case Aggregate(groupingExprs, aggregateExprs, child, _) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case aggExpr: AggregateExpression => aggExpr.aggregateFunction.children.foreach { child => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index e81370c504ab..f65053f2334c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -44,7 +44,7 @@ object UnsupportedOperationChecker { } // Disallow multiple streaming aggregations - val aggregates = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a } + val aggregates = plan.collect { case a@Aggregate(_, _, _, _) if a.isStreaming => a } if (aggregates.size > 1) { throwError( @@ -73,7 +73,7 @@ object UnsupportedOperationChecker { * data. */ def containsCompleteData(subplan: LogicalPlan): Boolean = { - val aggs = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a } + val aggs = plan.collect { case a @ Aggregate(_, _, _, _) if a.isStreaming => a } // Either the subplan has no streaming source, or it has aggregation with Complete mode !subplan.isStreaming || (aggs.nonEmpty && outputMode == InternalOutputModes.Complete) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d2f0c9798921..7d8a5e613919 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -369,7 +369,7 @@ object ColumnPruning extends Rule[LogicalPlan] { d.copy(child = prunedChild(child, d.references)) // Prunes the unused columns from child of Aggregate/Expand/Generate - case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => + case a @ Aggregate(_, _, child, _) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) @@ -1098,7 +1098,7 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { */ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(grouping, _, _) => + case a @ Aggregate(grouping, _, _, _) => val newGrouping = grouping.filter(!_.foldable) a.copy(groupingExpressions = newGrouping) } @@ -1110,7 +1110,7 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { */ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(grouping, _, _) => + case a @ Aggregate(grouping, _, _, _) => val newGrouping = ExpressionSet(grouping).toSeq a.copy(groupingExpressions = newGrouping) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 7400a01918c5..7e111c4841b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -69,7 +69,7 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { case _: Repartition => empty(p) case _: RepartitionByExpression => empty(p) // AggregateExpressions like COUNT(*) return their results like 0. - case Aggregate(_, ae, _) if !ae.exists(containsAggregateExpression) => empty(p) + case Aggregate(_, ae, _, _) if !ae.exists(containsAggregateExpression) => empty(p) // Generators like Hive-style UDTF may return their records within `close`. case Generate(_: Explode, _, _, _, _, _) => empty(p) case _ => p diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index f14aaab72a98..49f4166c11e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -181,7 +181,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap } - case Aggregate(_, aggExprs, _) => + case Aggregate(_, aggExprs, _, _) => // Some of the expressions under the Aggregate node are the join columns // for joining with the outer query block. Fill those expressions in with // nulls and statically evaluate the remainder. @@ -322,7 +322,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { * subqueries. */ def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(grouping, expressions, child) => + case a @ Aggregate(grouping, expressions, child, isGrouped) => val subqueries = ArrayBuffer.empty[ScalarSubquery] val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) if (subqueries.nonEmpty) { @@ -332,7 +332,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { val newGrouping = grouping.map { e => subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e) } - Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries)) + Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries), isGrouped) } else { a } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 41cabb8cb339..d17ab77993d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -228,10 +228,10 @@ object Unions { object PhysicalAggregation { // groupingExpressions, aggregateExpressions, resultExpressions, child type ReturnType = - (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) + (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan, Boolean) def unapply(a: Any): Option[ReturnType] = a match { - case logical.Aggregate(groupingExpressions, resultExpressions, child) => + case logical.Aggregate(groupingExpressions, resultExpressions, child, isGrouped) => // A single aggregate expression might appear multiple times in resultExpressions. // In order to avoid evaluating an individual aggregate function multiple times, we'll // build a set of the distinct aggregate expressions and build a function which can @@ -281,7 +281,8 @@ object PhysicalAggregation { namedGroupingExpressions.map(_._2), aggregateExpressions, rewrittenResultExpressions, - child)) + child, + isGrouped)) case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d2d33e40a8c8..a114ffa64542 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -472,10 +472,20 @@ case class Range( } } +object Aggregate { + def apply( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: LogicalPlan): Aggregate = { + Aggregate(groupingExpressions, aggregateExpressions, child, groupingExpressions.nonEmpty) + } +} + case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], - child: LogicalPlan) + child: LogicalPlan, + isGrouped: Boolean) extends UnaryNode { override lazy val resolved: Boolean = { @@ -484,7 +494,10 @@ case class Aggregate( }.nonEmpty ) - !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions + !expressions.exists(!_.resolved) && + childrenResolved && + !hasWindowExpressions && + (isGrouped || groupingExpressions.isEmpty) } override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala index 6f821f80cc4c..534c7cd37eaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala @@ -126,7 +126,7 @@ class SQLBuilder private ( case p: Project => projectToSQL(p, isDistinct = false) - case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) => + case a @ Aggregate(_, _, e @ Expand(_, _, p: Project), _) if isGroupingSet(a, e, p) => groupingSetToSQL(a, e, p) case p: Aggregate => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 1b7fedca8484..7816e9e29035 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -47,7 +47,7 @@ case class OptimizeMetadataOnlyQuery( } plan.transform { - case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(partAttrs, relation)) => + case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(partAttrs, relation), _) => // We only apply this optimization when only partitioned attributes are scanned. if (a.references.subsetOf(partAttrs)) { val aggFunctions = aggExprs.flatMap(_.collect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3441ccf53b45..30c507383cd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, SaveMode, Strategy} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ @@ -228,13 +227,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object StatefulAggregationStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalAggregation( - namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => + namedGroupingExpressions, + aggregateExpressions, + rewrittenResultExpressions, + child, + isGrouped) => aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, - planLater(child)) + planLater(child), + isGrouped) case _ => Nil } @@ -246,7 +250,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalAggregation( - groupingExpressions, aggregateExpressions, resultExpressions, child) => + groupingExpressions, aggregateExpressions, resultExpressions, child, isGrouped) => val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) @@ -267,21 +271,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { groupingExpressions, aggregateExpressions, resultExpressions, - planLater(child)) + planLater(child), + isGrouped) } } else if (functionsWithDistinct.isEmpty) { aggregate.AggUtils.planAggregateWithoutDistinct( groupingExpressions, aggregateExpressions, resultExpressions, - planLater(child)) + planLater(child), + isGrouped) } else { aggregate.AggUtils.planAggregateWithOneDistinct( groupingExpressions, functionsWithDistinct, functionsWithoutDistinct, resultExpressions, - planLater(child)) + planLater(child), + isGrouped) } aggregateOperator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 4fbb9d554c9b..0ec6bf8d7fb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -31,7 +31,8 @@ object AggUtils { groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], - child: SparkPlan): Seq[SparkPlan] = { + child: SparkPlan, + isGrouped: Boolean): Seq[SparkPlan] = { val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) @@ -42,7 +43,8 @@ object AggUtils { aggregateAttributes = completeAggregateAttributes, initialInputBufferOffset = 0, resultExpressions = resultExpressions, - child = child + child = child, + isGrouped = isGrouped ) :: Nil } @@ -53,7 +55,8 @@ object AggUtils { aggregateAttributes: Seq[Attribute] = Nil, initialInputBufferOffset: Int = 0, resultExpressions: Seq[NamedExpression] = Nil, - child: SparkPlan): SparkPlan = { + child: SparkPlan, + isGrouped: Boolean): SparkPlan = { val useHash = HashAggregateExec.supportsAggregate( aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) if (useHash) { @@ -64,7 +67,8 @@ object AggUtils { aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, resultExpressions = resultExpressions, - child = child) + child = child, + isGrouped = isGrouped) } else { SortAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, @@ -73,7 +77,8 @@ object AggUtils { aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, resultExpressions = resultExpressions, - child = child) + child = child, + isGrouped = isGrouped) } } @@ -81,7 +86,8 @@ object AggUtils { groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], - child: SparkPlan): Seq[SparkPlan] = { + child: SparkPlan, + isGrouped: Boolean): Seq[SparkPlan] = { // Check if we can use HashAggregate. // 1. Create an Aggregate Operator for partial aggregations. @@ -101,7 +107,8 @@ object AggUtils { aggregateAttributes = partialAggregateAttributes, initialInputBufferOffset = 0, resultExpressions = partialResultExpressions, - child = child) + child = child, + isGrouped = isGrouped) // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) @@ -116,7 +123,8 @@ object AggUtils { aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, - child = partialAggregate) + child = partialAggregate, + isGrouped = isGrouped) finalAggregate :: Nil } @@ -126,7 +134,8 @@ object AggUtils { functionsWithDistinct: Seq[AggregateExpression], functionsWithoutDistinct: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], - child: SparkPlan): Seq[SparkPlan] = { + child: SparkPlan, + isGrouped: Boolean): Seq[SparkPlan] = { // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one // DISTINCT aggregate function, all of those functions will have the same column expressions. @@ -154,7 +163,8 @@ object AggUtils { aggregateAttributes = aggregateAttributes, resultExpressions = groupingAttributes ++ distinctAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = child) + child = child, + isGrouped = true) } // 2. Create an Aggregate Operator for partial merge aggregations. @@ -170,7 +180,8 @@ object AggUtils { initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, resultExpressions = groupingAttributes ++ distinctAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = partialAggregate) + child = partialAggregate, + isGrouped = true) } // 3. Create an Aggregate operator for partial aggregation (for distinct) @@ -211,7 +222,8 @@ object AggUtils { aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, resultExpressions = partialAggregateResult, - child = partialMergeAggregate) + child = partialMergeAggregate, + isGrouped = isGrouped) } // 4. Create an Aggregate Operator for the final aggregation. @@ -241,7 +253,8 @@ object AggUtils { aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, initialInputBufferOffset = groupingAttributes.length, resultExpressions = resultExpressions, - child = partialDistinctAggregate) + child = partialDistinctAggregate, + isGrouped = isGrouped) } finalAndCompleteAggregate :: Nil @@ -261,7 +274,8 @@ object AggUtils { groupingExpressions: Seq[NamedExpression], functionsWithoutDistinct: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], - child: SparkPlan): Seq[SparkPlan] = { + child: SparkPlan, + isGrouped: Boolean): Seq[SparkPlan] = { val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -277,7 +291,8 @@ object AggUtils { aggregateAttributes = aggregateAttributes, resultExpressions = groupingAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = child) + child = child, + isGrouped = isGrouped) } val partialMerged1: SparkPlan = { @@ -292,7 +307,8 @@ object AggUtils { initialInputBufferOffset = groupingAttributes.length, resultExpressions = groupingAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = partialAggregate) + child = partialAggregate, + isGrouped = isGrouped) } val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1) @@ -309,7 +325,8 @@ object AggUtils { initialInputBufferOffset = groupingAttributes.length, resultExpressions = groupingAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = restored) + child = restored, + isGrouped = isGrouped) } // Note: stateId and returnAllStates are filled in later with preparation rules // in IncrementalExecution. @@ -329,7 +346,8 @@ object AggUtils { aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingAttributes.length, resultExpressions = resultExpressions, - child = saved) + child = saved, + isGrouped = isGrouped) } finalAndCompleteAggregate :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 59e132dfb252..a6584c26a353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -38,12 +38,14 @@ import org.apache.spark.util.Utils case class HashAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], + isGrouped: Boolean, aggregateExpressions: Seq[AggregateExpression], aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode with CodegenSupport { + assert(isGrouped || groupingExpressions.isEmpty) private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -95,7 +97,7 @@ case class HashAggregateExec( child.execute().mapPartitions { iter => val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { + if (!hasInput && isGrouped) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. Iterator.empty @@ -115,7 +117,7 @@ case class HashAggregateExec( numOutputRows, peakMemory, spillSize) - if (!hasInput && groupingExpressions.isEmpty) { + if (!hasInput && !isGrouped) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) } else { @@ -140,7 +142,7 @@ case class HashAggregateExec( } protected override def doProduce(ctx: CodegenContext): String = { - if (groupingExpressions.isEmpty) { + if (!isGrouped) { doProduceWithoutKeys(ctx) } else { doProduceWithKeys(ctx) @@ -148,7 +150,7 @@ case class HashAggregateExec( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - if (groupingExpressions.isEmpty) { + if (!isGrouped) { doConsumeWithoutKeys(ctx, input) } else { doConsumeWithKeys(ctx, input) @@ -484,7 +486,8 @@ case class HashAggregateExec( val isSupported = (groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) || f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) && - bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) + bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) && + groupingExpressions.nonEmpty // For vectorized hash map, We do not support byte array based decimal type for aggregate values // as ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place @@ -495,7 +498,7 @@ case class HashAggregateExec( val isNotByteArrayDecimalType = bufferSchema.map(_.dataType).filter(_.isInstanceOf[DecimalType]) .forall(!DecimalType.isByteArrayDecimalType(_)) - isSupported && isNotByteArrayDecimalType + isSupported && isNotByteArrayDecimalType } private def enableTwoLevelHashMap(ctx: CodegenContext) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 2a81a823c44b..6723c0e66d43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -33,12 +33,14 @@ import org.apache.spark.util.Utils case class SortAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], + isGrouped: Boolean, aggregateExpressions: Seq[AggregateExpression], aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode { + assert(isGrouped || groupingExpressions.isEmpty) private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -76,7 +78,7 @@ case class SortAggregateExec( // Because the constructor of an aggregation iterator will read at least the first row, // we need to get the value of iter.hasNext first. val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { + if (!hasInput && isGrouped) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. Iterator[UnsafeRow]() @@ -92,7 +94,7 @@ case class SortAggregateExec( (expressions, inputSchema) => newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), numOutputRows) - if (!hasInput && groupingExpressions.isEmpty) { + if (!hasInput && !isGrouped) { // There is no input and there is no grouping expressions. // We need to output a single row as the output. numOutputRows += 1 @@ -114,10 +116,11 @@ case class SortAggregateExec( val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]") val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]") val outputString = Utils.truncatedString(output, "[", ", ", "]") - if (verbose) { - s"SortAggregate(key=$keyString, functions=$functionString, output=$outputString)" + val suffix = if (verbose) { + s", output=$outputString, isGrouped=$isGrouped" } else { - s"SortAggregate(key=$keyString, functions=$functionString)" + "" } + s"SortAggregate(key=$keyString, functions=$functionString$suffix)" } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql new file mode 100644 index 000000000000..6741703d9d82 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -0,0 +1,17 @@ +-- Temporary data. +create temporary view myview as values 128, 256 as v(int_col); + +-- group by should produce all input rows, +select int_col, count(*) from myview group by int_col; + +-- group by should produce a single row. +select 'foo', count(*) from myview group by 1; + +-- group-by should not produce any rows (whole stage code generation). +select 'foo' from myview where int_col == 0 group by 1; + +-- group-by should not produce any rows (hash aggregate). +select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1; + +-- group-by should not produce any rows (sort aggregate). +select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out new file mode 100644 index 000000000000..9127bd4dd4c6 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -0,0 +1,51 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +create temporary view myview as values 128, 256 as v(int_col) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select int_col, count(*) from myview group by int_col +-- !query 1 schema +struct +-- !query 1 output +128 1 +256 1 + + +-- !query 2 +select 'foo', count(*) from myview group by 1 +-- !query 2 schema +struct +-- !query 2 output +foo 2 + + +-- !query 3 +select 'foo' from myview where int_col == 0 group by 1 +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1 +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1 +-- !query 5 schema +struct> +-- !query 5 output +