diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala index 0e89e4a31bf2e..add35796d0254 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentRow, IntegerLiteral, NamedExpression, RankLike, RowFrame, RowNumberLike, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CurrentRow, IntegerLiteral, NamedExpression, RankLike, RowFrame, RowNumberLike, SortOrder, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.logical.{Limit, LocalLimit, LogicalPlan, Project, Sort, Window} import org.apache.spark.sql.catalyst.rules.Rule @@ -32,7 +32,7 @@ object LimitPushDownThroughWindow extends Rule[LogicalPlan] { // The window frame of RankLike and RowNumberLike can only be UNBOUNDED PRECEDING to CURRENT ROW. private def supportsPushdownThroughWindow( windowExpressions: Seq[NamedExpression]): Boolean = windowExpressions.forall { - case Alias(WindowExpression(_: RankLike | _: RowNumberLike, WindowSpecDefinition(Nil, _, + case Alias(WindowExpression(_: RankLike | _: RowNumberLike, WindowSpecDefinition(_, _, SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))), _) => true case _ => false } @@ -40,17 +40,19 @@ object LimitPushDownThroughWindow extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Adding an extra Limit below WINDOW when the partitionSpec of all window functions is empty. case LocalLimit(limitExpr @ IntegerLiteral(limit), - window @ Window(windowExpressions, Nil, orderSpec, child)) + window @ Window(windowExpressions, partitionSpec, orderSpec, child)) if supportsPushdownThroughWindow(windowExpressions) && child.maxRows.forall(_ > limit) && limit < conf.topKSortFallbackThreshold => // Sort is needed here because we need global sort. - window.copy(child = Limit(limitExpr, Sort(orderSpec, true, child))) + window.copy(child = Limit(limitExpr, + Sort(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec, true, child))) // There is a Project between LocalLimit and Window if they do not have the same output. case LocalLimit(limitExpr @ IntegerLiteral(limit), project @ Project(_, - window @ Window(windowExpressions, Nil, orderSpec, child))) + window @ Window(windowExpressions, partitionSpec, orderSpec, child))) if supportsPushdownThroughWindow(windowExpressions) && child.maxRows.forall(_ > limit) && limit < conf.topKSortFallbackThreshold => // Sort is needed here because we need global sort. - project.copy(child = window.copy(child = Limit(limitExpr, Sort(orderSpec, true, child)))) + project.copy(child = window.copy(child = Limit(limitExpr, + Sort(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec, true, child)))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala index f2c1f452d0203..3e547389ec627 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala @@ -155,15 +155,22 @@ class LimitPushdownThroughWindowSuite extends PlanTest { WithoutOptimize.execute(correctAnswer.analyze)) } - test("Should not push down if partitionSpec is not empty") { + test("Should push down if partitionSpec is not empty") { val originalQuery = testRelation .select(a, b, c, windowExpr(RowNumber(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rn")) .limit(2) + val correctAnswer = testRelation + .select(a, b, c) + .orderBy(a.asc, c.desc) + .limit(2) + .select(a, b, c, + windowExpr(RowNumber(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rn")) + comparePlans( Optimize.execute(originalQuery.analyze), - WithoutOptimize.execute(originalQuery.analyze)) + WithoutOptimize.execute(correctAnswer.analyze)) } test("Should not push down when child's maxRows smaller than limit value") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 00cbd73533ab9..766ee78af8eab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4116,6 +4116,49 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-34775 Push down limit through window when partitionSpec is not empty") { + withTable("t1", "t2") { + var numRows = 20 + spark.range(numRows) + .selectExpr("id % 10 AS a", s"$numRows - id AS b") + .write + .saveAsTable("t1") + + val df1 = spark.sql( + """ + |SELECT a, b, ROW_NUMBER() OVER(PARTITION BY a ORDER BY b) AS rn, + |RANK() OVER(PARTITION BY a ORDER BY b) AS rk, + |DENSE_RANK() OVER(PARTITION BY a ORDER BY b) AS drk + |FROM t1 LIMIT 3 + |""".stripMargin) + val pushedLocalLimits1 = df1.queryExecution.optimizedPlan.collect { + case l @ LocalLimit(_, _: Sort) => l + } + assert(pushedLocalLimits1.length === 1) + checkAnswer(df1, Seq(Row(0, 10, 1, 1, 1), Row(0, 20, 2, 2, 2), Row(1, 9, 1, 1, 1))) + + + numRows = 10 + spark.range(numRows) + .selectExpr("if (id % 2 = 0, null, id) AS a", s"$numRows - id AS b") + .write + .saveAsTable("t2") + val df2 = spark.sql( + """ + |SELECT a, b, ROW_NUMBER() OVER(PARTITION BY a ORDER BY b) AS rn, + |RANK() OVER(PARTITION BY a ORDER BY b) AS rk, + |DENSE_RANK() OVER(PARTITION BY a ORDER BY b) AS drk + |FROM t2 LIMIT 3 + |""".stripMargin) + val pushedLocalLimits2 = df2.queryExecution.optimizedPlan.collect { + case l @ LocalLimit(_, _: Sort) => l + } + assert(pushedLocalLimits2.length === 1) + checkAnswer(df2, + Seq(Row(null, 2, 1, 1, 1), Row(null, 4, 2, 2, 2), Row(null, 6, 3, 3, 3))) + } + } } case class Foo(bar: Option[String])