Skip to content

Commit 4e53aab

Browse files
committed
Fix bug in regular aggregation path of the MultipleDistinctWriter: expressions and attributes didn't align.
1 parent 82b9c60 commit 4e53aab

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,13 +370,14 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
370370
// Setup expand for the 'regular' aggregate expressions.
371371
val regularAggExprs = aggExpressions.filter(!_.isDistinct)
372372
val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
373-
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap
373+
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)
374374

375375
// Setup aggregates for 'regular' aggregate expressions.
376376
val regularGroupId = Literal(0)
377+
val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
377378
val regularAggOperatorMap = regularAggExprs.map { e =>
378379
// Perform the actual aggregation in the initial aggregate.
379-
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap)
380+
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
380381
val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)()
381382

382383
// Select the result of the first aggregate in the last aggregate.
@@ -421,7 +422,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
421422
// Construct the expand operator.
422423
val expand = Expand(
423424
regularAggProjection ++ distinctAggProjections,
424-
groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq,
425+
groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2),
425426
a.child)
426427

427428
// Construct the first aggregate operator. This de-duplicates the all the children of

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,32 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
532532
Row(3, 0) :: Nil)
533533
}
534534

535+
test("multiple distinct multiple columns sets") {
536+
checkAnswer(
537+
sqlContext.sql(
538+
"""
539+
|SELECT
540+
| key,
541+
| count(distinct value1),
542+
| sum(distinct value1),
543+
| count(distinct value2),
544+
| sum(distinct value2),
545+
| count(distinct value1, value2),
546+
| count(value1),
547+
| sum(value1),
548+
| count(value2),
549+
| sum(value2),
550+
| count(*),
551+
| count(1)
552+
|FROM agg2
553+
|GROUP BY key
554+
""".stripMargin),
555+
Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) ::
556+
Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) ::
557+
Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) ::
558+
Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil)
559+
}
560+
535561
test("test count") {
536562
checkAnswer(
537563
sqlContext.sql(

0 commit comments

Comments
 (0)