diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f3f64031843e..9519a56c2817 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1020,7 +1020,7 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { * Note that changes in the final output ordering may affect the file size (SPARK-32318). * This rule handles the following cases: * 1) if the sort order is empty or the sort order does not have any reference - * 2) if the child is already sorted + * 2) if the Sort operator is a local sort and the child is already sorted * 3) if there is another Sort operator separated by 0...n Project, Filter, Repartition or * RepartitionByExpression (with deterministic expressions) operators * 4) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or @@ -1031,12 +1031,18 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { * function is order irrelevant */ object EliminateSorts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally + + private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) => val newOrders = orders.filterNot(_.child.foldable) - if (newOrders.isEmpty) child else s.copy(order = newOrders) - case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => - child + if (newOrders.isEmpty) { + applyLocally.lift(child).getOrElse(child) + } else { + s.copy(order = newOrders) + } + case Sort(orders, false, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => + applyLocally.lift(child).getOrElse(child) case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child)) case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) => j.copy(left = recursiveRemoveSort(originLeft), right = recursiveRemoveSort(originRight)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 302439839996..d84dfcc8f308 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1253,6 +1253,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val REMOVE_REDUNDANT_SORTS_ENABLED = buildConf("spark.sql.execution.removeRedundantSorts") + .internal() + .doc("Whether to remove redundant physical sort node") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + val STATE_STORE_PROVIDER_CLASS = buildConf("spark.sql.streaming.stateStore.providerClass") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index cc351e365113..62deebd93075 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -99,12 +99,34 @@ class EliminateSortsSuite extends AnalysisTest { comparePlans(optimized, correctAnswer) } - test("remove redundant order by") { + test("SPARK-33183: remove consecutive no-op sorts") { + val plan = testRelation.orderBy().orderBy().orderBy() + val optimized = Optimize.execute(plan.analyze) + val correctAnswer = testRelation.analyze + comparePlans(optimized, correctAnswer) + } + + test("SPARK-33183: remove redundant sort by") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.limit(2).select('a).sortBy('a.asc, 'b.desc_nullsFirst) val optimized = Optimize.execute(unnecessaryReordered.analyze) val correctAnswer = orderedPlan.limit(2).select('a).analyze - comparePlans(Optimize.execute(optimized), correctAnswer) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-33183: remove all redundant local sorts") { + val orderedPlan = testRelation.sortBy('a.asc).orderBy('a.asc).sortBy('a.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = testRelation.orderBy('a.asc).analyze + comparePlans(optimized, correctAnswer) + } + + test("SPARK-33183: should not remove global sort") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val reordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val optimized = Optimize.execute(reordered.analyze) + val correctAnswer = reordered.analyze + comparePlans(optimized, correctAnswer) } test("do not remove sort if the order is different") { @@ -115,22 +137,39 @@ class EliminateSortsSuite extends AnalysisTest { comparePlans(optimized, correctAnswer) } - test("filters don't affect order") { + test("SPARK-33183: remove top level local sort with filter operators") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) - val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.where('a > Literal(10)).sortBy('a.asc, 'b.desc) val optimized = Optimize.execute(filteredAndReordered.analyze) val correctAnswer = orderedPlan.where('a > Literal(10)).analyze comparePlans(optimized, correctAnswer) } - test("limits don't affect order") { + test("SPARK-33183: keep top level global sort with filter operators") { + val projectPlan = testRelation.select('a, 'b) + val orderedPlan = projectPlan.orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = projectPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc).analyze + comparePlans(optimized, correctAnswer) + } + + test("SPARK-33183: limits should not affect order for local sort") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) - val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.limit(Literal(10)).sortBy('a.asc, 'b.desc) val optimized = Optimize.execute(filteredAndReordered.analyze) val correctAnswer = orderedPlan.limit(Literal(10)).analyze comparePlans(optimized, correctAnswer) } + test("SPARK-33183: should not remove global sort with limit operators") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = filteredAndReordered.analyze + comparePlans(optimized, correctAnswer) + } + test("different sorts are not simplified if limit is in between") { val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10)) .orderBy('a.asc) @@ -139,11 +178,11 @@ class EliminateSortsSuite extends AnalysisTest { comparePlans(optimized, correctAnswer) } - test("range is already sorted") { + test("SPARK-33183: should not remove global sort with range operator") { val inputPlan = Range(1L, 1000L, 1, 10) val orderedPlan = inputPlan.orderBy('id.asc) val optimized = Optimize.execute(orderedPlan.analyze) - val correctAnswer = inputPlan.analyze + val correctAnswer = orderedPlan.analyze comparePlans(optimized, correctAnswer) val reversedPlan = inputPlan.orderBy('id.desc) @@ -154,10 +193,18 @@ class EliminateSortsSuite extends AnalysisTest { val negativeStepInputPlan = Range(10L, 1L, -1, 10) val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc) val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze) - val negativeStepCorrectAnswer = negativeStepInputPlan.analyze + val negativeStepCorrectAnswer = negativeStepOrderedPlan.analyze comparePlans(negativeStepOptimized, negativeStepCorrectAnswer) } + test("SPARK-33183: remove local sort with range operator") { + val inputPlan = Range(1L, 1000L, 1, 10) + val orderedPlan = inputPlan.sortBy('id.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = inputPlan.analyze + comparePlans(optimized, correctAnswer) + } + test("sort should not be removed when there is a node which doesn't guarantee any order") { val orderedPlan = testRelation.select('a, 'b) val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc) @@ -333,4 +380,39 @@ class EliminateSortsSuite extends AnalysisTest { val correctAnswer = PushDownOptimizer.execute(noOrderByPlan.analyze) comparePlans(optimized, correctAnswer) } + + test("SPARK-33183: remove consecutive global sorts with the same ordering") { + Seq( + (testRelation.orderBy('a.asc).orderBy('a.asc), testRelation.orderBy('a.asc)), + (testRelation.orderBy('a.asc, 'b.desc).orderBy('a.asc), testRelation.orderBy('a.asc)) + ).foreach { case (ordered, answer) => + val optimized = Optimize.execute(ordered.analyze) + comparePlans(optimized, answer.analyze) + } + } + + test("SPARK-33183: remove consecutive local sorts with the same ordering") { + val orderedPlan = testRelation.sortBy('a.asc).sortBy('a.asc).sortBy('a.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = testRelation.sortBy('a.asc).analyze + comparePlans(optimized, correctAnswer) + } + + test("SPARK-33183: remove consecutive local sorts with different ordering") { + val orderedPlan = testRelation.sortBy('b.asc).sortBy('a.desc).sortBy('a.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = testRelation.sortBy('a.asc).analyze + comparePlans(optimized, correctAnswer) + } + + test("SPARK-33183: should keep global sort when child is a local sort with the same ordering") { + val correctAnswer = testRelation.orderBy('a.asc).analyze + Seq( + testRelation.sortBy('a.asc).orderBy('a.asc), + testRelation.orderBy('a.asc).sortBy('a.asc).orderBy('a.asc) + ).foreach { ordered => + val optimized = Optimize.execute(ordered.analyze) + comparePlans(optimized, correctAnswer) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index c37e1e92c857..b998430c1602 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -343,6 +343,7 @@ object QueryExecution { PlanDynamicPruningFilters, PlanSubqueries, RemoveRedundantProjects, + RemoveRedundantSorts, EnsureRequirements, DisableUnnecessaryBucketedScan, ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.columnarRules), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantSorts.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantSorts.scala new file mode 100644 index 000000000000..87c08ec865fe --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantSorts.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf + +/** + * Remove redundant SortExec node from the spark plan. A sort node is redundant when + * its child satisfies both its sort orders and its required child distribution. Note + * this rule differs from the Optimizer rule EliminateSorts in that this rule also checks + * if the child satisfies the required distribution so that it is safe to remove not only a + * local sort but also a global sort when its child already satisfies required sort orders. + */ +object RemoveRedundantSorts extends Rule[SparkPlan] { + def apply(plan: SparkPlan): SparkPlan = { + if (!conf.getConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED)) { + plan + } else { + removeSorts(plan) + } + } + + private def removeSorts(plan: SparkPlan): SparkPlan = plan transform { + case s @ SortExec(orders, _, child, _) + if SortOrder.orderingSatisfies(child.outputOrdering, orders) && + child.outputPartitioning.satisfies(s.requiredChildDistribution.head) => + child + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index d30e16276b9f..a4a58dfe1de5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -83,6 +83,7 @@ case class AdaptiveSparkPlanExec( @transient private val optimizer = new AQEOptimizer(conf) @transient private val removeRedundantProjects = RemoveRedundantProjects + @transient private val removeRedundantSorts = RemoveRedundantSorts @transient private val ensureRequirements = EnsureRequirements // A list of physical plan rules to be applied before creation of query stages. The physical @@ -90,6 +91,7 @@ case class AdaptiveSparkPlanExec( // Exchange nodes) after running these rules. private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq( removeRedundantProjects, + removeRedundantSorts, ensureRequirements ) ++ context.session.sessionState.queryStagePrepRules diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 048466b3d863..be29acb6d3a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -234,19 +234,6 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } } - test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { - val query = testData.select('key, 'value).sort('key.desc).cache() - assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation]) - val resorted = query.sort('key.desc) - assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty) - assert(resorted.select('key).collect().map(_.getInt(0)).toSeq == - (1 to 100).reverse) - // with a different order, the sort is needed - val sortedAsc = query.sort('key) - assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1) - assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100)) - } - test("PartitioningCollection") { withTempView("normal", "small", "tiny") { testData.createOrReplaceTempView("normal") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala new file mode 100644 index 000000000000..54c5a3344190 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + + +abstract class RemoveRedundantSortsSuiteBase + extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + import testImplicits._ + + private def checkNumSorts(df: DataFrame, count: Int): Unit = { + val plan = df.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { case s: SortExec => s }.length == count) + } + + private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = { + withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") { + val df = sql(query) + checkNumSorts(df, enabledCount) + val result = df.collect() + withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "false") { + val df = sql(query) + checkNumSorts(df, disabledCount) + checkAnswer(df, result) + } + } + } + + test("remove redundant sorts with limit") { + withTempView("t") { + spark.range(100).select('id as "key").createOrReplaceTempView("t") + val query = + """ + |SELECT key FROM + | (SELECT key FROM t WHERE key > 10 ORDER BY key DESC LIMIT 10) + |ORDER BY key DESC + |""".stripMargin + checkSorts(query, 0, 1) + } + } + + test("remove redundant sorts with broadcast hash join") { + withTempView("t1", "t2") { + spark.range(1000).select('id as "key").createOrReplaceTempView("t1") + spark.range(1000).select('id as "key").createOrReplaceTempView("t2") + + val queryTemplate = """ + |SELECT /*+ BROADCAST(%s) */ t1.key FROM + | (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1 + |JOIN + | (SELECT key FROM t2 WHERE key > 50 ORDER BY key DESC LIMIT 100) t2 + |ON t1.key = t2.key + |ORDER BY %s + """.stripMargin + + // No sort should be removed since the stream side (t2) order DESC + // does not satisfy the required sort order ASC. + val buildLeftOrderByRightAsc = queryTemplate.format("t1", "t2.key ASC") + checkSorts(buildLeftOrderByRightAsc, 1, 1) + + // The top sort node should be removed since the stream side (t2) order DESC already + // satisfies the required sort order DESC. + val buildLeftOrderByRightDesc = queryTemplate.format("t1", "t2.key DESC") + checkSorts(buildLeftOrderByRightDesc, 0, 1) + + // No sort should be removed since the sort ordering from broadcast-hash join is based + // on the stream side (t2) and the required sort order is from t1. + val buildLeftOrderByLeftDesc = queryTemplate.format("t1", "t1.key DESC") + checkSorts(buildLeftOrderByLeftDesc, 1, 1) + + // The top sort node should be removed since the stream side (t1) order DESC already + // satisfies the required sort order DESC. + val buildRightOrderByLeftDesc = queryTemplate.format("t2", "t1.key DESC") + checkSorts(buildRightOrderByLeftDesc, 0, 1) + } + } + + test("remove redundant sorts with sort merge join") { + withTempView("t1", "t2") { + spark.range(1000).select('id as "key").createOrReplaceTempView("t1") + spark.range(1000).select('id as "key").createOrReplaceTempView("t2") + val query = """ + |SELECT /*+ MERGE(t1) */ t1.key FROM + | (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1 + |JOIN + | (SELECT key FROM t2 WHERE key > 50 ORDER BY key DESC LIMIT 100) t2 + |ON t1.key = t2.key + |ORDER BY t1.key + """.stripMargin + + val queryAsc = query + " ASC" + checkSorts(queryAsc, 2, 3) + + // The top level sort should not be removed since the child output ordering is ASC and + // the required ordering is DESC. + val queryDesc = query + " DESC" + checkSorts(queryDesc, 3, 3) + } + } + + test("cached sorted data doesn't need to be re-sorted") { + withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") { + val df = spark.range(1000).select('id as "key").sort('key.desc).cache() + val resorted = df.sort('key.desc) + val sortedAsc = df.sort('key.asc) + checkNumSorts(df, 0) + checkNumSorts(resorted, 0) + checkNumSorts(sortedAsc, 1) + val result = resorted.collect() + withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "false") { + val resorted = df.sort('key.desc) + checkNumSorts(resorted, 1) + checkAnswer(resorted, result) + } + } + } +} + +class RemoveRedundantSortsSuite extends RemoveRedundantSortsSuiteBase + with DisableAdaptiveExecutionSuite + +class RemoveRedundantSortsSuiteAE extends RemoveRedundantSortsSuiteBase + with EnableAdaptiveExecutionSuite