Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -205,45 +205,30 @@ class Analyzer(
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
case x: GroupingSets =>
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
// We will insert another Projection if the GROUP BY keys contains the
// non-attribute expressions. And the top operators can references those
// expressions by its alias.
// e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==>
// SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a

// find all of the non-attribute expressions in the GROUP BY keys
val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]()

// The pair of (the original GROUP BY key, associated attribute)
val groupByExprPairs = x.groupByExprs.map(_ match {
case e: NamedExpression => (e, e.toAttribute)
case other => {
val alias = Alias(other, other.toString)()
nonAttributeGroupByExpressions += alias // add the non-attributes expression alias
(other, alias.toAttribute)
}
})

// substitute the non-attribute expressions for aggregations.
val aggregation = x.aggregations.map(expr => expr.transformDown {
case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e)
}.asInstanceOf[NamedExpression])
val aliasedGroupByExprPairs = x.groupByExprs.map{
case a @ Alias(expr, _) => (expr, a)
case expr: NamedExpression => (expr, Alias(expr, expr.name)())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also use expr.toAttribute for a NamedExpression instead of creating an Alias.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe I need a new Alias here since we really have two versions of the expression -- the original and the version manipulated by the Generator with nulls inserted per the bitmask. In the Aggregate 'aggregation' list the grouping columns need to refer to the manipulated version and 'real' aggregates need to refer to the original version.

case expr => (expr, Alias(expr, expr.prettyString)())
}

// substitute the group by expressions.
val newGroupByExprs = groupByExprPairs.map(_._2)
val aliasedGroupByExprs = aliasedGroupByExprPairs.map(_._2)
val aliasedGroupByAttr = aliasedGroupByExprs.map(_.toAttribute)

val child = if (nonAttributeGroupByExpressions.length > 0) {
// insert additional projection if contains the
// non-attribute expressions in the GROUP BY keys
Project(x.child.output ++ nonAttributeGroupByExpressions, x.child)
} else {
x.child
// substitute group by expressions in aggregation list with appropriate attribute
val aggregations = x.aggregations.map{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expr => expr.transformDown { 
..
}

Otherwise it's not able to substitute the expression like sum(a+b) + count(c) for a+b.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenghao-intel actually that change would bring back the bug in question since it would do the substitutions in situations like below and the aggregations would be computed off the manipulated (nulls inserted) values.

select a + b, c, sum(a+b) + count(c)
from t1
group by a + b, c with rollup

In general anything below an AggregateExpression we don't want to transform, but above we do. So really I need a transformDownUntil method. BTW making this change does fix the groupby_grouping_sets1 test so I really do need to do something.

case a @ Alias(e, name) =>
aliasedGroupByExprPairs.find(_._1.semanticEquals(e))
.map(_._2.toAttribute.withName(name)).getOrElse(a)
case e =>
aliasedGroupByExprPairs.find(_._1.semanticEquals(e)).map(_._2.toAttribute).getOrElse(e)
}

Aggregate(
newGroupByExprs :+ VirtualColumn.groupingIdAttribute,
aggregation,
Expand(x.bitmasks, newGroupByExprs, gid, child))
aliasedGroupByAttr :+ gid, aggregations,
Generate(ExpandGroupingSets(aliasedGroupByExprs, x.bitmasks),
join = true, outer = false, qualifier = None, aliasedGroupByAttr :+ gid, x.child)
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,47 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
}
}
}

/**
* Given grouping expressions and a list of bitmasks corresponding to grouping sets this produces
* rows for each grouping set with the unselected expressions set to null. Additional it outputs the
* bitmask as the last field.
*/
case class ExpandGroupingSets(
groupByExprs: Seq[Expression],
bitmasks: Seq[Int])
extends Expression with Generator with CodegenFallback {

override def elementTypes: Seq[(DataType, Boolean)] = {
groupByExprs.map(_.dataType).map(dt => (dt, true)) :+ (IntegerType, false)
}

val children = groupByExprs

override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val groupByVals = groupByExprs.map(_.eval(input))

// Below it the imperative version of
// bitmasks.map{ bitmask =>
// val output = groupByVals.zipWithIndex.map{ case (col, colNum) =>
// if ((bitmask & (1 << colNum)) != 0) col else null
// }
// InternalRow(output :+ bitmask: _*)
// }

val outputRows = new Array[InternalRow](bitmasks.length)
var rowNum = 0
bitmasks.foreach{ bitmask =>
val output: Array[Any] = new Array[Any](groupByVals.length + 1)
var colNum = 0
groupByVals.foreach{ col =>
output(colNum) = if ((bitmask & (1 << colNum)) != 0) col else null
colNum += 1
}
output(groupByVals.length) = bitmask
outputRows(rowNum) = InternalRow(output: _*)
rowNum += 1
}
outputRows
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child))
if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty =>
a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references)))

// Eliminate attributes that are not needed to calculate the specified aggregates.
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
a.copy(child = Project(a.references.toSeq, child))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,77 +235,6 @@ case class Window(
projectList ++ windowExpressions.map(_.toAttribute)
}

/**
* Apply the all of the GroupExpressions to every input row, hence we will get
* multiple output rows for a input row.
* @param bitmasks The bitmask set represents the grouping sets
* @param groupByExprs The grouping by expressions
* @param child Child operator
*/
case class Expand(
bitmasks: Seq[Int],
groupByExprs: Seq[Expression],
gid: Attribute,
child: LogicalPlan) extends UnaryNode {
override def statistics: Statistics = {
val sizeInBytes = child.statistics.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes)
}

val projections: Seq[Seq[Expression]] = expand()

/**
* Extract attribute set according to the grouping id
* @param bitmask bitmask to represent the selected of the attribute sequence
* @param exprs the attributes in sequence
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
*/
private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
: OpenHashSet[Expression] = {
val set = new OpenHashSet[Expression](2)

var bit = exprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
bit -= 1
}

set
}

/**
* Create an array of Projections for the child projection, and replace the projections'
* expressions which equal GroupBy expressions with Literal(null), if those expressions
* are not set for this grouping set (according to the bit mask).
*/
private[this] def expand(): Seq[Seq[Expression]] = {
val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]

bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)

val substitution = (child.output :+ gid).map(expr => expr transformDown {
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
case x if x == gid =>
// replace the groupingId with concrete value (the bit mask)
Literal.create(bitmask, IntegerType)
})

result += substitution
}

result.toSeq
}

override def output: Seq[Attribute] = {
child.output :+ gid
}
}

trait GroupingAnalytics extends UnaryNode {

def groupByExprs: Seq[Expression]
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
case e @ logical.Expand(_, _, _, child) =>
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
case a @ logical.Aggregate(group, agg, child) => {
val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled
if (useNewAggregation && a.newAggregation.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,40 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}

test("rollup") {
checkAnswer(
testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"),
Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1)
:: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1)
:: Row(null, null, 3) :: Nil
)

checkAnswer(
testData2.rollup("a", "b").agg(sum("b")),
Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2)
:: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3)
:: Row(null, null, 9) :: Nil
)
}

test("cube") {
checkAnswer(
testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1)
:: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1)
:: Row(null, 1, 3) :: Row(null, 2, 0)
:: Row(null, null, 3) :: Nil
)

checkAnswer(
testData2.cube("a", "b").agg(sum("b")),
Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2)
:: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3)
:: Row(null, 1, 3) :: Row(null, 2, 6)
:: Row(null, null, 9) :: Nil
)
}

test("spark.sql.retainGroupColumns config") {
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
Expand Down