diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0b1c74293bb8..b7e16cc69594 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -49,6 +49,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { PushPredicateThroughProject, PushPredicateThroughGenerate, PushPredicateThroughAggregate, + PushDownLimit, ColumnPruning, // Operator combine ProjectCollapsing, @@ -79,6 +80,39 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { */ object DefaultOptimizer extends Optimizer +/** + * Pushes down Limit for reducing the amount of returned data. + * + * 1. Adding Extra Limit beneath the operations, including Union All. + * 2. Project is pushed through Limit in the rule ColumnPruning + * + * Any operator that a Limit can be pushed passed should override the maxRows function. + */ +object PushDownLimit extends Rule[LogicalPlan] { + + private def buildUnionChild (limitExp: Expression, plan: LogicalPlan): LogicalPlan = { + (limitExp, plan.maxRows) match { + case (IntegerLiteral(maxRow), Some(IntegerLiteral(childMaxRows))) if maxRow < childMaxRows => + Limit(limitExp, plan) + case (_, None) => + Limit(limitExp, plan) + case _ => plan + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + + // Adding extra Limit below UNION ALL iff both left and right childs are not Limit or + // do not have Limit descendants whose maxRow is larger. This heuristic is valid assuming + // there does not exist any Limit push-down rule that is unable to infer the value of maxRows. + // Note, right now, Union means UNION ALL, which does not de-duplicate rows. So, it is + // safe to pushdown Limit through it. Once we add UNION DISTINCT, we will not be able to + // pushdown Limit. + case Limit(exp, Union(left, right)) => + Limit(exp, Union(buildUnionChild(exp, left), buildUnionChild(exp, right))) + } +} + /** * Pushes operations down into a Sample. */ @@ -97,8 +131,8 @@ object SamplePushDown extends Rule[LogicalPlan] { * Operations that are safe to pushdown are listed as follows. * Union: * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is - * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, - * we will not be able to pushdown Projections. + * safe to pushdown Filters, Projections and Limits through it. Once we add UNION DISTINCT, + * we will not be able to pushdown Projections and Limits. * * Intersect: * It is not safe to pushdown Projections through it because we need to get the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6d859551f8c5..687f8505a0ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -90,6 +90,14 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { Statistics(sizeInBytes = children.map(_.statistics.sizeInBytes).product) } + /** + * Returns the limited number of rows to be returned. + * + * Any operator that a Limit can be pushed passed should override this function. (e.g., Union) + * Any operator that can push through a Limit should override this function. (e.g., Project) + */ + def maxRows: Option[Expression] = None + /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5f34d4a4eb73..3b30a17175ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -28,6 +28,8 @@ import scala.collection.mutable.ArrayBuffer case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override def maxRows: Option[Expression] = child.maxRows + override lazy val resolved: Boolean = { val hasSpecialExpressions = projectList.exists ( _.collect { case agg: AggregateExpression => agg @@ -109,6 +111,9 @@ private[sql] object SetOperation { case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + override def maxRows: Option[Expression] = + for (leftMax <- left.maxRows; rightMax <- right.maxRows) yield Add(leftMax, rightMax) + override def statistics: Statistics = { val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes Statistics(sizeInBytes = sizeInBytes) @@ -451,6 +456,8 @@ case class Pivot( case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def maxRows: Option[Expression] = Option(limitExpr) + override lazy val statistics: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushdownLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushdownLimitsSuite.scala new file mode 100644 index 000000000000..f65abb8bdd34 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushdownLimitsSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class PushdownLimitsSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Push Down Limit", Once, + PushDownLimit, + CombineLimits, + ConstantFolding, + BooleanSimplification) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + test("Union: limit to each side") { + val unionQuery = Union(testRelation, testRelation2).limit(1) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(1, Union(testRelation.limit(1), testRelation2.limit(1))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: limit to each side with the new limit number") { + val testLimitUnion = Union(testRelation, testRelation2.limit(3)) + val unionQuery = testLimitUnion.limit(1) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(1, Union(testRelation.limit(1), testRelation2.limit(1))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: no limit to both sides if children having smaller limit values") { + val testLimitUnion = Union(testRelation.limit(1), testRelation2.select('d).limit(1)) + val unionQuery = testLimitUnion.limit(2) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(2, Union(testRelation.limit(1), testRelation2.select('d).limit(1))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: limit to each sides if children having larger limit values") { + val testLimitUnion = Union(testRelation.limit(3), testRelation2.select('d).limit(4)) + val unionQuery = testLimitUnion.limit(2) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(2, Union(testRelation.limit(2), testRelation2.select('d).limit(2))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } +}