From 84821b45c23dcd8ad4889de2f0b802918276aa91 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 28 Mar 2016 08:31:34 +0000 Subject: [PATCH 1/3] Fix Expand constraints. --- .../DistinctAggregationRewriter.scala | 5 ++- .../sql/catalyst/optimizer/Optimizer.scala | 7 ++-- .../plans/logical/basicOperators.scala | 35 ++++++++++++++++--- .../optimizer/ColumnPruningSuite.scala | 7 ++-- .../plans/ConstraintPropagationSuite.scala | 33 +++++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 2 +- .../apache/spark/sql/hive/SQLBuilder.scala | 2 +- 7 files changed, 78 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 2e30d83a60970..3edf87f785954 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -222,10 +222,13 @@ object DistinctAggregationRewriter extends Rule[LogicalPlan] { } // Construct the expand operator. + val constraints = + Expand.constructValidConstraints(a.child.constraints, AttributeSet(groupByAttrs)) val expand = Expand( regularAggProjection ++ distinctAggProjections, groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), - a.child) + a.child, + constraints) // Construct the first aggregate operator. This de-duplicates the all the children of // distinct operators, and applies the regular aggregate operators. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a7a948ef1b97d..93220f7900cce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -327,14 +327,15 @@ object ColumnPruning extends Rule[LogicalPlan] { case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) - case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => + case a @ Project(_, e @ Expand(_, _, grandChild, constraintsBase)) + if (e.outputSet -- a.references).nonEmpty => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => proj.zip(e.output).filter { case (e, a) => newOutput.contains(a) }.unzip._1 } - a.copy(child = Expand(newProjects, newOutput, grandChild)) + a.copy(child = Expand(newProjects, newOutput, grandChild, constraintsBase)) // Prunes the unused columns from child of MapPartitions case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty => @@ -343,7 +344,7 @@ object ColumnPruning extends Rule[LogicalPlan] { // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) - case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => + case e @ Expand(_, _, child, _) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => g.copy(child = prunedChild(g.child, g.references)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 09c200fa839c7..347786e85d4c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -484,6 +484,7 @@ private[sql] object Expand { groupByAttrs: Seq[Attribute], gid: Attribute, child: LogicalPlan): Expand = { + var allNonSelectAttrSet = AttributeSet.empty // Create an array of Projections for the child projection, and replace the projections' // expressions which equal GroupBy expressions with Literal(null), if those expressions // are not set for this grouping set (according to the bit mask). @@ -491,6 +492,8 @@ private[sql] object Expand { // get the non selected grouping attributes according to the bit mask val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, groupByAttrs) + allNonSelectAttrSet = allNonSelectAttrSet ++ nonSelectedGroupAttrSet + child.output ++ groupByAttrs.map { attr => if (nonSelectedGroupAttrSet.contains(attr)) { // if the input attribute in the Invalid Grouping Expression set of for this group @@ -502,8 +505,30 @@ private[sql] object Expand { // groupingId is the last output, here we use the bit mask as the concrete value for it. } :+ Literal.create(bitmask, IntegerType) } - val output = child.output ++ groupByAttrs :+ gid - Expand(projections, output, Project(child.output ++ groupByAliases, child)) + val output = (child.output ++ groupByAttrs :+ gid).map { a => + if (a.resolved && allNonSelectAttrSet.contains(a)) { + a.withNullability(true) + } else { + a + } + } + val expandChild = Project(child.output ++ groupByAliases, child) + val validConstraints = constructValidConstraints(expandChild.constraints, allNonSelectAttrSet) + Expand(projections, output, expandChild, validConstraints) + } + + /** + * Filter out the `IsNotNull` constraints which cover the group by attributes in Expand operator. + * These constraints come from Expand's child plan. Because Expand will set group by attribute to + * null values in its projections, we need to filter out these `IsNotNull` constraints. + * + * @param constraints The constraints from Expand operator's child + * @param groupByAttrs The attributes of aliased group by expressions in Expand + */ + def constructValidConstraints( + constraints: ExpressionSet, + groupByAttrs: AttributeSet): Seq[Expression] = { + constraints.filter(_.references.intersect(groupByAttrs).isEmpty).toSeq } } @@ -518,8 +543,8 @@ private[sql] object Expand { case class Expand( projections: Seq[Seq[Expression]], output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - + child: LogicalPlan, + constraintsBase: Seq[Expression]) extends UnaryNode { override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) @@ -527,6 +552,8 @@ case class Expand( val sizeInBytes = super.statistics.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } + + override protected def validConstraints: Set[Expression] = constraintsBase.toSet } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 2248e03b2fc58..27ae771ac9443 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -105,6 +105,7 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning for Expand") { val input = LocalRelation('a.int, 'b.string, 'c.double) + val constraints = Expand.constructValidConstraints(input.constraints, AttributeSet('a)) val query = Aggregate( Seq('aa, 'gid), @@ -114,9 +115,9 @@ class ColumnPruningSuite extends PlanTest { Seq('a, 'b, 'c, Literal.create(null, StringType), 1), Seq('a, 'b, 'c, 'a, 2)), Seq('a, 'b, 'c, 'aa.int, 'gid.int), - input)).analyze + input, + constraints)).analyze val optimized = Optimize.execute(query) - val expected = Aggregate( Seq('aa, 'gid), @@ -127,7 +128,7 @@ class ColumnPruningSuite extends PlanTest { Seq('c, 'a, 2)), Seq('c, 'aa.int, 'gid.int), Project(Seq('a, 'c), - input))).analyze + input), constraints)).analyze comparePlans(optimized, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index e5063599a353e..9337f65024b0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -88,6 +88,39 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(aliasedRelation.analyze, "a"))))) } + test("propagating constraints in expand") { + val tr = LocalRelation('a.int, 'b.int, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + // We add IsNotNull constraints for 'a, 'b and 'c into LocalRelation + // by creating notNullRelation. + val notNullRelation = tr.where('c.attr > 10 && 'a.attr < 5 && 'b.attr > 2) + verifyConstraints(notNullRelation.analyze.constraints, + ExpressionSet(Seq(resolveColumn(notNullRelation.analyze, "c") > 10, + IsNotNull(resolveColumn(notNullRelation.analyze, "c")), + resolveColumn(notNullRelation.analyze, "a") < 5, + IsNotNull(resolveColumn(notNullRelation.analyze, "a")), + resolveColumn(notNullRelation.analyze, "b") > 2, + IsNotNull(resolveColumn(notNullRelation.analyze, "b"))))) + + val constraints = + Expand.constructValidConstraints( + notNullRelation.analyze.constraints, + AttributeSet(resolveColumn(notNullRelation.analyze, "a").asInstanceOf[Attribute])) + + val expand = Expand( + Seq( + Seq('c, Literal.create(null, StringType), 1), + Seq('c, 'a, 2)), + Seq('c, 'a, 'gid.int), + Project(Seq('a, 'c), + notNullRelation), constraints) + verifyConstraints(expand.analyze.constraints, + ExpressionSet(Seq(resolveColumn(expand.analyze, "c") > 10, + IsNotNull(resolveColumn(expand.analyze, "c"))))) + } + test("propagating constraints in aliases") { val tr = LocalRelation('a.int, 'b.string, 'c.int) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7841ff01f93c2..cbf8a87b8c131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -400,7 +400,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil - case e @ logical.Expand(_, _, child) => + case e @ logical.Expand(_, _, child, _) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil case logical.Window(windowExprs, partitionSpec, orderSpec, child) => execution.Window(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index e54358e657690..394572636b19e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -110,7 +110,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: Project => projectToSQL(p, isDistinct = false) - case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) => + case a @ Aggregate(_, _, e @ Expand(_, _, p: Project, _)) if isGroupingSet(a, e, p) => groupingSetToSQL(a, e, p) case p: Aggregate => From 23d6b37953e734d8e1ac1039518293aeb3c9138b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 30 Mar 2016 04:55:26 +0000 Subject: [PATCH 2/3] Move constraint pruning logic into Expand. --- .../DistinctAggregationRewriter.scala | 4 +-- .../sql/catalyst/optimizer/Optimizer.scala | 4 +-- .../plans/logical/basicOperators.scala | 28 +++++++------------ .../optimizer/ColumnPruningSuite.scala | 5 ++-- .../plans/ConstraintPropagationSuite.scala | 7 +---- 5 files changed, 16 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 3edf87f785954..f30aebe012528 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -222,13 +222,11 @@ object DistinctAggregationRewriter extends Rule[LogicalPlan] { } // Construct the expand operator. - val constraints = - Expand.constructValidConstraints(a.child.constraints, AttributeSet(groupByAttrs)) val expand = Expand( regularAggProjection ++ distinctAggProjections, groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), a.child, - constraints) + groupByAttrs) // Construct the first aggregate operator. This de-duplicates the all the children of // distinct operators, and applies the regular aggregate operators. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 93220f7900cce..a9113982df160 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -327,7 +327,7 @@ object ColumnPruning extends Rule[LogicalPlan] { case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) - case a @ Project(_, e @ Expand(_, _, grandChild, constraintsBase)) + case a @ Project(_, e @ Expand(_, _, grandChild, groupByAttrs)) if (e.outputSet -- a.references).nonEmpty => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => @@ -335,7 +335,7 @@ object ColumnPruning extends Rule[LogicalPlan] { newOutput.contains(a) }.unzip._1 } - a.copy(child = Expand(newProjects, newOutput, grandChild, constraintsBase)) + a.copy(child = Expand(newProjects, newOutput, grandChild, groupByAttrs)) // Prunes the unused columns from child of MapPartitions case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 347786e85d4c9..ee7c22668c169 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -513,22 +513,7 @@ private[sql] object Expand { } } val expandChild = Project(child.output ++ groupByAliases, child) - val validConstraints = constructValidConstraints(expandChild.constraints, allNonSelectAttrSet) - Expand(projections, output, expandChild, validConstraints) - } - - /** - * Filter out the `IsNotNull` constraints which cover the group by attributes in Expand operator. - * These constraints come from Expand's child plan. Because Expand will set group by attribute to - * null values in its projections, we need to filter out these `IsNotNull` constraints. - * - * @param constraints The constraints from Expand operator's child - * @param groupByAttrs The attributes of aliased group by expressions in Expand - */ - def constructValidConstraints( - constraints: ExpressionSet, - groupByAttrs: AttributeSet): Seq[Expression] = { - constraints.filter(_.references.intersect(groupByAttrs).isEmpty).toSeq + Expand(projections, output, expandChild, allNonSelectAttrSet.toSeq) } } @@ -539,12 +524,13 @@ private[sql] object Expand { * @param projections to apply * @param output of all projections. * @param child operator. + * @param groupByAttrs the attributes used in group by. */ case class Expand( projections: Seq[Seq[Expression]], output: Seq[Attribute], child: LogicalPlan, - constraintsBase: Seq[Expression]) extends UnaryNode { + groupByAttrs: Seq[Attribute]) extends UnaryNode { override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) @@ -553,7 +539,13 @@ case class Expand( Statistics(sizeInBytes = sizeInBytes) } - override protected def validConstraints: Set[Expression] = constraintsBase.toSet + /** + * Filter out the `IsNotNull` constraints which cover the group by attributes in Expand operator. + * These constraints come from Expand's child plan. Because Expand will set group by attribute to + * null values in its projections, we need to filter out these `IsNotNull` constraints. + */ + override protected def validConstraints: Set[Expression] = + child.constraints.filter(_.references.intersect(AttributeSet(groupByAttrs)).isEmpty) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 27ae771ac9443..fa5b30b90ee2d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -105,7 +105,6 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning for Expand") { val input = LocalRelation('a.int, 'b.string, 'c.double) - val constraints = Expand.constructValidConstraints(input.constraints, AttributeSet('a)) val query = Aggregate( Seq('aa, 'gid), @@ -116,7 +115,7 @@ class ColumnPruningSuite extends PlanTest { Seq('a, 'b, 'c, 'a, 2)), Seq('a, 'b, 'c, 'aa.int, 'gid.int), input, - constraints)).analyze + Seq('a))).analyze val optimized = Optimize.execute(query) val expected = Aggregate( @@ -128,7 +127,7 @@ class ColumnPruningSuite extends PlanTest { Seq('c, 'a, 2)), Seq('c, 'aa.int, 'gid.int), Project(Seq('a, 'c), - input), constraints)).analyze + input), Seq('a))).analyze comparePlans(optimized, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 9337f65024b0b..f3d129ce24d3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -104,18 +104,13 @@ class ConstraintPropagationSuite extends SparkFunSuite { resolveColumn(notNullRelation.analyze, "b") > 2, IsNotNull(resolveColumn(notNullRelation.analyze, "b"))))) - val constraints = - Expand.constructValidConstraints( - notNullRelation.analyze.constraints, - AttributeSet(resolveColumn(notNullRelation.analyze, "a").asInstanceOf[Attribute])) - val expand = Expand( Seq( Seq('c, Literal.create(null, StringType), 1), Seq('c, 'a, 2)), Seq('c, 'a, 'gid.int), Project(Seq('a, 'c), - notNullRelation), constraints) + notNullRelation), Seq('a)) verifyConstraints(expand.analyze.constraints, ExpressionSet(Seq(resolveColumn(expand.analyze, "c") > 10, IsNotNull(resolveColumn(expand.analyze, "c"))))) From ab89e620883f581b2104fc60ffb32f77501f94c4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 31 Mar 2016 06:20:49 +0000 Subject: [PATCH 3/3] Address comment. Remove constraints for Expand operator. --- .../DistinctAggregationRewriter.scala | 3 +-- .../sql/catalyst/optimizer/Optimizer.scala | 6 ++--- .../plans/logical/basicOperators.scala | 26 +++---------------- .../optimizer/ColumnPruningSuite.scala | 5 ++-- .../plans/ConstraintPropagationSuite.scala | 5 ++-- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../apache/spark/sql/hive/SQLBuilder.scala | 2 +- 7 files changed, 14 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index f30aebe012528..2e30d83a60970 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -225,8 +225,7 @@ object DistinctAggregationRewriter extends Rule[LogicalPlan] { val expand = Expand( regularAggProjection ++ distinctAggProjections, groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), - a.child, - groupByAttrs) + a.child) // Construct the first aggregate operator. This de-duplicates the all the children of // distinct operators, and applies the regular aggregate operators. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a9113982df160..73eacab31bb17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -327,7 +327,7 @@ object ColumnPruning extends Rule[LogicalPlan] { case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) - case a @ Project(_, e @ Expand(_, _, grandChild, groupByAttrs)) + case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => @@ -335,7 +335,7 @@ object ColumnPruning extends Rule[LogicalPlan] { newOutput.contains(a) }.unzip._1 } - a.copy(child = Expand(newProjects, newOutput, grandChild, groupByAttrs)) + a.copy(child = Expand(newProjects, newOutput, grandChild)) // Prunes the unused columns from child of MapPartitions case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty => @@ -344,7 +344,7 @@ object ColumnPruning extends Rule[LogicalPlan] { // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) - case e @ Expand(_, _, child, _) if (child.outputSet -- e.references).nonEmpty => + case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => g.copy(child = prunedChild(g.child, g.references)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index ee7c22668c169..db8be9d59348a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -484,7 +484,6 @@ private[sql] object Expand { groupByAttrs: Seq[Attribute], gid: Attribute, child: LogicalPlan): Expand = { - var allNonSelectAttrSet = AttributeSet.empty // Create an array of Projections for the child projection, and replace the projections' // expressions which equal GroupBy expressions with Literal(null), if those expressions // are not set for this grouping set (according to the bit mask). @@ -492,8 +491,6 @@ private[sql] object Expand { // get the non selected grouping attributes according to the bit mask val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, groupByAttrs) - allNonSelectAttrSet = allNonSelectAttrSet ++ nonSelectedGroupAttrSet - child.output ++ groupByAttrs.map { attr => if (nonSelectedGroupAttrSet.contains(attr)) { // if the input attribute in the Invalid Grouping Expression set of for this group @@ -505,15 +502,8 @@ private[sql] object Expand { // groupingId is the last output, here we use the bit mask as the concrete value for it. } :+ Literal.create(bitmask, IntegerType) } - val output = (child.output ++ groupByAttrs :+ gid).map { a => - if (a.resolved && allNonSelectAttrSet.contains(a)) { - a.withNullability(true) - } else { - a - } - } - val expandChild = Project(child.output ++ groupByAliases, child) - Expand(projections, output, expandChild, allNonSelectAttrSet.toSeq) + val output = child.output ++ groupByAttrs :+ gid + Expand(projections, output, Project(child.output ++ groupByAliases, child)) } } @@ -524,13 +514,11 @@ private[sql] object Expand { * @param projections to apply * @param output of all projections. * @param child operator. - * @param groupByAttrs the attributes used in group by. */ case class Expand( projections: Seq[Seq[Expression]], output: Seq[Attribute], - child: LogicalPlan, - groupByAttrs: Seq[Attribute]) extends UnaryNode { + child: LogicalPlan) extends UnaryNode { override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) @@ -539,13 +527,7 @@ case class Expand( Statistics(sizeInBytes = sizeInBytes) } - /** - * Filter out the `IsNotNull` constraints which cover the group by attributes in Expand operator. - * These constraints come from Expand's child plan. Because Expand will set group by attribute to - * null values in its projections, we need to filter out these `IsNotNull` constraints. - */ - override protected def validConstraints: Set[Expression] = - child.constraints.filter(_.references.intersect(AttributeSet(groupByAttrs)).isEmpty) + override protected def validConstraints: Set[Expression] = Set.empty[Expression] } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index fa5b30b90ee2d..95b8d1b108210 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -114,8 +114,7 @@ class ColumnPruningSuite extends PlanTest { Seq('a, 'b, 'c, Literal.create(null, StringType), 1), Seq('a, 'b, 'c, 'a, 2)), Seq('a, 'b, 'c, 'aa.int, 'gid.int), - input, - Seq('a))).analyze + input)).analyze val optimized = Optimize.execute(query) val expected = Aggregate( @@ -127,7 +126,7 @@ class ColumnPruningSuite extends PlanTest { Seq('c, 'a, 2)), Seq('c, 'aa.int, 'gid.int), Project(Seq('a, 'c), - input), Seq('a))).analyze + input))).analyze comparePlans(optimized, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index f3d129ce24d3f..cd1ab0e2b685a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -110,10 +110,9 @@ class ConstraintPropagationSuite extends SparkFunSuite { Seq('c, 'a, 2)), Seq('c, 'a, 'gid.int), Project(Seq('a, 'c), - notNullRelation), Seq('a)) + notNullRelation)) verifyConstraints(expand.analyze.constraints, - ExpressionSet(Seq(resolveColumn(expand.analyze, "c") > 10, - IsNotNull(resolveColumn(expand.analyze, "c"))))) + ExpressionSet(Seq.empty[Expression])) } test("propagating constraints in aliases") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cbf8a87b8c131..7841ff01f93c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -400,7 +400,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil - case e @ logical.Expand(_, _, child, _) => + case e @ logical.Expand(_, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil case logical.Window(windowExprs, partitionSpec, orderSpec, child) => execution.Window(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 394572636b19e..e54358e657690 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -110,7 +110,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: Project => projectToSQL(p, isDistinct = false) - case a @ Aggregate(_, _, e @ Expand(_, _, p: Project, _)) if isGroupingSet(a, e, p) => + case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) => groupingSetToSQL(a, e, p) case p: Aggregate =>