Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ class Analyzer(
exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined)

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
Aggregate(groups, assignAliases(aggs), child)
case Aggregate(groups, aggs, child, grouped) if child.resolved && hasUnresolvedAlias(aggs) =>
Aggregate(groups, assignAliases(aggs), child, grouped)

case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
g.copy(aggregations = assignAliases(g.aggregations))
Expand Down Expand Up @@ -281,9 +281,9 @@ class Analyzer(
failAnalysis(
s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead")

case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child, _) =>
GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child, _) =>
GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)

// Ensure all the expressions have been resolved.
Expand Down Expand Up @@ -496,7 +496,7 @@ class Analyzer(
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(projectList = newAliases(projectList)))

case oldVersion @ Aggregate(_, aggregateExpressions, _)
case oldVersion @ Aggregate(_, aggregateExpressions, _, _)
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))

Expand Down Expand Up @@ -728,7 +728,7 @@ class Analyzer(

// Replace the index with the corresponding expression in aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns (select expression)
case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
case a @ Aggregate(groups, aggs, child, isGrouped) if aggs.forall(_.resolved) &&
groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
val newGroups = groups.map {
case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
Expand All @@ -745,7 +745,7 @@ class Analyzer(
s"(valid range is [1, ${aggs.size}])")
case o => o
}
Aggregate(newGroups, aggs, child)
Aggregate(newGroups, aggs, child, isGrouped)
}
}

Expand Down Expand Up @@ -991,11 +991,12 @@ class Analyzer(
} else {
p
}
case a @ Aggregate(grouping, expressions, child) =>
case a @ Aggregate(grouping, expressions, child, isGrouped) =>
failOnOuterReference(a)
val referencesToAdd = missingReferences(a)
if (referencesToAdd.nonEmpty) {
Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child)
val newGrouping = grouping ++ referencesToAdd
Aggregate(newGrouping, expressions ++ referencesToAdd, child)
} else {
a
}
Expand Down Expand Up @@ -1189,7 +1190,7 @@ class Analyzer(
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case filter @ Filter(havingCondition,
aggregate @ Aggregate(grouping, originalAggExprs, child))
aggregate @ Aggregate(grouping, originalAggExprs, child, isGrouped))
if aggregate.resolved =>

// Try resolving the condition of the filter as though it is in the aggregate clause
Expand All @@ -1198,7 +1199,8 @@ class Analyzer(
Aggregate(
grouping,
Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
child)
child,
isGrouped)
val resolvedOperator = execute(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
Expand Down Expand Up @@ -1684,13 +1686,13 @@ class Analyzer(

// Aggregate with Having clause. This rule works with an unresolved Aggregate because
// a resolved Aggregate will not have Window Functions.
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child, isGrouped))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child, isGrouped)
// Add a Filter operator for conditions in the Having clause.
val withFilter = Filter(condition, withAggregate)
val withWindow = addWindow(windowExpressions, withFilter)
Expand All @@ -1702,12 +1704,12 @@ class Analyzer(
case p: LogicalPlan if !p.childrenResolved => p

// Aggregate without Having clause.
case a @ Aggregate(groupingExprs, aggregateExprs, child)
case a @ Aggregate(groupingExprs, aggregateExprs, child, isGrouped)
if hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child, isGrouped)
// Add Window operators.
val withWindow = addWindow(windowExpressions, withAggregate)

Expand Down Expand Up @@ -2100,9 +2102,9 @@ object CleanupAliases extends Rule[LogicalPlan] {
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
Project(cleanedProjectList, child)

case Aggregate(grouping, aggs, child) =>
case Aggregate(grouping, aggs, child, isGrouped) =>
val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
Aggregate(grouping.map(trimAliases), cleanedAggs, child)
Aggregate(grouping.map(trimAliases), cleanedAggs, child, isGrouped)

case w @ Window(windowExprs, partitionSpec, orderSpec, child) =>
val cleanedWindowExprs =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ trait CheckAnalysis extends PredicateHelper {

checkValidJoinConditionExprs(condition)

case Aggregate(groupingExprs, aggregateExprs, child) =>
case Aggregate(groupingExprs, aggregateExprs, child, _) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case aggExpr: AggregateExpression =>
aggExpr.aggregateFunction.children.foreach { child =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ object UnsupportedOperationChecker {
}

// Disallow multiple streaming aggregations
val aggregates = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a }
val aggregates = plan.collect { case a@Aggregate(_, _, _, _) if a.isStreaming => a }

if (aggregates.size > 1) {
throwError(
Expand Down Expand Up @@ -73,7 +73,7 @@ object UnsupportedOperationChecker {
* data.
*/
def containsCompleteData(subplan: LogicalPlan): Boolean = {
val aggs = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a }
val aggs = plan.collect { case a @ Aggregate(_, _, _, _) if a.isStreaming => a }
// Either the subplan has no streaming source, or it has aggregation with Complete mode
!subplan.isStreaming || (aggs.nonEmpty && outputMode == InternalOutputModes.Complete)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
d.copy(child = prunedChild(child, d.references))

// Prunes the unused columns from child of Aggregate/Expand/Generate
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
case a @ Aggregate(_, _, child, _) if (child.outputSet -- a.references).nonEmpty =>
a.copy(child = prunedChild(child, a.references))
case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
e.copy(child = prunedChild(child, e.references))
Expand Down Expand Up @@ -1098,7 +1098,7 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] {
*/
object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(grouping, _, _) =>
case a @ Aggregate(grouping, _, _, _) =>
val newGrouping = grouping.filter(!_.foldable)
a.copy(groupingExpressions = newGrouping)
}
Expand All @@ -1110,7 +1110,7 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
*/
object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(grouping, _, _) =>
case a @ Aggregate(grouping, _, _, _) =>
val newGrouping = ExpressionSet(grouping).toSeq
a.copy(groupingExpressions = newGrouping)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
case _: Repartition => empty(p)
case _: RepartitionByExpression => empty(p)
// AggregateExpressions like COUNT(*) return their results like 0.
case Aggregate(_, ae, _) if !ae.exists(containsAggregateExpression) => empty(p)
case Aggregate(_, ae, _, _) if !ae.exists(containsAggregateExpression) => empty(p)
// Generators like Hive-style UDTF may return their records within `close`.
case Generate(_: Explode, _, _, _, _, _) => empty(p)
case _ => p
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap
}

case Aggregate(_, aggExprs, _) =>
case Aggregate(_, aggExprs, _, _) =>
// Some of the expressions under the Aggregate node are the join columns
// for joining with the outer query block. Fill those expressions in with
// nulls and statically evaluate the remainder.
Expand Down Expand Up @@ -322,7 +322,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
* subqueries.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(grouping, expressions, child) =>
case a @ Aggregate(grouping, expressions, child, isGrouped) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
Expand All @@ -332,7 +332,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
val newGrouping = grouping.map { e =>
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
}
Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries), isGrouped)
} else {
a
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,10 @@ object Unions {
object PhysicalAggregation {
// groupingExpressions, aggregateExpressions, resultExpressions, child
type ReturnType =
(Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan)
(Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan, Boolean)

def unapply(a: Any): Option[ReturnType] = a match {
case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
case logical.Aggregate(groupingExpressions, resultExpressions, child, isGrouped) =>
// A single aggregate expression might appear multiple times in resultExpressions.
// In order to avoid evaluating an individual aggregate function multiple times, we'll
// build a set of the distinct aggregate expressions and build a function which can
Expand Down Expand Up @@ -281,7 +281,8 @@ object PhysicalAggregation {
namedGroupingExpressions.map(_._2),
aggregateExpressions,
rewrittenResultExpressions,
child))
child,
isGrouped))

case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,20 @@ case class Range(
}
}

object Aggregate {
def apply(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: LogicalPlan): Aggregate = {
Aggregate(groupingExpressions, aggregateExpressions, child, groupingExpressions.nonEmpty)
}
}

case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: LogicalPlan)
child: LogicalPlan,
isGrouped: Boolean)
extends UnaryNode {

override lazy val resolved: Boolean = {
Expand All @@ -484,7 +494,10 @@ case class Aggregate(
}.nonEmpty
)

!expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions
!expressions.exists(!_.resolved) &&
childrenResolved &&
!hasWindowExpressions &&
(isGrouped || groupingExpressions.isEmpty)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why have this condition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not matter if isGrouped has any grouping expressions (literal grouping expressions are eliminated during optimization). It is however problematic when a not-grouped Aggregate has grouping expressions; this means that we have not derived the isGrouped flag correctly.

}

override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class SQLBuilder private (
case p: Project =>
projectToSQL(p, isDistinct = false)

case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) =>
case a @ Aggregate(_, _, e @ Expand(_, _, p: Project), _) if isGroupingSet(a, e, p) =>
groupingSetToSQL(a, e, p)

case p: Aggregate =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ case class OptimizeMetadataOnlyQuery(
}

plan.transform {
case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(partAttrs, relation)) =>
case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(partAttrs, relation), _) =>
// We only apply this optimization when only partitioned attributes are scanned.
if (a.references.subsetOf(partAttrs)) {
val aggFunctions = aggExprs.flatMap(_.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, SaveMode, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
Expand Down Expand Up @@ -228,13 +227,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object StatefulAggregationStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalAggregation(
namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) =>
namedGroupingExpressions,
aggregateExpressions,
rewrittenResultExpressions,
child,
isGrouped) =>

aggregate.AggUtils.planStreamingAggregation(
namedGroupingExpressions,
aggregateExpressions,
rewrittenResultExpressions,
planLater(child))
planLater(child),
isGrouped)

case _ => Nil
}
Expand All @@ -246,7 +250,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object Aggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalAggregation(
groupingExpressions, aggregateExpressions, resultExpressions, child) =>
groupingExpressions, aggregateExpressions, resultExpressions, child, isGrouped) =>

val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
Expand All @@ -267,21 +271,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
groupingExpressions,
aggregateExpressions,
resultExpressions,
planLater(child))
planLater(child),
isGrouped)
}
} else if (functionsWithDistinct.isEmpty) {
aggregate.AggUtils.planAggregateWithoutDistinct(
groupingExpressions,
aggregateExpressions,
resultExpressions,
planLater(child))
planLater(child),
isGrouped)
} else {
aggregate.AggUtils.planAggregateWithOneDistinct(
groupingExpressions,
functionsWithDistinct,
functionsWithoutDistinct,
resultExpressions,
planLater(child))
planLater(child),
isGrouped)
}

aggregateOperator
Expand Down
Loading