diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8595762988b4..65d5a4640ea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -962,34 +962,42 @@ class Analyzer( unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")(isGenerated = true)) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] - val resolvedAliasedOrdering: Seq[Alias] = - resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] // If we pass the analysis check, then the ordering expressions should only reference to // aggregate expressions or grouping expressions, and it's safe to push them down to // Aggregate. checkAnalysis(resolvedAggregate) - val originalAggExprs = aggregate.aggregateExpressions.map( - CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + val resolvedOrdering = + resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]].map(_.child) - // If the ordering expression is same with original aggregate expression, we don't need - // to push down this ordering expression and can reference the original aggregate - // expression instead. + // Collects aggregate expressions and grouping expressions that are un-evaluable as + // ordering expressions and need to push down and add them into the aggregate list, we + // will project them away finally. val needsPushDown = ArrayBuffer.empty[NamedExpression] - val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { - case (evaluated, order) => - val index = originalAggExprs.indexWhere { - case Alias(child, _) => child semanticEquals evaluated.child - case other => other semanticEquals evaluated.child - } - - if (index == -1) { - needsPushDown += evaluated - order.copy(child = evaluated.toAttribute) - } else { - order.copy(child = originalAggExprs(index).toAttribute) + val groupingExpressions = resolvedAggregate.groupingExpressions + val evaluatedOrderings = resolvedOrdering.zip(sortOrder).map { + case (resolved, order) => + val evaluated = resolved.transformDown { + case e if e.isInstanceOf[AggregateExpression] || + groupingExpressions.exists(e.semanticEquals) => + val alreadyPushed = needsPushDown.find { + case Alias(child, _) => e semanticEquals child + case other => e semanticEquals other + } + + if (alreadyPushed.isDefined) { + alreadyPushed.get.toAttribute + } else { + val named = e match { + case n: NamedExpression => n + case _ => Alias(e, "aggOrder")() + } + needsPushDown += named + named.toAttribute + } } + order.copy(child = evaluated) } val sortOrdersMap = unresolvedSortOrders @@ -1003,9 +1011,10 @@ class Analyzer( if (sortOrder == finalSortOrders) { sort } else { + val newAggExprs: Seq[NamedExpression] = aggregate.aggregateExpressions ++ needsPushDown Project(aggregate.output, Sort(finalSortOrders, global, - aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + aggregate.copy(aggregateExpressions = newAggExprs))) } } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e6d554565d44..af40649d1c61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -98,7 +98,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, - EliminateSerialization) :: + EliminateSerialization, + RemoveUnnecessarySortOrderEvaluation) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: Batch("Typed Filter Optimization", fixedPoint, @@ -1557,3 +1558,50 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } } + +/** + * Remove the unnecessary evaluation from SortOrder, if they are already in child's output. + * + * As an example, "SELECT a + 1 AS add, b FROM t ORDER BY a + 1" will be analyzed into: + * {{{ + * Project('add#2, 'b#1, + * Sort('a#0 + 1, + * Project(('a#0 + 1).as("add")#2, 'b#1, 'a#0), + * Relation) + * }}} + * Then this rule can optimize it into: + * {{{ + * Project('add#2, b#1, + * Sort('add#2, + * Project(('a#0 + 1).as("add")#2, 'b#1, 'a#0), + * Relation) + * }}} + * Finally other optimize rules(column pruning, project collapse) can turn it into: + * {{{ + * Sort('add#2, + * Project(('a#0 + 1).as("add")#2, b#1), + * Relation) + * }}} + */ +object RemoveUnnecessarySortOrderEvaluation extends Rule[LogicalPlan] { + + private def optimizeSortOrders(orders: Seq[SortOrder], childColumns: Seq[NamedExpression]) = { + orders.map { order => + val newChild = order.child transformDown { + case expr => childColumns.find { + case Alias(child, _) => child semanticEquals expr + case other => other semanticEquals expr + }.map(_.toAttribute).getOrElse(expr) + } + order.copy(child = newChild) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case sort @ Sort(sortOrders, _, Project(projectList, _)) => + sort.copy(order = optimizeSortOrders(sortOrders, projectList)) + + case sort @ Sort(sortOrders, _, Aggregate(_, aggExprs, _)) => + sort.copy(order = optimizeSortOrders(sortOrders, aggExprs)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index a63d1770f325..abb30a4b1ad6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -129,7 +129,6 @@ class AnalysisSuite extends AnalysisTest { val b = testRelation2.output(1) val c = testRelation2.output(2) val alias_a3 = count(a).as("a3") - val alias_b = b.as("aggOrder") // Case 1: when the child of Sort is not Aggregate, // the sort reference is handled by the rule ResolveSortReferences @@ -153,8 +152,8 @@ class AnalysisSuite extends AnalysisTest { .orderBy('b.asc) val expected2 = testRelation2 - .groupBy(a, c, b)(a, c, alias_a3, alias_b) - .orderBy(alias_b.toAttribute.asc) + .groupBy(a, c, b)(a, c, alias_a3, b) + .orderBy(b.asc) .select(a, c, alias_a3.toAttribute) checkAnalysis(plan2, expected2) @@ -316,8 +315,8 @@ class AnalysisSuite extends AnalysisTest { .orderBy('a1.asc, 'c.asc) val expected = testRelation2 - .groupBy(a, c)(alias1, alias2, alias3) - .orderBy(alias1.toAttribute.asc, alias2.toAttribute.asc) + .groupBy(a, c)(alias1, alias2, alias3, c) + .orderBy(alias1.toAttribute.asc, c.asc) .select(alias1.toAttribute, alias2.toAttribute, alias3.toAttribute) checkAnalysis(plan, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SortOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SortOptimizationSuite.scala new file mode 100644 index 000000000000..011f83682ffd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SortOptimizationSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class SortOptimizationSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", FixedPoint(100), + ColumnPruning, + CollapseProject, + RemoveUnnecessarySortOrderEvaluation) :: Nil + } + + private val testRelation = LocalRelation('a.int, 'b.int) + + test("sort on projection") { + val query = + testRelation + .select(('a * 2).as("eval"), 'b) + .orderBy(('a * 2).asc, ('a * 2 + 3).asc).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = + testRelation + .select(('a * 2).as("eval"), 'b) + .orderBy('eval.asc, ('eval + 3).asc).analyze + + comparePlans(optimized, correctAnswer) + } + + test("sort on aggregation") { + val query = + testRelation + .groupBy('a * 2)(sum('b).as("sum")) + .orderBy(('a * 2).asc, ('a * 2 + 3).asc).analyze + + val optimized = Optimize.execute(query) + + val correctAnswer = + testRelation + .groupBy('a * 2)(sum('b).as("sum"), ('a * 2).as("aggOrder")) + .orderBy('aggOrder.asc, ('aggOrder + 3).asc) + .select('sum) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("sort on aggregation: order by aggregate expression") { + val query = + testRelation + .groupBy('a)('a) + .orderBy(sum('b).asc, (sum('b) + 1).asc).analyze + + val optimized = Optimize.execute(query) + + val correctAnswer = + testRelation + .groupBy('a)('a, sum('b).as("aggOrder")) + .orderBy('aggOrder.asc, ('aggOrder + 1).asc) + .select('a).analyze + + comparePlans(optimized, correctAnswer) + } +}