diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2de92d06ec83..905984573283 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -61,15 +61,16 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down - SetOperationPushDown, SamplePushDown, ReorderJoin, OuterJoinElimination, - PushPredicateThroughJoin, PushPredicateThroughProject, + SetOperationPushDown, + PushPredicateThroughJoin, PushPredicateThroughGenerate, PushPredicateThroughAggregate, LimitPushDown, + PushProjectThroughFilter, ColumnPruning, EliminateOperators, // Operator combine @@ -91,6 +92,10 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { SimplifyCasts, SimplifyCaseConversionExpressions, EliminateSerialization) :: + // Because ColumnPruning is called after PushPredicateThroughProject, the predicate push down + // is reversed. This batch is to ensure Filter is pushed below Project, if possible. + Batch("Push Predicate Through Project", Once, + PushPredicateThroughProject) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -306,14 +311,28 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { } /** - * Attempts to eliminate the reading of unneeded columns from the query plan using the following - * transformations: + * Attempts to eliminate the reading of unneeded columns from the query plan + * by pushing Project through Filter. * - * - Inserting Projections beneath the following operators: - * - Aggregate - * - Generate - * - Project <- Join - * - LeftSemiJoin + * Note: This rule could reverse the effects of PushPredicateThroughProject. + * This rule should be run before ColumnPruning for ensuring that Project can be + * pushed as low as possible. + */ +object PushProjectThroughFilter extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case p @ Project(projectList, f: Filter) + if f.condition.deterministic && projectList.forall(_.deterministic) => + val required = f.references ++ p.references + if ((f.inputSet -- required).nonEmpty) { + p.copy(child = f.copy(child = ColumnPruning.prunedChild(f.child, required))) + } else { + p + } + } +} + +/** + * Attempts to eliminate the reading of unneeded columns from the query plan */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -392,7 +411,7 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** Applies a projection only when the child is producing unnecessary attributes */ - private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = + def prunedChild(c: LogicalPlan, allReferences: AttributeSet): LogicalPlan = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { Project(c.output.filter(allReferences.contains), c) } else { @@ -874,6 +893,10 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { * that were defined in the projection. * * This heuristic is valid assuming the expression evaluation cost is minimal. + * + * Note: Because PushProjectThroughFilter could reverse the effect of PushPredicateThroughProject, + * PushPredicateThroughProject needs to be called before the other Predicate Push Down rules for + * ensuring the predicates can be pushed as low as possible. */ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 6187fb9e2fb8..2ec4a934b803 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -34,6 +34,9 @@ class ColumnPruningSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Column pruning", FixedPoint(100), + PushPredicateThroughProject, + PushPredicateThroughJoin, + PushProjectThroughFilter, ColumnPruning, EliminateOperators, CollapseProject) :: Nil