diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index 1198d3fc53cb..ffe071ef25b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -42,8 +42,8 @@ object AggregateEstimation { (res, expr) => { val columnStat = childStats.attributeStats(expr.asInstanceOf[Attribute]) val distinctCount = columnStat.distinctCount.get - val distinctValue: BigInt = if (distinctCount == 0 && columnStat.nullCount.get > 0) { - 1 + val distinctValue: BigInt = if (columnStat.nullCount.get > 0) { + distinctCount + 1 } else { distinctCount } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index c24705043812..32bf20b8c17f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -40,7 +40,9 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { attr("key31") -> ColumnStat(distinctCount = Some(0), min = None, max = None, nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), attr("key32") -> ColumnStat(distinctCount = Some(0), min = None, max = None, - nullCount = Some(4), avgLen = Some(4), maxLen = Some(4)) + nullCount = Some(4), avgLen = Some(4), maxLen = Some(4)), + attr("key33") -> ColumnStat(distinctCount = Some(2), min = None, max = None, + nullCount = Some(2), avgLen = Some(4), maxLen = Some(4)) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -126,6 +128,15 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { expectedOutputRowCount = nameToColInfo("key22")._2.distinctCount.get) } + test("group-by column with null value") { + checkAggStats( + tableColumns = Seq("key21", "key33"), + tableRowCount = 6, + groupByColumns = Seq("key21", "key33"), + expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount.get * + (nameToColInfo("key33")._2.distinctCount.get + 1)) + } + test("non-cbo estimation") { val attributes = Seq("key12").map(nameToAttr) val child = StatsTestPlan(