diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index beabacfc88e3..f6959e3c592d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -205,45 +205,30 @@ class Analyzer( GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) case x: GroupingSets => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() - // We will insert another Projection if the GROUP BY keys contains the - // non-attribute expressions. And the top operators can references those - // expressions by its alias. - // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==> - // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a - - // find all of the non-attribute expressions in the GROUP BY keys - val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]() - - // The pair of (the original GROUP BY key, associated attribute) - val groupByExprPairs = x.groupByExprs.map(_ match { - case e: NamedExpression => (e, e.toAttribute) - case other => { - val alias = Alias(other, other.toString)() - nonAttributeGroupByExpressions += alias // add the non-attributes expression alias - (other, alias.toAttribute) - } - }) - // substitute the non-attribute expressions for aggregations. - val aggregation = x.aggregations.map(expr => expr.transformDown { - case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e) - }.asInstanceOf[NamedExpression]) + val aliasedGroupByExprPairs = x.groupByExprs.map{ + case a @ Alias(expr, _) => (expr, a) + case expr: NamedExpression => (expr, Alias(expr, expr.name)()) + case expr => (expr, Alias(expr, expr.prettyString)()) + } - // substitute the group by expressions. - val newGroupByExprs = groupByExprPairs.map(_._2) + val aliasedGroupByExprs = aliasedGroupByExprPairs.map(_._2) + val aliasedGroupByAttr = aliasedGroupByExprs.map(_.toAttribute) - val child = if (nonAttributeGroupByExpressions.length > 0) { - // insert additional projection if contains the - // non-attribute expressions in the GROUP BY keys - Project(x.child.output ++ nonAttributeGroupByExpressions, x.child) - } else { - x.child + // substitute group by expressions in aggregation list with appropriate attribute + val aggregations = x.aggregations.map{ + case a @ Alias(e, name) => + aliasedGroupByExprPairs.find(_._1.semanticEquals(e)) + .map(_._2.toAttribute.withName(name)).getOrElse(a) + case e => + aliasedGroupByExprPairs.find(_._1.semanticEquals(e)).map(_._2.toAttribute).getOrElse(e) } Aggregate( - newGroupByExprs :+ VirtualColumn.groupingIdAttribute, - aggregation, - Expand(x.bitmasks, newGroupByExprs, gid, child)) + aliasedGroupByAttr :+ gid, aggregations, + Generate(ExpandGroupingSets(aliasedGroupByExprs, x.bitmasks), + join = true, outer = false, qualifier = None, aliasedGroupByAttr :+ gid, x.child) + ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 1a2092c909c5..e464ab0be70d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -146,3 +146,47 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit } } } + +/** + * Given grouping expressions and a list of bitmasks corresponding to grouping sets this produces + * rows for each grouping set with the unselected expressions set to null. Additional it outputs the + * bitmask as the last field. + */ +case class ExpandGroupingSets( + groupByExprs: Seq[Expression], + bitmasks: Seq[Int]) + extends Expression with Generator with CodegenFallback { + + override def elementTypes: Seq[(DataType, Boolean)] = { + groupByExprs.map(_.dataType).map(dt => (dt, true)) :+ (IntegerType, false) + } + + val children = groupByExprs + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val groupByVals = groupByExprs.map(_.eval(input)) + + // Below it the imperative version of + // bitmasks.map{ bitmask => + // val output = groupByVals.zipWithIndex.map{ case (col, colNum) => + // if ((bitmask & (1 << colNum)) != 0) col else null + // } + // InternalRow(output :+ bitmask: _*) + // } + + val outputRows = new Array[InternalRow](bitmasks.length) + var rowNum = 0 + bitmasks.foreach{ bitmask => + val output: Array[Any] = new Array[Any](groupByVals.length + 1) + var colNum = 0 + groupByVals.foreach{ col => + output(colNum) = if ((bitmask & (1 << colNum)) != 0) col else null + colNum += 1 + } + output(groupByVals.length) = bitmask + outputRows(rowNum) = InternalRow(output: _*) + rowNum += 1 + } + outputRows + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d37f43888fd4..67dc1aa96158 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -200,10 +200,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child)) - if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references))) - // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 4cb67aacf33e..4a0a82f47940 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -235,77 +235,6 @@ case class Window( projectList ++ windowExpressions.map(_.toAttribute) } -/** - * Apply the all of the GroupExpressions to every input row, hence we will get - * multiple output rows for a input row. - * @param bitmasks The bitmask set represents the grouping sets - * @param groupByExprs The grouping by expressions - * @param child Child operator - */ -case class Expand( - bitmasks: Seq[Int], - groupByExprs: Seq[Expression], - gid: Attribute, - child: LogicalPlan) extends UnaryNode { - override def statistics: Statistics = { - val sizeInBytes = child.statistics.sizeInBytes * projections.length - Statistics(sizeInBytes = sizeInBytes) - } - - val projections: Seq[Seq[Expression]] = expand() - - /** - * Extract attribute set according to the grouping id - * @param bitmask bitmask to represent the selected of the attribute sequence - * @param exprs the attributes in sequence - * @return the attributes of non selected specified via bitmask (with the bit set to 1) - */ - private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) - : OpenHashSet[Expression] = { - val set = new OpenHashSet[Expression](2) - - var bit = exprs.length - 1 - while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) - bit -= 1 - } - - set - } - - /** - * Create an array of Projections for the child projection, and replace the projections' - * expressions which equal GroupBy expressions with Literal(null), if those expressions - * are not set for this grouping set (according to the bit mask). - */ - private[this] def expand(): Seq[Seq[Expression]] = { - val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] - - bitmasks.foreach { bitmask => - // get the non selected grouping attributes according to the bit mask - val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs) - - val substitution = (child.output :+ gid).map(expr => expr transformDown { - case x: Expression if nonSelectedGroupExprSet.contains(x) => - // if the input attribute in the Invalid Grouping Expression set of for this group - // replace it with constant null - Literal.create(null, expr.dataType) - case x if x == gid => - // replace the groupingId with concrete value (the bit mask) - Literal.create(bitmask, IntegerType) - }) - - result += substitution - } - - result.toSeq - } - - override def output: Seq[Attribute] = { - child.output :+ gid - } -} - trait GroupingAnalytics extends UnaryNode { def groupByExprs: Seq[Expression] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala deleted file mode 100644 index a458881f4094..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} - -/** - * Apply the all of the GroupExpressions to every input row, hence we will get - * multiple output rows for a input row. - * @param projections The group of expressions, all of the group expressions should - * output the same schema specified bye the parameter `output` - * @param output The output Schema - * @param child Child operator - */ -case class Expand( - projections: Seq[Seq[Expression]], - output: Seq[Attribute], - child: SparkPlan) - extends UnaryNode { - - // The GroupExpressions can output data with arbitrary partitioning, so set it - // as UNKNOWN partitioning - override def outputPartitioning: Partitioning = UnknownPartitioning(0) - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - // TODO Move out projection objects creation and transfer to - // workers via closure. However we can't assume the Projection - // is serializable because of the code gen, so we have to - // create the projections within each of the partition processing. - val groups = projections.map(ee => newProjection(ee, child.output)).toArray - - new Iterator[InternalRow] { - private[this] var result: InternalRow = _ - private[this] var idx = -1 // -1 means the initial state - private[this] var input: InternalRow = _ - - override final def hasNext: Boolean = (-1 < idx && idx < groups.length) || iter.hasNext - - override final def next(): InternalRow = { - if (idx <= 0) { - // in the initial (-1) or beginning(0) of a new input row, fetch the next input tuple - input = iter.next() - idx = 0 - } - - result = groups(idx)(input) - idx += 1 - - if (idx == groups.length && iter.hasNext) { - idx = 0 - } - - result - } - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 32067266b516..e09551ac6737 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -421,8 +421,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil - case e @ logical.Expand(_, _, _, child) => - execution.Expand(e.projections, e.output, planLater(child)) :: Nil case a @ logical.Aggregate(group, agg, child) => { val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled if (useNewAggregation && a.newAggregation.isDefined) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f5ef9ffd7f4f..60dc59f974e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -60,6 +60,40 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("rollup") { + checkAnswer( + testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.rollup("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, null, 9) :: Nil + ) + } + + test("cube") { + checkAnswer( + testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, 1, 3) :: Row(null, 2, 0) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.cube("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, 1, 3) :: Row(null, 2, 6) + :: Row(null, null, 9) :: Nil + ) + } + test("spark.sql.retainGroupColumns config") { checkAnswer( testData2.groupBy("a").agg(sum($"b")),