Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentRow, IntegerLiteral, NamedExpression, RankLike, RowFrame, RowNumberLike, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CurrentRow, IntegerLiteral, NamedExpression, RankLike, RowFrame, RowNumberLike, SortOrder, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.logical.{Limit, LocalLimit, LogicalPlan, Project, Sort, Window}
import org.apache.spark.sql.catalyst.rules.Rule

Expand All @@ -32,25 +32,27 @@ object LimitPushDownThroughWindow extends Rule[LogicalPlan] {
// The window frame of RankLike and RowNumberLike can only be UNBOUNDED PRECEDING to CURRENT ROW.
private def supportsPushdownThroughWindow(
windowExpressions: Seq[NamedExpression]): Boolean = windowExpressions.forall {
case Alias(WindowExpression(_: RankLike | _: RowNumberLike, WindowSpecDefinition(Nil, _,
case Alias(WindowExpression(_: RankLike | _: RowNumberLike, WindowSpecDefinition(_, _,
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))), _) => true
case _ => false
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Adding an extra Limit below WINDOW when the partitionSpec of all window functions is empty.
case LocalLimit(limitExpr @ IntegerLiteral(limit),
window @ Window(windowExpressions, Nil, orderSpec, child))
window @ Window(windowExpressions, partitionSpec, orderSpec, child))
if supportsPushdownThroughWindow(windowExpressions) && child.maxRows.forall(_ > limit) &&
limit < conf.topKSortFallbackThreshold =>
// Sort is needed here because we need global sort.
window.copy(child = Limit(limitExpr, Sort(orderSpec, true, child)))
window.copy(child = Limit(limitExpr,
Sort(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec, true, child)))
// There is a Project between LocalLimit and Window if they do not have the same output.
case LocalLimit(limitExpr @ IntegerLiteral(limit), project @ Project(_,
window @ Window(windowExpressions, Nil, orderSpec, child)))
window @ Window(windowExpressions, partitionSpec, orderSpec, child)))
if supportsPushdownThroughWindow(windowExpressions) && child.maxRows.forall(_ > limit) &&
limit < conf.topKSortFallbackThreshold =>
// Sort is needed here because we need global sort.
project.copy(child = window.copy(child = Limit(limitExpr, Sort(orderSpec, true, child))))
project.copy(child = window.copy(child = Limit(limitExpr,
Sort(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec, true, child))))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,22 @@ class LimitPushdownThroughWindowSuite extends PlanTest {
WithoutOptimize.execute(correctAnswer.analyze))
}

test("Should not push down if partitionSpec is not empty") {
test("Should push down if partitionSpec is not empty") {
val originalQuery = testRelation
.select(a, b, c,
windowExpr(RowNumber(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rn"))
.limit(2)

val correctAnswer = testRelation
.select(a, b, c)
.orderBy(a.asc, c.desc)
.limit(2)
.select(a, b, c,
windowExpr(RowNumber(), windowSpec(a :: Nil, c.desc :: Nil, windowFrame)).as("rn"))

comparePlans(
Optimize.execute(originalQuery.analyze),
WithoutOptimize.execute(originalQuery.analyze))
WithoutOptimize.execute(correctAnswer.analyze))
}

test("Should not push down when child's maxRows smaller than limit value") {
Expand Down
43 changes: 43 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4116,6 +4116,49 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}
}

test("SPARK-34775 Push down limit through window when partitionSpec is not empty") {
withTable("t1", "t2") {
var numRows = 20
spark.range(numRows)
.selectExpr("id % 10 AS a", s"$numRows - id AS b")
.write
.saveAsTable("t1")

val df1 = spark.sql(
"""
|SELECT a, b, ROW_NUMBER() OVER(PARTITION BY a ORDER BY b) AS rn,
|RANK() OVER(PARTITION BY a ORDER BY b) AS rk,
|DENSE_RANK() OVER(PARTITION BY a ORDER BY b) AS drk
|FROM t1 LIMIT 3
|""".stripMargin)
val pushedLocalLimits1 = df1.queryExecution.optimizedPlan.collect {
case l @ LocalLimit(_, _: Sort) => l
}
assert(pushedLocalLimits1.length === 1)
checkAnswer(df1, Seq(Row(0, 10, 1, 1, 1), Row(0, 20, 2, 2, 2), Row(1, 9, 1, 1, 1)))


numRows = 10
spark.range(numRows)
.selectExpr("if (id % 2 = 0, null, id) AS a", s"$numRows - id AS b")
.write
.saveAsTable("t2")
val df2 = spark.sql(
"""
|SELECT a, b, ROW_NUMBER() OVER(PARTITION BY a ORDER BY b) AS rn,
|RANK() OVER(PARTITION BY a ORDER BY b) AS rk,
|DENSE_RANK() OVER(PARTITION BY a ORDER BY b) AS drk
|FROM t2 LIMIT 3
|""".stripMargin)
val pushedLocalLimits2 = df2.queryExecution.optimizedPlan.collect {
case l @ LocalLimit(_, _: Sort) => l
}
assert(pushedLocalLimits2.length === 1)
checkAnswer(df2,
Seq(Row(null, 2, 1, 1, 1), Row(null, 4, 2, 2, 2), Row(null, 6, 3, 3, 3)))
}
}
}

case class Foo(bar: Option[String])