Skip to content

Commit cfbeea7

Browse files
committed
Merge branch 'unionLimit' into unionLimit2
2 parents 10d570c + 2823a57 commit cfbeea7

File tree

5 files changed

+121
-3
lines changed

5 files changed

+121
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
5353
// Operator combine
5454
ProjectCollapsing,
5555
CombineFilters,
56-
CombineLimits,
5756
// Constant folding
5857
NullPropagation,
5958
OptimizeIn,
@@ -64,6 +63,11 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
6463
SimplifyFilters,
6564
SimplifyCasts,
6665
SimplifyCaseConversionExpressions) ::
66+
Batch("Push Down Limits", FixedPoint(100),
67+
PushDownLimit,
68+
CombineLimits,
69+
ConstantFolding,
70+
BooleanSimplification) ::
6771
Batch("Decimal Optimizations", FixedPoint(100),
6872
DecimalAggregates) ::
6973
Batch("LocalRelation", FixedPoint(100),
@@ -79,6 +83,36 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
7983
*/
8084
object DefaultOptimizer extends Optimizer
8185

86+
/**
87+
* Pushes down Limit for reducing the amount of returned data.
88+
*
89+
* 1. Adding Extra Limit beneath the operations, including Union All.
90+
* 2. Project is pushed through Limit in the rule ColumnPruning
91+
*
92+
* Any operator that a Limit can be pushed passed should override the maxRows function.
93+
*
94+
* Note: This rule has to be done when the logical plan is stable;
95+
* Otherwise, it could impact the other rules.
96+
*/
97+
object PushDownLimit extends Rule[LogicalPlan] {
98+
99+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
100+
101+
// Adding extra Limit below UNION ALL iff both left and right childs are not Limit or
102+
// do not have Limit descendants. This heuristic is valid assuming there does not exist
103+
// any Limit push-down rule that is unable to infer the value of maxRows.
104+
// Note, right now, Union means UNION ALL, which does not de-duplicate rows. So, it is
105+
// safe to pushdown Limit through it. Once we add UNION DISTINCT, we will not be able to
106+
// pushdown Limit.
107+
case Limit(exp, Union(left, right))
108+
if left.maxRows.isEmpty || right.maxRows.isEmpty =>
109+
Limit(exp,
110+
Union(
111+
Limit(exp, left),
112+
Limit(exp, right)))
113+
}
114+
}
115+
82116
/**
83117
* Pushes operations down into a Sample.
84118
*/
@@ -97,8 +131,8 @@ object SamplePushDown extends Rule[LogicalPlan] {
97131
* Operations that are safe to pushdown are listed as follows.
98132
* Union:
99133
* Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is
100-
* safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT,
101-
* we will not be able to pushdown Projections.
134+
* safe to pushdown Filters, Projections and Limits through it. Once we add UNION DISTINCT,
135+
* we will not be able to pushdown Projections and Limits.
102136
*
103137
* Intersect:
104138
* It is not safe to pushdown Projections through it because we need to get the

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
9090
Statistics(sizeInBytes = children.map(_.statistics.sizeInBytes).product)
9191
}
9292

93+
/**
94+
* Returns the limited number of rows to be returned.
95+
*
96+
* Any operator that a Limit can be pushed passed should override this function.
97+
*/
98+
def maxRows: Option[Expression] = None
99+
93100
/**
94101
* Returns true if this expression and all its children have been resolved to a specific schema
95102
* and false if it still contains any unresolved placeholders. Implementations of LogicalPlan

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ import scala.collection.mutable.ArrayBuffer
2828
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
2929
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
3030

31+
override def maxRows: Option[Expression] = child.maxRows
32+
3133
override lazy val resolved: Boolean = {
3234
val hasSpecialExpressions = projectList.exists ( _.collect {
3335
case agg: AggregateExpression => agg
@@ -109,6 +111,9 @@ private[sql] object SetOperation {
109111

110112
case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
111113

114+
override def maxRows: Option[Expression] =
115+
for (leftMax <- left.maxRows; rightMax <- right.maxRows) yield Add(leftMax, rightMax)
116+
112117
override def statistics: Statistics = {
113118
val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes
114119
Statistics(sizeInBytes = sizeInBytes)
@@ -451,6 +456,8 @@ case class Pivot(
451456
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
452457
override def output: Seq[Attribute] = child.output
453458

459+
override def maxRows: Option[Expression] = Option(limitExpr)
460+
454461
override lazy val statistics: Statistics = {
455462
val limit = limitExpr.eval().asInstanceOf[Int]
456463
val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
21+
import org.apache.spark.sql.catalyst.plans.PlanTest
22+
import org.apache.spark.sql.catalyst.plans.logical._
23+
import org.apache.spark.sql.catalyst.rules._
24+
import org.apache.spark.sql.catalyst.dsl.plans._
25+
import org.apache.spark.sql.catalyst.dsl.expressions._
26+
27+
class PushdownLimitsSuite extends PlanTest {
28+
object Optimize extends RuleExecutor[LogicalPlan] {
29+
val batches =
30+
Batch("Subqueries", Once,
31+
EliminateSubQueries) ::
32+
Batch("Push Down Limit", Once,
33+
PushDownLimit,
34+
CombineLimits,
35+
ConstantFolding,
36+
BooleanSimplification) :: Nil
37+
}
38+
39+
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
40+
val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
41+
42+
test("Union: limit to each side") {
43+
val unionQuery = Union(testRelation, testRelation2).limit(1)
44+
val unionOptimized = Optimize.execute(unionQuery.analyze)
45+
val unionCorrectAnswer =
46+
Limit(1, Union(testRelation.limit(1), testRelation2.limit(1))).analyze
47+
comparePlans(unionOptimized, unionCorrectAnswer)
48+
}
49+
50+
test("Union: limit to each side with the new limit number") {
51+
val testLimitUnion = Union(testRelation, testRelation2.limit(3))
52+
val unionQuery = testLimitUnion.limit(1)
53+
val unionOptimized = Optimize.execute(unionQuery.analyze)
54+
val unionCorrectAnswer =
55+
Limit(1, Union(testRelation.limit(1), testRelation2.limit(1))).analyze
56+
comparePlans(unionOptimized, unionCorrectAnswer)
57+
}
58+
59+
test("Union: no limit to both sides") {
60+
val testLimitUnion = Union(testRelation.limit(2), testRelation2.select('d).limit(3))
61+
val unionQuery = testLimitUnion.limit(2)
62+
val unionOptimized = Optimize.execute(unionQuery.analyze)
63+
val unionCorrectAnswer =
64+
Limit(2, Union(testRelation.limit(2), testRelation2.select('d).limit(3))).analyze
65+
comparePlans(unionOptimized, unionCorrectAnswer)
66+
}
67+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ class SetOperationPushDownSuite extends PlanTest {
3131
EliminateSubQueries) ::
3232
Batch("Union Pushdown", Once,
3333
SetOperationPushDown,
34+
CombineLimits,
35+
ConstantFolding,
36+
BooleanSimplification,
3437
SimplifyFilters) :: Nil
3538
}
3639

0 commit comments

Comments
 (0)