From 6c89a15ed8eb868b23237bba07498fb2053f4643 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 30 Jan 2017 13:11:46 +0100 Subject: [PATCH 1/7] Open-up TreeNode's transform logic. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 42 ++++++----------- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../spark/sql/catalyst/trees/TreeNode.scala | 46 +++++-------------- 3 files changed, 25 insertions(+), 65 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b108017c4c48..a5761703fd65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -242,31 +242,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * @param rule the rule to be applied to every expression in this operator. */ def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { - var changed = false - - @inline def transformExpressionDown(e: Expression): Expression = { - val newE = e.transformDown(rule) - if (newE.fastEquals(e)) { - e - } else { - changed = true - newE - } - } - - def recursiveTransform(arg: Any): AnyRef = arg match { - case e: Expression => transformExpressionDown(e) - case Some(e: Expression) => Some(transformExpressionDown(e)) - case m: Map[_, _] => m - case d: DataType => d // Avoid unpacking Structs - case seq: Traversable[_] => seq.map(recursiveTransform) - case other: AnyRef => other - case null => null - } - - val newArgs = mapProductIterator(recursiveTransform) - - if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this + mapExpressions(_.transformDown(rule)) } /** @@ -276,10 +252,18 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * @return */ def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { + mapExpressions(_.transformUp(rule)) + } + + /** + * Apply a map function to each expression present in this query operator, and return a new + * query operator based on the mapped expressions. + */ + def mapExpressions(f: Expression => Expression): this.type = { var changed = false - @inline def transformExpressionUp(e: Expression): Expression = { - val newE = e.transformUp(rule) + @inline def transformExpression(e: Expression): Expression = { + val newE = f(e) if (newE.fastEquals(e)) { e } else { @@ -289,8 +273,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT } def recursiveTransform(arg: Any): AnyRef = arg match { - case e: Expression => transformExpressionUp(e) - case Some(e: Expression) => Some(transformExpressionUp(e)) + case e: Expression => transformExpression(e) + case Some(e: Expression) => Some(transformExpression(e)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 93550e1fc32a..0937825e273a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -56,7 +56,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { if (!analyzed) { - val afterRuleOnChildren = transformChildren(rule, (t, r) => t.resolveOperators(r)) + val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) if (this fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[LogicalPlan]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 8fec9dd9b48b..f37661c31584 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -190,26 +190,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { arr } - /** - * Returns a copy of this node where `f` has been applied to all the nodes children. - */ - def mapChildren(f: BaseType => BaseType): BaseType = { - var changed = false - val newArgs = mapProductIterator { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = f(arg.asInstanceOf[BaseType]) - if (newChild fastEquals arg) { - arg - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - } - if (changed) makeCopy(newArgs) else this - } - /** * Returns a copy of this node with the children replaced. * TODO: Validate somewhere (in debug mode?) that children are ordered correctly. @@ -289,9 +269,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // Check if unchanged and then possibly return old copy to avoid gc churn. if (this fastEquals afterRule) { - transformChildren(rule, (t, r) => t.transformDown(r)) + mapChildren(_.transformDown(rule)) } else { - afterRule.transformChildren(rule, (t, r) => t.transformDown(r)) + afterRule.mapChildren(_.transformDown(rule)) } } @@ -303,7 +283,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * @param rule the function use to transform this nodes children */ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { - val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r)) + val afterRuleOnChildren = mapChildren(_.transformUp(rule)) if (this fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[BaseType]) @@ -316,18 +296,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } /** - * Returns a copy of this node where `rule` has been recursively applied to all the children of - * this node. When `rule` does not apply to a given node it is left unchanged. - * @param rule the function used to transform this nodes children + * Returns a copy of this node where `f` has been applied to all the nodes children. */ - protected def transformChildren( - rule: PartialFunction[BaseType, BaseType], - nextOperation: (BaseType, PartialFunction[BaseType, BaseType]) => BaseType): BaseType = { + def mapChildren(f: BaseType => BaseType): BaseType = { if (children.nonEmpty) { var changed = false val newArgs = mapProductIterator { case arg: TreeNode[_] if containsChild(arg) => - val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) + val newChild = f(arg.asInstanceOf[BaseType]) if (!(newChild fastEquals arg)) { changed = true newChild @@ -335,7 +311,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { arg } case Some(arg: TreeNode[_]) if containsChild(arg) => - val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) + val newChild = f(arg.asInstanceOf[BaseType]) if (!(newChild fastEquals arg)) { changed = true Some(newChild) @@ -344,7 +320,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } case m: Map[_, _] => m.mapValues { case arg: TreeNode[_] if containsChild(arg) => - val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) + val newChild = f(arg.asInstanceOf[BaseType]) if (!(newChild fastEquals arg)) { changed = true newChild @@ -356,7 +332,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if containsChild(arg) => - val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) + val newChild = f(arg.asInstanceOf[BaseType]) if (!(newChild fastEquals arg)) { changed = true newChild @@ -364,8 +340,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { arg } case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => - val newChild1 = nextOperation(arg1.asInstanceOf[BaseType], rule) - val newChild2 = nextOperation(arg2.asInstanceOf[BaseType], rule) + val newChild1 = f(arg1.asInstanceOf[BaseType]) + val newChild2 = f(arg2.asInstanceOf[BaseType]) if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { changed = true (newChild1, newChild2) From dac7ec99075ce98ebea92e108ad66b05537de396 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 31 Jan 2017 17:03:57 +0100 Subject: [PATCH 2/7] Split RemoveAliasOnlyProject into RemoveRedundantAliases and RemoveRedundantProject. --- .../sql/catalyst/optimizer/Optimizer.scala | 138 +++++++++++++----- ...RemoveRedundantAliasAndProjectSuite.scala} | 29 +++- 2 files changed, 119 insertions(+), 48 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{RemoveAliasOnlyProjectSuite.scala => RemoveRedundantAliasAndProjectSuite.scala} (74%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 55d37cce9911..1eb37e43fdb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -110,7 +110,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) SimplifyCaseConversionExpressions, RewriteCorrelatedScalarSubquery, EliminateSerialization, - RemoveAliasOnlyProject) :: + RemoveRedundantAliases, + RemoveRedundantProject) :: Batch("Check Cartesian Products", Once, CheckCartesianProducts(conf)) :: Batch("Decimal Optimizations", fixedPoint, @@ -154,56 +155,113 @@ class SimpleTestOptimizer extends Optimizer( new SimpleCatalystConf(caseSensitiveAnalysis = true)) /** - * Removes the Project only conducting Alias of its child node. - * It is created mainly for removing extra Project added in EliminateSerialization rule, - * but can also benefit other operators. + * Remove redundant aliases from a query plan. A redundant alias is an alias that does not change + * the name or metadata of a column, and does not deduplicate it. */ -object RemoveAliasOnlyProject extends Rule[LogicalPlan] { +object RemoveRedundantAliases extends Rule[LogicalPlan] { + /** - * Returns true if the project list is semantically same as child output, after strip alias on - * attribute. + * Replace the attributes in an expression using the given mapping. */ - private def isAliasOnly( - projectList: Seq[NamedExpression], - childOutput: Seq[Attribute]): Boolean = { - if (projectList.length != childOutput.length) { - false - } else { - stripAliasOnAttribute(projectList).zip(childOutput).forall { - case (a: Attribute, o) if a semanticEquals o => true - case _ => false - } + private def createAttributeMapping(current: LogicalPlan, next: LogicalPlan) + : Seq[(Attribute, Attribute)] = { + current.output.zip(next.output).filterNot { + case (a1, a2) => a1.semanticEquals(a2) } } - private def stripAliasOnAttribute(projectList: Seq[NamedExpression]) = { - projectList.map { - // Alias with metadata can not be stripped, or the metadata will be lost. - // If the alias name is different from attribute name, we can't strip it either, or we may - // accidentally change the output schema name of the root plan. - case a @ Alias(attr: Attribute, name) if a.metadata == Metadata.empty && name == attr.name => - attr - case other => other - } + /** + * Create an attribute map from a sequence of Attribute to Attribute mappings. + */ + private def toAttributeMap(pairs: Seq[(Attribute, Attribute)]): AttributeMap[Attribute] = { + val map = AttributeMap(pairs) + assert(map.size == pairs.size) + map + } + + /** + * Remove the top-level alias from an expression when it is redundant. + */ + private def removeRedundantAlias(e: Expression, blacklist: AttributeSet): Expression = e match { + // Alias with metadata can not be stripped, or the metadata will be lost. + // If the alias name is different from attribute name, we can't strip it either, or we + // may accidentally change the output schema name of the root plan. + case a @ Alias(attr: Attribute, name) + if a.metadata == Metadata.empty && name == attr.name && !blacklist.contains(attr) => + attr + case a => a } - def apply(plan: LogicalPlan): LogicalPlan = { - val aliasOnlyProject = plan.collectFirst { - case p @ Project(pList, child) if isAliasOnly(pList, child.output) => p + /** + * Get an appropriate alias cleaning method for the given node. + * + * We currently clean Project, Aggregate & Window nodes. + */ + private def getAliasCleaner(plan: LogicalPlan): (Expression, AttributeSet) => Expression = { + plan match { + case _: Project => removeRedundantAlias + case _: Aggregate => removeRedundantAlias + case _: Window => removeRedundantAlias + case _ => (e, _) => e } + } - aliasOnlyProject.map { case proj => - val attributesToReplace = proj.output.zip(proj.child.output).filterNot { - case (a1, a2) => a1 semanticEquals a2 - } - val attrMap = AttributeMap(attributesToReplace) - plan transform { - case plan: Project if plan eq proj => plan.child - case plan => plan transformExpressions { - case a: Attribute if attrMap.contains(a) => attrMap(a) + /** + * Remove redundant alias expression from a LogicalPlan and its subtree. A blacklist is used to + * prevent the removal of seemingly redundant aliases which are actually to deduplicate the + * input for a (self) join. + */ + private def removeRedundantAliases(plan: LogicalPlan, blacklist: AttributeSet): LogicalPlan = { + plan match { + // A join has to be treated differently, because the left and the right side of the join are + // not allowed to use the same attributes. We use a blacklist to prevent us from creating a + // situation in which this happens; the rule will only remove an alias if its child is not on + // the black list. + case Join(left, right, joinType, condition) => + val newLeft = removeRedundantAliases(left, blacklist ++ right.outputSet) + val newRight = removeRedundantAliases(right, blacklist ++ newLeft.outputSet) + val mapping = toAttributeMap( + createAttributeMapping(left, newLeft) ++ + createAttributeMapping(right, newRight)) + val newCondition = condition.map(_.transform { + case a: Attribute => mapping.getOrElse(a, a) + }) + Join(newLeft, newRight, joinType, newCondition) + + case _ => + // Drop blacklisted attributes that are masked in the current project. This allows us to + // remove redundant aliases in the subtree. + val childBlacklist = blacklist -- (plan.inputSet -- plan.outputSet) + + // Remove redundant aliases in the subtree(s). + val currentNextAttrPairs = mutable.Buffer.empty[(Attribute, Attribute)] + val newNode = plan.mapChildren { child => + val newChild = removeRedundantAliases(child, childBlacklist) + currentNextAttrPairs ++= createAttributeMapping(child, newChild) + newChild } - } - }.getOrElse(plan) + + // Transform the expressions. + val cleanExpression = getAliasCleaner(plan) + val mapping = toAttributeMap(currentNextAttrPairs) + newNode.mapExpressions { expr => + val newExpr = expr.transform { + case a: Attribute => mapping.getOrElse(a, a) + } + cleanExpression(newExpr, blacklist) + } + } + } + + def apply(plan: LogicalPlan): LogicalPlan = removeRedundantAliases(plan, AttributeSet.empty) +} + +/** + * Remove projections from the query plan that do not make any modifications. + */ +object RemoveRedundantProject extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p @ Project(_, child) if p.output == child.output => child } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveAliasOnlyProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala similarity index 74% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveAliasOnlyProjectSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala index 7c26cb5598b3..73b34eda51ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveAliasOnlyProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala @@ -25,10 +25,14 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.MetadataBuilder -class RemoveAliasOnlyProjectSuite extends PlanTest with PredicateHelper { +class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("RemoveAliasOnlyProject", FixedPoint(50), RemoveAliasOnlyProject) :: Nil + val batches = Batch( + "RemoveAliasOnlyProject", + FixedPoint(50), + RemoveRedundantAliases, + RemoveRedundantProject) :: Nil } test("all expressions in project list are aliased child output") { @@ -40,8 +44,8 @@ class RemoveAliasOnlyProjectSuite extends PlanTest with PredicateHelper { test("all expressions in project list are aliased child output but with different order") { val relation = LocalRelation('a.int, 'b.int) - val query = relation.select('b as 'b, 'a as 'a).analyze - val optimized = Optimize.execute(query) + val query = relation.select('b, 'a).analyze + val optimized = Optimize.execute(relation.select('b as 'b, 'a as 'a).analyze) comparePlans(optimized, query) } @@ -54,15 +58,15 @@ class RemoveAliasOnlyProjectSuite extends PlanTest with PredicateHelper { test("some expressions in project list are aliased child output but with different order") { val relation = LocalRelation('a.int, 'b.int) - val query = relation.select('b as 'b, 'a).analyze - val optimized = Optimize.execute(query) + val query = relation.select('b, 'a).analyze + val optimized = Optimize.execute(relation.select('b as 'b, 'a).analyze) comparePlans(optimized, query) } test("some expressions in project list are not Alias or Attribute") { val relation = LocalRelation('a.int, 'b.int) - val query = relation.select('a as 'a, 'b + 1).analyze - val optimized = Optimize.execute(query) + val query = relation.select('a, 'b + 1).analyze + val optimized = Optimize.execute(relation.select('a as 'a, 'b + 1).analyze) comparePlans(optimized, query) } @@ -74,4 +78,13 @@ class RemoveAliasOnlyProjectSuite extends PlanTest with PredicateHelper { val optimized = Optimize.execute(query) comparePlans(optimized, query) } + + test("do not dedup in cross join") { + val relation = LocalRelation('a.int) + val fragment = relation.select('a as 'a) + val query = relation.join(relation.select('a as 'a)).analyze + val optimized = Optimize.execute( + fragment.select('a as 'a).join(fragment.select('a as 'a)).analyze) + comparePlans(optimized, query) + } } From 6aad5d844b26f675cf13d3c898da785271ffcc7c Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 31 Jan 2017 19:20:46 +0100 Subject: [PATCH 3/7] Fix union. --- .../sql/catalyst/optimizer/Optimizer.scala | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1eb37e43fdb9..4ed6b74090e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -170,15 +170,6 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { } } - /** - * Create an attribute map from a sequence of Attribute to Attribute mappings. - */ - private def toAttributeMap(pairs: Seq[(Attribute, Attribute)]): AttributeMap[Attribute] = { - val map = AttributeMap(pairs) - assert(map.size == pairs.size) - map - } - /** * Remove the top-level alias from an expression when it is redundant. */ @@ -208,19 +199,19 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { /** * Remove redundant alias expression from a LogicalPlan and its subtree. A blacklist is used to - * prevent the removal of seemingly redundant aliases which are actually to deduplicate the - * input for a (self) join. + * prevent the removal of seemingly redundant aliases used to deduplicate the input for a (self) + * join. */ private def removeRedundantAliases(plan: LogicalPlan, blacklist: AttributeSet): LogicalPlan = { plan match { // A join has to be treated differently, because the left and the right side of the join are // not allowed to use the same attributes. We use a blacklist to prevent us from creating a - // situation in which this happens; the rule will only remove an alias if its child is not on - // the black list. + // situation in which this happens; the rule will only remove an alias if its child + // attribute is not on the black list. case Join(left, right, joinType, condition) => val newLeft = removeRedundantAliases(left, blacklist ++ right.outputSet) val newRight = removeRedundantAliases(right, blacklist ++ newLeft.outputSet) - val mapping = toAttributeMap( + val mapping = AttributeMap( createAttributeMapping(left, newLeft) ++ createAttributeMapping(right, newRight)) val newCondition = condition.map(_.transform { @@ -241,9 +232,13 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { newChild } + // Create the attribute mapping. Note that the currentNextAttrPairs can contain duplicate + // keys in case of Union (this is caused by the PushProjectionThroughUnion rule); in this + // case we use the the first mapping (which should be provided by the first child). + val mapping = AttributeMap(currentNextAttrPairs) + // Transform the expressions. val cleanExpression = getAliasCleaner(plan) - val mapping = toAttributeMap(currentNextAttrPairs) newNode.mapExpressions { expr => val newExpr = expr.transform { case a: Attribute => mapping.getOrElse(a, a) From 81f2fa5df7b56f96db396cdb7fd0a9b1d5e75c89 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 1 Feb 2017 15:48:30 +0100 Subject: [PATCH 4/7] Improve test coverage --- .../RemoveRedundantAliasAndProjectSuite.scala | 26 ++++++- .../test/resources/sql-tests/inputs/cte.sql | 15 ++++ .../test/resources/sql-tests/inputs/union.sql | 16 +++++ .../resources/sql-tests/results/cte.sql.out | 49 ++++++++++++- .../resources/sql-tests/results/union.sql.out | 70 ++++++++++++++++++- 5 files changed, 171 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala index 73b34eda51ef..2b55478f9bca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala @@ -31,6 +31,7 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper val batches = Batch( "RemoveAliasOnlyProject", FixedPoint(50), + PushProjectionThroughUnion, RemoveRedundantAliases, RemoveRedundantProject) :: Nil } @@ -79,7 +80,7 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper comparePlans(optimized, query) } - test("do not dedup in cross join") { + test("retain deduplicating alias in self-join") { val relation = LocalRelation('a.int) val fragment = relation.select('a as 'a) val query = relation.join(relation.select('a as 'a)).analyze @@ -87,4 +88,27 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper fragment.select('a as 'a).join(fragment.select('a as 'a)).analyze) comparePlans(optimized, query) } + + test("alias removal should not break after push project through union") { + val r1 = LocalRelation('a.int) + val r2 = LocalRelation('b.int) + val optimized = Optimize.execute( + r1.select('a as 'a).union(r2.select('b as 'b)).select('a).analyze) + val query = r1.union(r2) + comparePlans(optimized, query) + } + + test("remove redundant alias from aggregate") { + val relation = LocalRelation('a.int, 'b.int) + val optimized = Optimize.execute(relation.groupBy('a as 'a)('a as 'a, sum('b)).analyze) + val query = relation.groupBy('a)('a, sum('b)).analyze + comparePlans(optimized, query) + } + + test("remove redundant alias from window") { + val relation = LocalRelation('a.int, 'b.int) + val optimized = Optimize.execute(relation.window(Seq('b as 'b), Seq('a as 'a), Seq()).analyze) + val query = relation.window(Seq('b), Seq('a), Seq()).analyze + comparePlans(optimized, query) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte.sql b/sql/core/src/test/resources/sql-tests/inputs/cte.sql index 3914db26914b..d34d89f23575 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cte.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cte.sql @@ -12,3 +12,18 @@ WITH s1 AS (SELECT 1 FROM s2), s2 AS (SELECT 1 FROM s1) SELECT * FROM s1, s2; -- WITH clause should reference the previous CTE WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1 cross join t2; + +-- SPARK-18609 CTE with self-join +WITH CTE1 AS ( + SELECT b.id AS id + FROM T2 a + CROSS JOIN (SELECT id AS id FROM T2) b +) +SELECT t1.id AS c1, + t2.id AS c2 +FROM CTE1 t1 + CROSS JOIN CTE1 t2; + +-- Clean up +DROP VIEW IF EXISTS t; +DROP VIEW IF EXISTS t2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql index 1f4780abde2d..e57d69eaad03 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/union.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -22,6 +22,22 @@ FROM (SELECT 0 a, 0 b SELECT SUM(1) a, CAST(0 AS BIGINT) b UNION ALL SELECT 0 a, 0 b) T; +-- Regression test for SPARK-18841 Push project through union should not be broken by redundant alias removal. +CREATE OR REPLACE TEMPORARY VIEW p1 AS VALUES 1 T(col); +CREATE OR REPLACE TEMPORARY VIEW p2 AS VALUES 1 T(col); +CREATE OR REPLACE TEMPORARY VIEW p3 AS VALUES 1 T(col); +SELECT 1 AS x, + col +FROM (SELECT col AS col + FROM (SELECT p1.col AS col + FROM p1 CROSS JOIN p2 + UNION ALL + SELECT col + FROM p3) T1) T2; + -- Clean-up DROP VIEW IF EXISTS t1; DROP VIEW IF EXISTS t2; +DROP VIEW IF EXISTS p1; +DROP VIEW IF EXISTS p2; +DROP VIEW IF EXISTS p3; diff --git a/sql/core/src/test/resources/sql-tests/results/cte.sql.out b/sql/core/src/test/resources/sql-tests/results/cte.sql.out index 9fbad8f3800a..a446c2cd183d 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 9 -- !query 0 @@ -55,3 +55,50 @@ struct 0 2 1 2 1 2 + + +-- !query 6 +WITH CTE1 AS ( + SELECT b.id AS id + FROM T2 a + CROSS JOIN (SELECT id AS id FROM T2) b +) +SELECT t1.id AS c1, + t2.id AS c2 +FROM CTE1 t1 + CROSS JOIN CTE1 t2 +-- !query 6 schema +struct +-- !query 6 output +0 0 +0 0 +0 0 +0 0 +0 1 +0 1 +0 1 +0 1 +1 0 +1 0 +1 0 +1 0 +1 1 +1 1 +1 1 +1 1 + + +-- !query 7 +DROP VIEW IF EXISTS t +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +DROP VIEW IF EXISTS t2 +-- !query 8 schema +struct<> +-- !query 8 output + diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out index c57028cabe93..d123b7fdbe0c 100644 --- a/sql/core/src/test/resources/sql-tests/results/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 14 -- !query 0 @@ -65,7 +65,7 @@ struct -- !query 5 -DROP VIEW IF EXISTS t1 +CREATE OR REPLACE TEMPORARY VIEW p1 AS VALUES 1 T(col) -- !query 5 schema struct<> -- !query 5 output @@ -73,8 +73,72 @@ struct<> -- !query 6 -DROP VIEW IF EXISTS t2 +CREATE OR REPLACE TEMPORARY VIEW p2 AS VALUES 1 T(col) -- !query 6 schema struct<> -- !query 6 output + + +-- !query 7 +CREATE OR REPLACE TEMPORARY VIEW p3 AS VALUES 1 T(col) +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +SELECT 1 AS x, + col +FROM (SELECT col AS col + FROM (SELECT p1.col AS col + FROM p1 CROSS JOIN p2 + UNION ALL + SELECT col + FROM p3) T1) T2 +-- !query 8 schema +struct +-- !query 8 output +1 1 +1 1 + + +-- !query 9 +DROP VIEW IF EXISTS t1 +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +DROP VIEW IF EXISTS t2 +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +DROP VIEW IF EXISTS p1 +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +DROP VIEW IF EXISTS p2 +-- !query 12 schema +struct<> +-- !query 12 output + + + +-- !query 13 +DROP VIEW IF EXISTS p3 +-- !query 13 schema +struct<> +-- !query 13 output + From acbb9e03a4cb36525920985e96983c3abd0f7326 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 3 Feb 2017 21:47:28 +0100 Subject: [PATCH 5/7] Code review --- .../sql/catalyst/optimizer/Optimizer.scala | 35 +++++++------------ 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4ed6b74090e8..4684e0206abe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -183,20 +183,6 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { case a => a } - /** - * Get an appropriate alias cleaning method for the given node. - * - * We currently clean Project, Aggregate & Window nodes. - */ - private def getAliasCleaner(plan: LogicalPlan): (Expression, AttributeSet) => Expression = { - plan match { - case _: Project => removeRedundantAlias - case _: Aggregate => removeRedundantAlias - case _: Window => removeRedundantAlias - case _ => (e, _) => e - } - } - /** * Remove redundant alias expression from a LogicalPlan and its subtree. A blacklist is used to * prevent the removal of seemingly redundant aliases used to deduplicate the input for a (self) @@ -220,14 +206,10 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { Join(newLeft, newRight, joinType, newCondition) case _ => - // Drop blacklisted attributes that are masked in the current project. This allows us to - // remove redundant aliases in the subtree. - val childBlacklist = blacklist -- (plan.inputSet -- plan.outputSet) - // Remove redundant aliases in the subtree(s). val currentNextAttrPairs = mutable.Buffer.empty[(Attribute, Attribute)] val newNode = plan.mapChildren { child => - val newChild = removeRedundantAliases(child, childBlacklist) + val newChild = removeRedundantAliases(child, blacklist) currentNextAttrPairs ++= createAttributeMapping(child, newChild) newChild } @@ -237,13 +219,20 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { // case we use the the first mapping (which should be provided by the first child). val mapping = AttributeMap(currentNextAttrPairs) + // Create a an expression cleaning function for nodes that can actually produce redundant + // aliases, use identity otherwise. + val clean: Expression => Expression = plan match { + case _: Project => removeRedundantAlias(_, blacklist) + case _: Aggregate => removeRedundantAlias(_, blacklist) + case _: Window => removeRedundantAlias(_, blacklist) + case _ => identity[Expression] + } + // Transform the expressions. - val cleanExpression = getAliasCleaner(plan) newNode.mapExpressions { expr => - val newExpr = expr.transform { + clean(expr.transform { case a: Attribute => mapping.getOrElse(a, a) - } - cleanExpression(newExpr, blacklist) + }) } } } From 23743e1e52f7bb764eb5743abdd773b73166a7d7 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 7 Feb 2017 17:35:48 +0100 Subject: [PATCH 6/7] Update doc after CR --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c0cd81baf3a1..0c13e3e93a42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -164,7 +164,8 @@ class SimpleTestOptimizer extends Optimizer( object RemoveRedundantAliases extends Rule[LogicalPlan] { /** - * Replace the attributes in an expression using the given mapping. + * Create an attribute mapping from the old to the new attributes. This function will only + * return the attribute pairs that have changed. */ private def createAttributeMapping(current: LogicalPlan, next: LogicalPlan) : Seq[(Attribute, Attribute)] = { From 29c469643bdab299cafd4fbbd657d34601c0e690 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 7 Feb 2017 17:59:04 +0100 Subject: [PATCH 7/7] Unit test should have similar styles. --- .../RemoveRedundantAliasAndProjectSuite.scala | 51 ++++++++++--------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala index 2b55478f9bca..c01ea01ec680 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala @@ -45,9 +45,10 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper test("all expressions in project list are aliased child output but with different order") { val relation = LocalRelation('a.int, 'b.int) - val query = relation.select('b, 'a).analyze - val optimized = Optimize.execute(relation.select('b as 'b, 'a as 'a).analyze) - comparePlans(optimized, query) + val query = relation.select('b as 'b, 'a as 'a).analyze + val optimized = Optimize.execute(query) + val expected = relation.select('b, 'a).analyze + comparePlans(optimized, expected) } test("some expressions in project list are aliased child output") { @@ -59,16 +60,18 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper test("some expressions in project list are aliased child output but with different order") { val relation = LocalRelation('a.int, 'b.int) - val query = relation.select('b, 'a).analyze - val optimized = Optimize.execute(relation.select('b as 'b, 'a).analyze) - comparePlans(optimized, query) + val query = relation.select('b as 'b, 'a).analyze + val optimized = Optimize.execute(query) + val expected = relation.select('b, 'a).analyze + comparePlans(optimized, expected) } test("some expressions in project list are not Alias or Attribute") { val relation = LocalRelation('a.int, 'b.int) - val query = relation.select('a, 'b + 1).analyze - val optimized = Optimize.execute(relation.select('a as 'a, 'b + 1).analyze) - comparePlans(optimized, query) + val query = relation.select('a as 'a, 'b + 1).analyze + val optimized = Optimize.execute(query) + val expected = relation.select('a, 'b + 1).analyze + comparePlans(optimized, expected) } test("some expressions in project list are aliased child output but with metadata") { @@ -83,32 +86,34 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper test("retain deduplicating alias in self-join") { val relation = LocalRelation('a.int) val fragment = relation.select('a as 'a) - val query = relation.join(relation.select('a as 'a)).analyze - val optimized = Optimize.execute( - fragment.select('a as 'a).join(fragment.select('a as 'a)).analyze) - comparePlans(optimized, query) + val query = fragment.select('a as 'a).join(fragment.select('a as 'a)).analyze + val optimized = Optimize.execute(query) + val expected = relation.join(relation.select('a as 'a)).analyze + comparePlans(optimized, expected) } test("alias removal should not break after push project through union") { val r1 = LocalRelation('a.int) val r2 = LocalRelation('b.int) - val optimized = Optimize.execute( - r1.select('a as 'a).union(r2.select('b as 'b)).select('a).analyze) - val query = r1.union(r2) - comparePlans(optimized, query) + val query = r1.select('a as 'a).union(r2.select('b as 'b)).select('a).analyze + val optimized = Optimize.execute(query) + val expected = r1.union(r2) + comparePlans(optimized, expected) } test("remove redundant alias from aggregate") { val relation = LocalRelation('a.int, 'b.int) - val optimized = Optimize.execute(relation.groupBy('a as 'a)('a as 'a, sum('b)).analyze) - val query = relation.groupBy('a)('a, sum('b)).analyze - comparePlans(optimized, query) + val query = relation.groupBy('a as 'a)('a as 'a, sum('b)).analyze + val optimized = Optimize.execute(query) + val expected = relation.groupBy('a)('a, sum('b)).analyze + comparePlans(optimized, expected) } test("remove redundant alias from window") { val relation = LocalRelation('a.int, 'b.int) - val optimized = Optimize.execute(relation.window(Seq('b as 'b), Seq('a as 'a), Seq()).analyze) - val query = relation.window(Seq('b), Seq('a), Seq()).analyze - comparePlans(optimized, query) + val query = relation.window(Seq('b as 'b), Seq('a as 'a), Seq()).analyze + val optimized = Optimize.execute(query) + val expected = relation.window(Seq('b), Seq('a), Seq()).analyze + comparePlans(optimized, expected) } }