diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f6088695a927..6fa5f99f373a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -47,6 +47,7 @@ object DefaultOptimizer extends Optimizer { PushPredicateThroughProject, PushPredicateThroughGenerate, PushPredicateThroughAggregate, + PushLimitThroughOuterJoin, ColumnPruning, // Operator combine ProjectCollapsing, @@ -857,6 +858,60 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Push [[Limit]] operators through [[Join]] operators or [[Project]] + [[Join]] operators, + * iff the join type is outer joins. + * Case 1: If the type is [[LeftOuter]] or [[RightOuter]], add extra [[Limit]] operators + * on top of the outer-side child. + * Case 2: If the type is [[FullOuter]] and only one child is [[Limit]], add extra [[Limit]] + * operators on the child that is not [[Limit]] + * Case 3: If the type is [[FullOuter]] and no child is [[Limit]], add extra [[Limit]] + * operators on the child whose statistics is higher. + */ +object PushLimitThroughOuterJoin extends Rule[LogicalPlan] with PredicateHelper { + + private def makeNewJoinWithLimit(j: Join, limitExpr: Expression): Join = { + j.joinType match { + // RightOuter join: + // Add extra Limit in the right child + case RightOuter => + Join(j.left, CombineLimits(Limit(limitExpr, j.right)), j.joinType, j.condition) + // LeftOuter join: + // Add extra Limit in the left child + case LeftOuter => + Join(CombineLimits(Limit(limitExpr, j.left)), j.right, j.joinType, j.condition) + // FullOuter join whose left child is not Limit but right child is Limit + // Add extra Limit in the right child + case FullOuter if !j.left.isInstanceOf[Limit] && j.right.isInstanceOf[Limit] => + Join(j.left, CombineLimits(Limit(limitExpr, j.right)), j.joinType, j.condition) + // FullOuter join whose left child is Limit but right child is not Limit + // Add extra Limit in the left child + case FullOuter if j.left.isInstanceOf[Limit] && !j.right.isInstanceOf[Limit] => + Join(CombineLimits(Limit(limitExpr, j.left)), j.right, j.joinType, j.condition) + // FullOuter join whose left and right children are not Limit: + // Add extra Limit in the child whose statistics is higher + case FullOuter if !j.left.isInstanceOf[Limit] && !j.right.isInstanceOf[Limit] => + if (j.left.statistics.sizeInBytes <= j.right.statistics.sizeInBytes) { + Join(j.left, CombineLimits(Limit(limitExpr, j.right)), j.joinType, j.condition) + } else { + Join(CombineLimits(Limit(limitExpr, j.left)), j.right, j.joinType, j.condition) + } + // DO Nothing for the other cases + case _ => j + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Limit(limitExpr, j: Join) => + Limit(limitExpr, + makeNewJoinWithLimit(j, limitExpr)) + case Limit(limitExpr, Project(projectList, j: Join)) => + Limit(limitExpr, + Project(projectList, + makeNewJoinWithLimit(j, limitExpr))) + } +} + /** * Removes [[Cast Casts]] that are unnecessary because the input is already the correct type. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index fba4c5ca77d6..3c3d063fb668 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftSemi, LeftOuter, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -41,6 +41,7 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughJoin, PushPredicateThroughGenerate, PushPredicateThroughAggregate, + PushLimitThroughOuterJoin, ColumnPruning, ProjectCollapsing) :: Nil } @@ -750,4 +751,74 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("limit: push down left outer join") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, LeftOuter) + .limit(1) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.limit(1) + val correctAnswer = + left.join(y, LeftOuter).limit(1).analyze + + comparePlans(optimized, correctAnswer) + } + + test("limit: push down right outer join") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, RightOuter) + .limit(1) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val right = testRelation.limit(1) + val correctAnswer = + x.join(right, RightOuter).limit(1).analyze + + comparePlans(optimized, correctAnswer) + } + + test("limit: push down full outer join") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, FullOuter) + .limit(1) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation.limit(1) + val correctAnswer = + left.join(right, FullOuter).limit(1).analyze + + comparePlans(optimized, correctAnswer) + } + + test("limit: push down full outer join + project") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y, FullOuter).select('a, 'b, 'd) + .limit(1) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.select('a, 'b) + val right = testRelation1.limit(1) + val correctAnswer = + left.join(right, FullOuter).select('a, 'b, 'd).limit(1).analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 48cab01ac100..deb76ad0a0a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}