diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d0a31e7620bb..18f2e7c28ac8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -80,7 +80,6 @@ class Analyzer( EliminateUnions), Batch("Resolution", fixedPoint, ResolveRelations :: - ResolveStar :: ResolveReferences :: ResolveGroupingAnalytics :: ResolvePivot :: @@ -374,91 +373,6 @@ class Analyzer( } } - /** - * Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output. - */ - object ResolveStar extends Rule[LogicalPlan] { - - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case p: LogicalPlan if !p.childrenResolved => p - // If the projection list contains Stars, expand it. - case p: Project if containsStar(p.projectList) => - p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) - // If the aggregate function argument contains Stars, expand it. - case a: Aggregate if containsStar(a.aggregateExpressions) => - if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) { - failAnalysis( - "Group by position: star is not allowed to use in the select list " + - "when using ordinals in group by") - } else { - a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) - } - // If the script transformation input contains Stars, expand it. - case t: ScriptTransformation if containsStar(t.input) => - t.copy( - input = t.input.flatMap { - case s: Star => s.expand(t.child, resolver) - case o => o :: Nil - } - ) - case g: Generate if containsStar(g.generator.children) => - failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF") - } - - /** - * Build a project list for Project/Aggregate and expand the star if possible - */ - private def buildExpandedProjectList( - exprs: Seq[NamedExpression], - child: LogicalPlan): Seq[NamedExpression] = { - exprs.flatMap { - // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") - case s: Star => s.expand(child, resolver) - // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b - case UnresolvedAlias(s: Star, _) => s.expand(child, resolver) - case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil - case o => o :: Nil - }.map(_.asInstanceOf[NamedExpression]) - } - - /** - * Returns true if `exprs` contains a [[Star]]. - */ - def containsStar(exprs: Seq[Expression]): Boolean = - exprs.exists(_.collect { case _: Star => true }.nonEmpty) - - /** - * Expands the matching attribute.*'s in `child`'s output. - */ - def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = { - expr.transformUp { - case f1: UnresolvedFunction if containsStar(f1.children) => - f1.copy(children = f1.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) - case c: CreateStruct if containsStar(c.children) => - c.copy(children = c.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) - case c: CreateArray if containsStar(c.children) => - c.copy(children = c.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) - case p: Murmur3Hash if containsStar(p.children) => - p.copy(children = p.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) - // count(*) has been replaced by count(1) - case o if containsStar(o.children) => - failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'") - } - } - } - /** * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from * a logical plan node's children. @@ -525,6 +439,29 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p + // If the projection list contains Stars, expand it. + case p: Project if containsStar(p.projectList) => + p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) + // If the aggregate function argument contains Stars, expand it. + case a: Aggregate if containsStar(a.aggregateExpressions) => + if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) { + failAnalysis( + "Group by position: star is not allowed to use in the select list " + + "when using ordinals in group by") + } else { + a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) + } + // If the script transformation input contains Stars, expand it. + case t: ScriptTransformation if containsStar(t.input) => + t.copy( + input = t.input.flatMap { + case s: Star => s.expand(t.child, resolver) + case o => o :: Nil + } + ) + case g: Generate if containsStar(g.generator.children) => + failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF") + // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) @@ -619,6 +556,59 @@ class Analyzer( def findAliases(projectList: Seq[NamedExpression]): AttributeSet = { AttributeSet(projectList.collect { case a: Alias => a.toAttribute }) } + + /** + * Build a project list for Project/Aggregate and expand the star if possible + */ + private def buildExpandedProjectList( + exprs: Seq[NamedExpression], + child: LogicalPlan): Seq[NamedExpression] = { + exprs.flatMap { + // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") + case s: Star => s.expand(child, resolver) + // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b + case UnresolvedAlias(s: Star, _) => s.expand(child, resolver) + case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil + case o => o :: Nil + }.map(_.asInstanceOf[NamedExpression]) + } + + /** + * Returns true if `exprs` contains a [[Star]]. + */ + def containsStar(exprs: Seq[Expression]): Boolean = + exprs.exists(_.collect { case _: Star => true }.nonEmpty) + + /** + * Expands the matching attribute.*'s in `child`'s output. + */ + def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = { + expr.transformUp { + case f1: UnresolvedFunction if containsStar(f1.children) => + f1.copy(children = f1.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case c: CreateStruct if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case c: CreateArray if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case p: Murmur3Hash if containsStar(p.children) => + p.copy(children = p.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + // count(*) has been replaced by count(1) + case o if containsStar(o.children) => + failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'") + } + } } protected[sql] def resolveExpression( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4cfdcf95cb92..a7a948ef1b97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -306,21 +306,21 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { } /** - * Attempts to eliminate the reading of unneeded columns from the query plan using the following - * transformations: + * Attempts to eliminate the reading of unneeded columns from the query plan. * - * - Inserting Projections beneath the following operators: - * - Aggregate - * - Generate - * - Project <- Join - * - LeftSemiJoin + * Since adding Project before Filter conflicts with PushPredicatesThroughProject, this rule will + * remove the Project p2 in the following pattern: + * + * p1 @ Project(_, Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(p2.inputSet) + * + * p2 is usually inserted by this rule and useless, p1 could prune the columns anyway. */ object ColumnPruning extends Rule[LogicalPlan] { private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = output1.size == output2.size && output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform { // Prunes the unused columns from project list of Project/Aggregate/Expand case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) @@ -399,7 +399,7 @@ object ColumnPruning extends Rule[LogicalPlan] { } else { p } - } + }) /** Applies a projection only when the child is producing unnecessary attributes */ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = @@ -408,6 +408,16 @@ object ColumnPruning extends Rule[LogicalPlan] { } else { c } + + /** + * The Project before Filter is not necessary but conflict with PushPredicatesThroughProject, + * so remove it. + */ + private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform { + case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child))) + if p2.outputSet.subsetOf(child.outputSet) => + p1.copy(child = f.copy(child = child)) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 8e30349f50f0..6fc828f63f15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -22,8 +22,10 @@ import scala.collection.JavaConverters._ import com.google.common.util.concurrent.AtomicLongMap import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.util.Utils object RuleExecutor { protected val timeMap = AtomicLongMap.create[String]() @@ -98,7 +100,12 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (iteration > batch.strategy.maxIterations) { // Only log if this is a rule that is supposed to run more than once. if (iteration != 2) { - logInfo(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") + val message = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}" + if (Utils.isTesting) { + throw new TreeNodeException(curPlan, message, null) + } else { + logWarning(message) + } } continue = false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 9563f43259fb..659cffded7ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -29,7 +29,7 @@ class AnalysisSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.analysis.TestRelations._ test("union project *") { - val plan = (1 to 100) + val plan = (1 to 120) .map(_ => testRelation) .fold[LogicalPlan](testRelation) { (a, b) => a.select(UnresolvedStar(None)).select('a).union(b.select(UnresolvedStar(None))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index dd7d65ddc9e9..2248e03b2fc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -34,6 +34,7 @@ class ColumnPruningSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Column pruning", FixedPoint(100), + PushPredicateThroughProject, ColumnPruning, CollapseProject) :: Nil } @@ -133,12 +134,16 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning on Filter") { val input = LocalRelation('a.int, 'b.string, 'c.double) + val plan1 = Filter('a > 1, input).analyze + comparePlans(Optimize.execute(plan1), plan1) val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze - val expected = - Project('a :: Nil, - Filter('c > Literal(0.0), - Project(Seq('a, 'c), input))).analyze - comparePlans(Optimize.execute(query), expected) + comparePlans(Optimize.execute(query), query) + val plan2 = Filter('b > 1, Project(Seq('a, 'b), input)).analyze + val expected2 = Project(Seq('a, 'b), Filter('b > 1, input)).analyze + comparePlans(Optimize.execute(plan2), expected2) + val plan3 = Project(Seq('a), Filter('b > 1, Project(Seq('a, 'b), input))).analyze + val expected3 = Project(Seq('a), Filter('b > 1, input)).analyze + comparePlans(Optimize.execute(plan3), expected3) } test("Column pruning on except/intersect/distinct") { @@ -297,7 +302,7 @@ class ColumnPruningSuite extends PlanTest { SortOrder('b, Ascending) :: Nil, UnspecifiedFrame)).as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil) - .select('a, 'c, 'window).where('window > 1).select('a, 'c).analyze + .where('window > 1).select('a, 'c).analyze val optimized = Optimize.execute(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index a7de7b052bdc..c9d36910b099 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.trees import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} class RuleExecutorSuite extends SparkFunSuite { @@ -49,6 +51,9 @@ class RuleExecutorSuite extends SparkFunSuite { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } - assert(ToFixedPoint.execute(Literal(100)) === Literal(90)) + val message = intercept[TreeNodeException[LogicalPlan]] { + ToFixedPoint.execute(Literal(100)) + }.getMessage + assert(message.contains("Max iterations (10) reached for batch fixedPoint")) } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 05f59f15455e..5e6860d7e14c 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -338,6 +338,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_round_3", "view_cast", + // enable this after fixing SPARK-14137 + "union20", + // These tests check the VIEW table definition, but Spark handles CREATE VIEW itself and // generates different View Expanded Text. "alter_view_as_select", @@ -1040,7 +1043,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "union18", "union19", "union2", - "union20", "union22", "union23", "union24",