diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8f2b7ca5cba25..377f084587ad3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -106,30 +106,23 @@ case class HashAggregateExec( // so return an empty iterator. Iterator.empty } else { - val aggregationIterator = - new TungstenAggregationIterator( - partIndex, - groupingExpressions, - aggregateExpressions, - aggregateAttributes, - initialInputBufferOffset, - resultExpressions, - (expressions, inputSchema) => - MutableProjection.create(expressions, inputSchema), - inputAttributes, - iter, - testFallbackStartsAt, - numOutputRows, - peakMemory, - spillSize, - avgHashProbe, - numTasksFallBacked) - if (!hasInput && groupingExpressions.isEmpty) { - numOutputRows += 1 - Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) - } else { - aggregationIterator - } + new TungstenAggregationIterator( + partIndex, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + (expressions, inputSchema) => + MutableProjection.create(expressions, inputSchema), + inputAttributes, + iter, + testFallbackStartsAt, + numOutputRows, + peakMemory, + spillSize, + avgHashProbe, + numTasksFallBacked) } aggTime += NANOSECONDS.toMillis(System.nanoTime() - beforeAgg) res diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 1ebf0d143bd1f..13510ee7b8aca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -184,8 +184,15 @@ class TungstenAggregationIterator( // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. val groupingKey = groupingProjection.apply(null) - val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + var buffer: UnsafeRow = if (aggregateExpressions.isEmpty) { + null + } else { + hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + } while (inputIter.hasNext) { + if (buffer == null) { + buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + } val newInput = inputIter.next() processRow(buffer, newInput) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 620ee430cab20..a68cdb0ee8bd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -823,11 +823,7 @@ class DataFrameAggregateSuite extends QueryTest "should produce correct aggregate") { _ => // explicit global aggregations val emptyAgg = Map.empty[String, String] - checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row())) - checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row())) checkAnswer(spark.emptyDataFrame.agg(count("*")), Seq(Row(0))) - checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row())) - checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row())) checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(count("*")), Seq(Row(0))) // global aggregation is converted to grouping aggregation: @@ -2339,6 +2335,22 @@ class DataFrameAggregateSuite extends QueryTest test("SPARK-32761: aggregating multiple distinct CONSTANT columns") { checkAnswer(sql("select count(distinct 2), count(distinct 2,3)"), Row(1, 1)) } + + test("aggregating single distinct column with empty and non-empty table") { + val tableName = "t" + withTable(tableName) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + sql(s"create table $tableName(col int not null) using parquet") + checkAnswer(sql(s"select count(distinct 1) from $tableName"), Row(0)) + sql(s"insert into $tableName(col) values(1)") + checkAnswer(sql(s"select count(distinct 1) from $tableName"), Row(1)) + sql(s"insert into $tableName(col) values(1)") + checkAnswer(sql(s"select count(distinct 1) from $tableName"), Row(1)) + sql(s"insert into $tableName(col) values(2)") + checkAnswer(sql(s"select count(distinct 1) from $tableName"), Row(1)) + } + } + } } case class B(c: Option[Double])