Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -962,34 +962,42 @@ class Analyzer(
unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")(isGenerated = true))
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
val resolvedAliasedOrdering: Seq[Alias] =
resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]

// If we pass the analysis check, then the ordering expressions should only reference to
// aggregate expressions or grouping expressions, and it's safe to push them down to
// Aggregate.
checkAnalysis(resolvedAggregate)

val originalAggExprs = aggregate.aggregateExpressions.map(
CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
val resolvedOrdering =
resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]].map(_.child)

// If the ordering expression is same with original aggregate expression, we don't need
// to push down this ordering expression and can reference the original aggregate
// expression instead.
// Collects aggregate expressions and grouping expressions that are un-evaluable as
// ordering expressions and need to push down and add them into the aggregate list, we
// will project them away finally.
val needsPushDown = ArrayBuffer.empty[NamedExpression]
val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map {
case (evaluated, order) =>
val index = originalAggExprs.indexWhere {
case Alias(child, _) => child semanticEquals evaluated.child
case other => other semanticEquals evaluated.child
}

if (index == -1) {
needsPushDown += evaluated
order.copy(child = evaluated.toAttribute)
} else {
order.copy(child = originalAggExprs(index).toAttribute)
val groupingExpressions = resolvedAggregate.groupingExpressions
val evaluatedOrderings = resolvedOrdering.zip(sortOrder).map {
case (resolved, order) =>
val evaluated = resolved.transformDown {
case e if e.isInstanceOf[AggregateExpression] ||
groupingExpressions.exists(e.semanticEquals) =>
val alreadyPushed = needsPushDown.find {
case Alias(child, _) => e semanticEquals child
case other => e semanticEquals other
}

if (alreadyPushed.isDefined) {
alreadyPushed.get.toAttribute
} else {
val named = e match {
case n: NamedExpression => n
case _ => Alias(e, "aggOrder")()
}
needsPushDown += named
named.toAttribute
}
}
order.copy(child = evaluated)
}

val sortOrdersMap = unresolvedSortOrders
Expand All @@ -1003,9 +1011,10 @@ class Analyzer(
if (sortOrder == finalSortOrders) {
sort
} else {
val newAggExprs: Seq[NamedExpression] = aggregate.aggregateExpressions ++ needsPushDown
Project(aggregate.output,
Sort(finalSortOrders, global,
aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
aggregate.copy(aggregateExpressions = newAggExprs)))
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
EliminateSorts,
SimplifyCasts,
SimplifyCaseConversionExpressions,
EliminateSerialization) ::
EliminateSerialization,
RemoveUnnecessarySortOrderEvaluation) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
Batch("Typed Filter Optimization", fixedPoint,
Expand Down Expand Up @@ -1557,3 +1558,50 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}
}

/**
* Remove the unnecessary evaluation from SortOrder, if they are already in child's output.
*
* As an example, "SELECT a + 1 AS add, b FROM t ORDER BY a + 1" will be analyzed into:
* {{{
* Project('add#2, 'b#1,
* Sort('a#0 + 1,
* Project(('a#0 + 1).as("add")#2, 'b#1, 'a#0),
* Relation)
* }}}
* Then this rule can optimize it into:
* {{{
* Project('add#2, b#1,
* Sort('add#2,
* Project(('a#0 + 1).as("add")#2, 'b#1, 'a#0),
* Relation)
* }}}
* Finally other optimize rules(column pruning, project collapse) can turn it into:
* {{{
* Sort('add#2,
* Project(('a#0 + 1).as("add")#2, b#1),
* Relation)
* }}}
*/
object RemoveUnnecessarySortOrderEvaluation extends Rule[LogicalPlan] {

private def optimizeSortOrders(orders: Seq[SortOrder], childColumns: Seq[NamedExpression]) = {
orders.map { order =>
val newChild = order.child transformDown {
case expr => childColumns.find {
case Alias(child, _) => child semanticEquals expr
case other => other semanticEquals expr
}.map(_.toAttribute).getOrElse(expr)
}
order.copy(child = newChild)
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case sort @ Sort(sortOrders, _, Project(projectList, _)) =>
sort.copy(order = optimizeSortOrders(sortOrders, projectList))

case sort @ Sort(sortOrders, _, Aggregate(_, aggExprs, _)) =>
sort.copy(order = optimizeSortOrders(sortOrders, aggExprs))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ class AnalysisSuite extends AnalysisTest {
val b = testRelation2.output(1)
val c = testRelation2.output(2)
val alias_a3 = count(a).as("a3")
val alias_b = b.as("aggOrder")

// Case 1: when the child of Sort is not Aggregate,
// the sort reference is handled by the rule ResolveSortReferences
Expand All @@ -153,8 +152,8 @@ class AnalysisSuite extends AnalysisTest {
.orderBy('b.asc)

val expected2 = testRelation2
.groupBy(a, c, b)(a, c, alias_a3, alias_b)
.orderBy(alias_b.toAttribute.asc)
.groupBy(a, c, b)(a, c, alias_a3, b)
.orderBy(b.asc)
.select(a, c, alias_a3.toAttribute)

checkAnalysis(plan2, expected2)
Expand Down Expand Up @@ -316,8 +315,8 @@ class AnalysisSuite extends AnalysisTest {
.orderBy('a1.asc, 'c.asc)

val expected = testRelation2
.groupBy(a, c)(alias1, alias2, alias3)
.orderBy(alias1.toAttribute.asc, alias2.toAttribute.asc)
.groupBy(a, c)(alias1, alias2, alias3, c)
.orderBy(alias1.toAttribute.asc, c.asc)
.select(alias1.toAttribute, alias2.toAttribute, alias3.toAttribute)
checkAnalysis(plan, expected)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class SortOptimizationSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Operator Optimizations", FixedPoint(100),
ColumnPruning,
CollapseProject,
RemoveUnnecessarySortOrderEvaluation) :: Nil
}

private val testRelation = LocalRelation('a.int, 'b.int)

test("sort on projection") {
val query =
testRelation
.select(('a * 2).as("eval"), 'b)
.orderBy(('a * 2).asc, ('a * 2 + 3).asc).analyze
val optimized = Optimize.execute(query)

val correctAnswer =
testRelation
.select(('a * 2).as("eval"), 'b)
.orderBy('eval.asc, ('eval + 3).asc).analyze

comparePlans(optimized, correctAnswer)
}

test("sort on aggregation") {
val query =
testRelation
.groupBy('a * 2)(sum('b).as("sum"))
.orderBy(('a * 2).asc, ('a * 2 + 3).asc).analyze

val optimized = Optimize.execute(query)

val correctAnswer =
testRelation
.groupBy('a * 2)(sum('b).as("sum"), ('a * 2).as("aggOrder"))
.orderBy('aggOrder.asc, ('aggOrder + 3).asc)
.select('sum)
.analyze

comparePlans(optimized, correctAnswer)
}

test("sort on aggregation: order by aggregate expression") {
val query =
testRelation
.groupBy('a)('a)
.orderBy(sum('b).asc, (sum('b) + 1).asc).analyze

val optimized = Optimize.execute(query)

val correctAnswer =
testRelation
.groupBy('a)('a, sum('b).as("aggOrder"))
.orderBy('aggOrder.asc, ('aggOrder + 1).asc)
.select('a).analyze

comparePlans(optimized, correctAnswer)
}
}