diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 78fb8b5de8886..653d735da263a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1528,21 +1528,31 @@ object EliminateSorts extends Rule[LogicalPlan] { } case Sort(orders, false, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => applyLocally.lift(child).getOrElse(child) - case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child)) + case s @ Sort(_, global, child) => s.copy(child = recursiveRemoveSort(child, global)) case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) => - j.copy(left = recursiveRemoveSort(originLeft), right = recursiveRemoveSort(originRight)) + j.copy(left = recursiveRemoveSort(originLeft, true), + right = recursiveRemoveSort(originRight, true)) case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) => - g.copy(child = recursiveRemoveSort(originChild)) + g.copy(child = recursiveRemoveSort(originChild, true)) } - private def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = { + /** + * If the upper sort is global then we can remove the global or local sort recursively. + * If the upper sort is local then we can only remove the local sort recursively. + */ + private def recursiveRemoveSort( + plan: LogicalPlan, + canRemoveGlobalSort: Boolean): LogicalPlan = { if (!plan.containsPattern(SORT)) { return plan } plan match { - case Sort(_, _, child) => recursiveRemoveSort(child) + case Sort(_, global, child) if canRemoveGlobalSort || !global => + recursiveRemoveSort(child, canRemoveGlobalSort) case other if canEliminateSort(other) => - other.withNewChildren(other.children.map(recursiveRemoveSort)) + other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, canRemoveGlobalSort))) + case other if canEliminateGlobalSort(other) => + other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, true))) case _ => plan } } @@ -1550,6 +1560,10 @@ object EliminateSorts extends Rule[LogicalPlan] { private def canEliminateSort(plan: LogicalPlan): Boolean = plan match { case p: Project => p.projectList.forall(_.deterministic) case f: Filter => f.condition.deterministic + case _ => false + } + + private def canEliminateGlobalSort(plan: LogicalPlan): Boolean = plan match { case r: RepartitionByExpression => r.partitionExpressions.forall(_.deterministic) case r: RebalancePartitions => r.partitionExpressions.forall(_.deterministic) case _: Repartition => true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 865a06368a42e..1d879a7065e92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -424,4 +424,20 @@ class EliminateSortsSuite extends AnalysisTest { comparePlans(optimized, correctAnswer) } } + + test("SPARK-39835: Fix EliminateSorts remove global sort below the local sort") { + // global -> local + val plan = testRelation.orderBy($"a".asc).sortBy($"c".asc).analyze + comparePlans(Optimize.execute(plan), plan) + + // global -> global -> local + val plan2 = testRelation.orderBy($"a".asc).orderBy($"b".asc).sortBy($"c".asc).analyze + val expected2 = testRelation.orderBy($"b".asc).sortBy($"c".asc).analyze + comparePlans(Optimize.execute(plan2), expected2) + + // local -> global -> local + val plan3 = testRelation.sortBy($"a".asc).orderBy($"b".asc).sortBy($"c".asc).analyze + val expected3 = testRelation.orderBy($"b".asc).sortBy($"c".asc).analyze + comparePlans(Optimize.execute(plan3), expected3) + } }