Skip to content

Commit 0008bd1

Browse files
uros-dbcloud-fan
authored andcommitted
[SPARK-49000][SQL][3.5] Fix "select count(distinct 1) from t" where t is empty table by expanding RewriteDistinctAggregates
### What changes were proposed in this pull request? Fix `RewriteDistinctAggregates` rule to deal properly with aggregation on DISTINCT literals. Physical plan for `select count(distinct 1) from t`: ``` -- count(distinct 1) == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- HashAggregate(keys=[], functions=[count(distinct 1)], output=[count(DISTINCT 1)#2L]) +- HashAggregate(keys=[], functions=[partial_count(distinct 1)], output=[count#6L]) +- HashAggregate(keys=[], functions=[], output=[]) +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=20] +- HashAggregate(keys=[], functions=[], output=[]) +- FileScan parquet spark_catalog.default.t[] Batched: false, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/Users/nikola.mandic/oss-spark/spark-warehouse/org.apache.spark.s..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<> ``` Problem is happening when `HashAggregate(keys=[], functions=[], output=[])` node yields one row to `partial_count` node, which then captures one row. This four-node structure is constructed by `AggUtils.planAggregateWithOneDistinct`. To fix the problem, we're adding `Expand` node which will force non-empty grouping expressions in `HashAggregateExec` nodes. This will in turn enable streaming zero rows to parent `partial_count` node, yielding correct final result. ### Why are the changes needed? Aggregation with DISTINCT literal gives wrong results. For example, when running on empty table `t`: `select count(distinct 1) from t` returns 1, while the correct result should be 0. For reference: `select count(1) from t` returns 0, which is the correct and expected result. ### Does this PR introduce _any_ user-facing change? Yes, this fixes a critical bug in Spark. ### How was this patch tested? New e2e SQL tests for aggregates with DISTINCT literals. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47566 from uros-db/SPARK-49000-3.5. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 4f9dbc3 commit 0008bd1

File tree

2 files changed

+124
-3
lines changed

2 files changed

+124
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,15 +197,25 @@ import org.apache.spark.util.collection.Utils
197197
* techniques.
198198
*/
199199
object RewriteDistinctAggregates extends Rule[LogicalPlan] {
200+
private def mustRewrite(
201+
distinctAggs: Seq[AggregateExpression],
202+
groupingExpressions: Seq[Expression]): Boolean = {
203+
// If there are any distinct AggregateExpressions with filter, we need to rewrite the query.
204+
// Also, if there are no grouping expressions and all distinct aggregate expressions are
205+
// foldable, we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1). Without this case,
206+
// non-grouping aggregation queries with distinct aggregate expressions will be incorrectly
207+
// handled by the aggregation strategy, causing wrong results when working with empty tables.
208+
distinctAggs.exists(_.filter.isDefined) || (groupingExpressions.isEmpty &&
209+
distinctAggs.exists(_.aggregateFunction.children.forall(_.foldable)))
210+
}
200211

201212
private def mayNeedtoRewrite(a: Aggregate): Boolean = {
202213
val aggExpressions = collectAggregateExprs(a)
203214
val distinctAggs = aggExpressions.filter(_.isDistinct)
204215
// We need at least two distinct aggregates or the single distinct aggregate group exists filter
205216
// clause for this rule because aggregation strategy can handle a single distinct aggregate
206217
// group without filter clause.
207-
// This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a).
208-
distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined)
218+
distinctAggs.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions)
209219
}
210220

211221
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
@@ -236,7 +246,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
236246
}
237247

238248
// Aggregation strategy can handle queries with a single distinct group without filter clause.
239-
if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) {
249+
if (distinctAggGroups.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions)) {
240250
// Create the attributes for the grouping id and the group by clause.
241251
val gid = AttributeReference("gid", IntegerType, nullable = false)()
242252
val groupByMap = a.groupingExpressions.collect {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import scala.util.Random
2424
import org.scalatest.matchers.must.Matchers.the
2525

2626
import org.apache.spark.{SparkException, SparkThrowable}
27+
import org.apache.spark.sql.catalyst.plans.logical.Expand
2728
import org.apache.spark.sql.execution.WholeStageCodegenExec
2829
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
2930
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
@@ -2150,6 +2151,116 @@ class DataFrameAggregateSuite extends QueryTest
21502151
checkAnswer(df, Row(1, 2, 2) :: Row(3, 1, 1) :: Nil)
21512152
}
21522153
}
2154+
2155+
test("aggregating with various distinct expressions") {
2156+
abstract class AggregateTestCaseBase(
2157+
val query: String,
2158+
val resultSeq: Seq[Seq[Row]],
2159+
val hasExpandNodeInPlan: Boolean)
2160+
case class AggregateTestCase(
2161+
override val query: String,
2162+
override val resultSeq: Seq[Seq[Row]],
2163+
override val hasExpandNodeInPlan: Boolean)
2164+
extends AggregateTestCaseBase(query, resultSeq, hasExpandNodeInPlan)
2165+
case class AggregateTestCaseDefault(
2166+
override val query: String)
2167+
extends AggregateTestCaseBase(
2168+
query,
2169+
Seq(Seq(Row(0)), Seq(Row(1)), Seq(Row(1))),
2170+
hasExpandNodeInPlan = true)
2171+
2172+
val t = "t"
2173+
val testCases: Seq[AggregateTestCaseBase] = Seq(
2174+
AggregateTestCaseDefault(
2175+
s"""SELECT COUNT(DISTINCT "col") FROM $t"""
2176+
),
2177+
AggregateTestCaseDefault(
2178+
s"SELECT COUNT(DISTINCT 1) FROM $t"
2179+
),
2180+
AggregateTestCaseDefault(
2181+
s"SELECT COUNT(DISTINCT 1 + 2) FROM $t"
2182+
),
2183+
AggregateTestCaseDefault(
2184+
s"SELECT COUNT(DISTINCT 1, 2, 1 + 2) FROM $t"
2185+
),
2186+
AggregateTestCase(
2187+
s"SELECT COUNT(1), COUNT(DISTINCT 1) FROM $t",
2188+
Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(2, 1))),
2189+
hasExpandNodeInPlan = true
2190+
),
2191+
AggregateTestCaseDefault(
2192+
s"""SELECT COUNT(DISTINCT 1, "col") FROM $t"""
2193+
),
2194+
AggregateTestCaseDefault(
2195+
s"""SELECT COUNT(DISTINCT current_date()) FROM $t"""
2196+
),
2197+
AggregateTestCaseDefault(
2198+
s"""SELECT COUNT(DISTINCT array(1, 2)[1]) FROM $t"""
2199+
),
2200+
AggregateTestCaseDefault(
2201+
s"""SELECT COUNT(DISTINCT map(1, 2)[1]) FROM $t"""
2202+
),
2203+
AggregateTestCaseDefault(
2204+
s"""SELECT COUNT(DISTINCT struct(1, 2).col1) FROM $t"""
2205+
),
2206+
AggregateTestCase(
2207+
s"SELECT COUNT(DISTINCT 1) FROM $t GROUP BY col",
2208+
Seq(Seq(), Seq(Row(1)), Seq(Row(1), Row(1))),
2209+
hasExpandNodeInPlan = false
2210+
),
2211+
AggregateTestCaseDefault(
2212+
s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 1"
2213+
),
2214+
AggregateTestCase(
2215+
s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 0",
2216+
Seq(Seq(Row(0)), Seq(Row(0)), Seq(Row(0))),
2217+
hasExpandNodeInPlan = false
2218+
),
2219+
AggregateTestCase(
2220+
s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)",
2221+
Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))),
2222+
hasExpandNodeInPlan = false
2223+
),
2224+
AggregateTestCase(
2225+
s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(1) FROM $t)",
2226+
Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))),
2227+
hasExpandNodeInPlan = false
2228+
),
2229+
AggregateTestCase(
2230+
s"SELECT SUM(1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)",
2231+
Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))),
2232+
hasExpandNodeInPlan = false
2233+
),
2234+
AggregateTestCaseDefault(
2235+
s"SELECT SUM(x) FROM (SELECT COUNT(DISTINCT 1) AS x FROM $t)"),
2236+
AggregateTestCase(
2237+
s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT "col") FROM $t""",
2238+
Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 1))),
2239+
hasExpandNodeInPlan = true
2240+
),
2241+
AggregateTestCase(
2242+
s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT col) FROM $t""",
2243+
Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 2))),
2244+
hasExpandNodeInPlan = true
2245+
)
2246+
)
2247+
withTable(t) {
2248+
sql(s"create table $t(col int) using parquet")
2249+
Seq(0, 1, 2).foreach(columnValue => {
2250+
if (columnValue != 0) {
2251+
sql(s"insert into $t(col) values($columnValue)")
2252+
}
2253+
testCases.foreach(testCase => {
2254+
val query = sql(testCase.query)
2255+
checkAnswer(query, testCase.resultSeq(columnValue))
2256+
val hasExpandNodeInPlan = query.queryExecution.optimizedPlan.collectFirst {
2257+
case _: Expand => true
2258+
}.nonEmpty
2259+
assert(hasExpandNodeInPlan == testCase.hasExpandNodeInPlan)
2260+
})
2261+
})
2262+
}
2263+
}
21532264
}
21542265

21552266
case class B(c: Option[Double])

0 commit comments

Comments
 (0)