diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala index eaea167ee9ff2..b7c5a851f7bbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentRow, IntegerLiteral, NamedExpression, RankLike, RowFrame, RowNumberLike, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CurrentRow, IntegerLiteral, NamedExpression, RankLike, RowFrame, RowNumberLike, SortOrder, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.logical.{Limit, LocalLimit, LogicalPlan, Project, Sort, Window} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{LIMIT, WINDOW} @@ -33,7 +33,7 @@ object LimitPushDownThroughWindow extends Rule[LogicalPlan] { // The window frame of RankLike and RowNumberLike can only be UNBOUNDED PRECEDING to CURRENT ROW. private def supportsPushdownThroughWindow( windowExpressions: Seq[NamedExpression]): Boolean = windowExpressions.forall { - case Alias(WindowExpression(_: RankLike | _: RowNumberLike, WindowSpecDefinition(Nil, _, + case Alias(WindowExpression(_: RankLike | _: RowNumberLike, WindowSpecDefinition(_, _, SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))), _) => true case _ => false } @@ -42,17 +42,19 @@ object LimitPushDownThroughWindow extends Rule[LogicalPlan] { _.containsAllPatterns(WINDOW, LIMIT), ruleId) { // Adding an extra Limit below WINDOW when the partitionSpec of all window functions is empty. case LocalLimit(limitExpr @ IntegerLiteral(limit), - window @ Window(windowExpressions, Nil, orderSpec, child)) + window @ Window(windowExpressions, partitionSpec, orderSpec, child)) if supportsPushdownThroughWindow(windowExpressions) && child.maxRows.forall(_ > limit) && limit < conf.topKSortFallbackThreshold => // Sort is needed here because we need global sort. - window.copy(child = Limit(limitExpr, Sort(orderSpec, true, child))) + window.copy(child = Limit(limitExpr, + Sort(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec, true, child))) // There is a Project between LocalLimit and Window if they do not have the same output. case LocalLimit(limitExpr @ IntegerLiteral(limit), project @ Project(_, - window @ Window(windowExpressions, Nil, orderSpec, child))) + window @ Window(windowExpressions, partitionSpec, orderSpec, child))) if supportsPushdownThroughWindow(windowExpressions) && child.maxRows.forall(_ > limit) && limit < conf.topKSortFallbackThreshold => // Sort is needed here because we need global sort. - project.copy(child = window.copy(child = Limit(limitExpr, Sort(orderSpec, true, child)))) + project.copy(child = window.copy(child = Limit(limitExpr, + Sort(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec, true, child)))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala index f2c1f452d0203..402faad1bf801 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala @@ -155,15 +155,21 @@ class LimitPushdownThroughWindowSuite extends PlanTest { WithoutOptimize.execute(correctAnswer.analyze)) } - test("Should not push down if partitionSpec is not empty") { + test("Should push down if partitionSpec is not empty") { val originalQuery = testRelation .select(a, b, c, windowExpr(RowNumber(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rn")) .limit(2) + val correctAnswer = testRelation + .select(a, b, c) + .orderBy(a.asc, c.desc) + .limit(2) + .select(a, b, c, + windowExpr(RowNumber(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rn")) comparePlans( Optimize.execute(originalQuery.analyze), - WithoutOptimize.execute(originalQuery.analyze)) + WithoutOptimize.execute(correctAnswer.analyze)) } test("Should not push down when child's maxRows smaller than limit value") { @@ -187,4 +193,125 @@ class LimitPushdownThroughWindowSuite extends PlanTest { Optimize.execute(originalQuery.analyze), WithoutOptimize.execute(originalQuery.analyze)) } + + test("Should push down if partitionSpec is not empty and with multi partitionSpec") { + val originalQuery = testRelation + .select(a, b, c, + windowExpr(RowNumber(), windowSpec(a :: b :: Nil, c.desc :: Nil, windowFrame)).as("rn")) + .limit(2) + val correctAnswer = testRelation + .select(a, b, c) + .orderBy(a.asc, b.asc, c.desc) + .limit(2) + .select(a, b, c, + windowExpr(RowNumber(), windowSpec(a :: b :: Nil, c.desc :: Nil, windowFrame)).as("rn")) + + comparePlans( + Optimize.execute(originalQuery.analyze), + WithoutOptimize.execute(correctAnswer.analyze)) + } + + test("Push down limit through window for multiple window functions " + + "when all partitionSpec is not empty and same") { + val originalQuery = testRelation + .select(a, b, c, + windowExpr(RowNumber(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rn"), + windowExpr(new Rank(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rk")) + .limit(2) + val correctAnswer = testRelation + .select(a, b, c) + .orderBy(a.asc, c.desc) + .limit(2) + .select(a, b, c, + windowExpr(RowNumber(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rn"), + windowExpr(new Rank(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rk")) + + comparePlans( + Optimize.execute(originalQuery.analyze), + WithoutOptimize.execute(correctAnswer.analyze)) + } + + test("Push down limit through window for multiple window functions " + + "when partitionSpec is not empty and not same") { + val originalQuery = testRelation + .select(a, b, c, + windowExpr(RowNumber(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rn"), + windowExpr(new Rank(), windowSpec(b :: Nil, c.desc :: Nil, windowFrame)).as("rk")) + .limit(2) + val correctAnswer = testRelation + .select(a, b, c, + windowExpr(RowNumber(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rn")) + .orderBy(b.asc, c.desc) + .limit(2) + .select(a, b, c, $"rn".attr, + windowExpr(new Rank(), windowSpec(b :: Nil, c.desc :: Nil, windowFrame)).as("rk")) + + comparePlans( + Optimize.execute(originalQuery.analyze), + WithoutOptimize.execute(correctAnswer.analyze)) + } + + test("Push down limit through window respect spark.sql.execution.topKSortFallbackThreshold " + + "when partitionSpec is not empty") { + Seq(1, 100).foreach { threshold => + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> threshold.toString) { + val originalQuery = testRelation + .select(a, b, c, + windowExpr(RowNumber(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rn")) + .limit(2) + val correctAnswer = if (threshold == 1) { + originalQuery + } else { + testRelation + .select(a, b, c) + .orderBy(a.asc, c.desc) + .limit(2) + .select(a, b, c, + windowExpr(RowNumber(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rn")) + } + + comparePlans( + Optimize.execute(originalQuery.analyze), + WithoutOptimize.execute(correctAnswer.analyze)) + } + } + } + + test("Push down to first window if order column is different " + + "when partitionSpec is not empty") { + val originalQuery = testRelation + .select(a, b, c, + windowExpr(RowNumber(), windowSpec(a :: Nil, b.desc :: Nil, windowFrame)).as("rn"), + windowExpr(new Rank(), windowSpec(a :: Nil, c.asc :: Nil, windowFrame)).as("rk")) + .limit(2) + val correctAnswer = testRelation + .select(a, b, c, + windowExpr(RowNumber(), windowSpec(a :: Nil, b.desc :: Nil, windowFrame)).as("rn")) + .orderBy(a.asc, c.asc) + .limit(2) + .select(a, b, c, $"rn".attr, + windowExpr(new Rank(), windowSpec(a :: Nil, c.asc :: Nil, windowFrame)).as("rk")) + + comparePlans( + Optimize.execute(originalQuery.analyze), + WithoutOptimize.execute(correctAnswer.analyze)) + } + + test("Should push down if is a Project between LocalLimit and Window " + + "when partitionSpec is not empty") { + val originalQuery = testRelation + .select(a, b, + windowExpr(RowNumber(), windowSpec(a :: Nil, b.desc :: Nil, windowFrame)).as("rn")) + .select(a, $"rn".attr) + .limit(2) + val correctAnswer = testRelation + .select(a, b) + .orderBy(a.asc, b.desc) + .limit(2) + .select(a, windowExpr(RowNumber(), windowSpec(a :: Nil, b.desc :: Nil, windowFrame)).as("rn")) + + comparePlans( + Optimize.execute(originalQuery.analyze), + WithoutOptimize.execute(correctAnswer.analyze)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 26af0b9f81127..f0ab1312cd217 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4112,6 +4112,49 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } checkAnswer(sql(s"select /*+ REPARTITION(3, a) */ a b from values('123') t(a)"), Row("123")) } + + test("SPARK-34775 Push down limit through window when partitionSpec is not empty") { + withTable("t1", "t2") { + var numRows = 20 + spark.range(numRows) + .selectExpr("id % 10 AS a", s"$numRows - id AS b") + .write + .saveAsTable("t1") + + val df1 = spark.sql( + """ + |SELECT a, b, ROW_NUMBER() OVER(PARTITION BY a ORDER BY b) AS rn, + |RANK() OVER(PARTITION BY a ORDER BY b) AS rk, + |DENSE_RANK() OVER(PARTITION BY a ORDER BY b) AS drk + |FROM t1 LIMIT 3 + |""".stripMargin) + val pushedLocalLimits1 = df1.queryExecution.optimizedPlan.collect { + case l @ LocalLimit(_, _: Sort) => l + } + assert(pushedLocalLimits1.length === 1) + checkAnswer(df1, Seq(Row(0, 10, 1, 1, 1), Row(0, 20, 2, 2, 2), Row(1, 9, 1, 1, 1))) + + + numRows = 10 + spark.range(numRows) + .selectExpr("if (id % 2 = 0, null, id) AS a", s"$numRows - id AS b") + .write + .saveAsTable("t2") + val df2 = spark.sql( + """ + |SELECT a, b, ROW_NUMBER() OVER(PARTITION BY a ORDER BY b) AS rn, + |RANK() OVER(PARTITION BY a ORDER BY b) AS rk, + |DENSE_RANK() OVER(PARTITION BY a ORDER BY b) AS drk + |FROM t2 LIMIT 3 + |""".stripMargin) + val pushedLocalLimits2 = df2.queryExecution.optimizedPlan.collect { + case l @ LocalLimit(_, _: Sort) => l + } + assert(pushedLocalLimits2.length === 1) + checkAnswer(df2, + Seq(Row(null, 2, 1, 1, 1), Row(null, 4, 2, 2, 2), Row(null, 6, 3, 3, 3))) + } + } } case class Foo(bar: Option[String])