Skip to content

Commit 8386b42

Browse files
committed
another fix
1 parent 6c4b295 commit 8386b42

File tree

3 files changed

+25
-5
lines changed

3 files changed

+25
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
115115
}
116116

117117
// Extract distinct aggregate expressions.
118-
val distincgAggExpressions = aggExpressions.filter(_.isDistinct)
119-
val distinctAggGroups = distincgAggExpressions.groupBy { e =>
118+
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
120119
val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet
121120
if (unfoldableChildren.nonEmpty) {
122121
// Only expand the unfoldable children
@@ -133,7 +132,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
133132
}
134133

135134
// Aggregation strategy can handle queries with a single distinct group.
136-
if (distincgAggExpressions.size > 1) {
135+
if (distinctAggGroups.size > 1) {
137136
// Create the attributes for the grouping id and the group by clause.
138137
val gid = AttributeReference("gid", IntegerType, nullable = false)()
139138
val groupByMap = a.groupingExpressions.collect {
@@ -152,7 +151,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
152151
}
153152

154153
// Setup unique distinct aggregate children.
155-
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
154+
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
156155
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
157156
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
158157

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
384384

385385
val (functionsWithDistinct, functionsWithoutDistinct) =
386386
aggregateExpressions.partition(_.isDistinct)
387-
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
387+
if (functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length > 1) {
388388
// This is a sanity check. We should not reach here when we have multiple distinct
389389
// column sets. Our `RewriteDistinctAggregates` should take care this case.
390390
sys.error("You hit a query analyzer bug. Please report your query to " +

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,27 @@ class PlannerSuite extends SharedSQLContext {
6969
testPartialAggregationPlan(query)
7070
}
7171

72+
test("mixed aggregates with same distinct columns") {
73+
def assertNoExpand(plan: SparkPlan): Unit = {
74+
assert(plan.collect { case e: ExpandExec => e }.isEmpty)
75+
}
76+
77+
withTempView("v") {
78+
Seq((1, 1.0, 1.0), (1, 2.0, 2.0)).toDF("i", "j", "k").createTempView("v")
79+
// one distinct column
80+
val query1 = sql("SELECT sum(DISTINCT j), max(DISTINCT j) FROM v GROUP BY i")
81+
assertNoExpand(query1.queryExecution.executedPlan)
82+
83+
// 2 distinct columns
84+
val query2 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT j, k) FROM v GROUP BY i")
85+
assertNoExpand(query2.queryExecution.executedPlan)
86+
87+
// 2 distinct columns with different order
88+
val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i")
89+
assertNoExpand(query3.queryExecution.executedPlan)
90+
}
91+
}
92+
7293
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
7394
def checkPlan(fieldTypes: Seq[DataType]): Unit = {
7495
withTempView("testLimit") {

0 commit comments

Comments
 (0)