diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8b777bed706bf..e5f531ff2f51d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1105,23 +1105,45 @@ object EliminateSorts extends Rule[LogicalPlan] { } case Sort(orders, false, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => applyLocally.lift(child).getOrElse(child) - case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child)) + case s @ Sort(_, global, child) => s.copy(child = recursiveRemoveSort(child, global)) case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) => - j.copy(left = recursiveRemoveSort(originLeft), right = recursiveRemoveSort(originRight)) + j.copy(left = recursiveRemoveSort(originLeft, true), + right = recursiveRemoveSort(originRight, true)) case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) => - g.copy(child = recursiveRemoveSort(originChild)) + g.copy(child = recursiveRemoveSort(originChild, true)) } - private def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match { - case Sort(_, _, child) => recursiveRemoveSort(child) - case other if canEliminateSort(other) => - other.withNewChildren(other.children.map(recursiveRemoveSort)) - case _ => plan + /** + * If the upper sort is global then we can remove the global or local sort recursively. + * If the upper sort is local then we can only remove the local sort recursively. + */ + private def recursiveRemoveSort( + plan: LogicalPlan, + canRemoveGlobalSort: Boolean): LogicalPlan = { + plan match { + case Sort(_, global, child) if canRemoveGlobalSort || !global => + recursiveRemoveSort(child, canRemoveGlobalSort) + case Sort(sortOrder, true, child) => + // For this case, the upper sort is local so the ordering of present sort is unnecessary, + // so here we only preserve its output partitioning using `RepartitionByExpression`. + // We should use `None` as the optNumPartitions so AQE can coalesce shuffle partitions. + // This behavior is same with original global sort. + RepartitionByExpression(sortOrder, recursiveRemoveSort(child, true), None) + case other if canEliminateSort(other) => + other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, canRemoveGlobalSort))) + case other if canEliminateGlobalSort(other) => + other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, true))) + case _ => plan + } } private def canEliminateSort(plan: LogicalPlan): Boolean = plan match { case p: Project => p.projectList.forall(_.deterministic) case f: Filter => f.condition.deterministic + case _ => false + } + + private def canEliminateGlobalSort(plan: LogicalPlan): Boolean = plan match { case r: RepartitionByExpression => r.partitionExpressions.forall(_.deterministic) case _: Repartition => true case _ => false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 62deebd930752..ca7d386480977 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -415,4 +415,24 @@ class EliminateSortsSuite extends AnalysisTest { comparePlans(optimized, correctAnswer) } } + + test("SPARK-39835: Fix EliminateSorts remove global sort below the local sort") { + // global -> local + val plan = testRelation.orderBy($"a".asc).sortBy($"c".asc).analyze + val expect = RepartitionByExpression($"a".asc :: Nil, testRelation, None) + .sortBy($"c".asc).analyze + comparePlans(Optimize.execute(plan), expect) + + // global -> global -> local + val plan2 = testRelation.orderBy($"a".asc).orderBy($"b".asc).sortBy($"c".asc).analyze + val expected2 = RepartitionByExpression($"b".asc :: Nil, testRelation, None) + .sortBy($"c".asc).analyze + comparePlans(Optimize.execute(plan2), expected2) + + // local -> global -> local + val plan3 = testRelation.sortBy($"a".asc).orderBy($"b".asc).sortBy($"c".asc).analyze + val expected3 = RepartitionByExpression($"b".asc :: Nil, testRelation, None) + .sortBy($"c".asc).analyze + comparePlans(Optimize.execute(plan3), expected3) + } }