Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -752,14 +752,13 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
}
override def computeStats(conf: CatalystConf): Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make the stats more accurate, yes, we can use a smaller number between childStats.rowCounts and limit as outputRowCount of getOutputSize

val sizeInBytes = if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
1
} else {
(limit: Long) * output.map(a => a.dataType.defaultSize).sum
}
child.stats(conf).copy(sizeInBytes = sizeInBytes)
val childStats = child.stats(conf)
val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit)
// Don't propagate column stats, because we don't know the distribution after a limit operation
Statistics(
sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats),
rowCount = Some(rowCount),
isBroadcastable = childStats.isBroadcastable)
}
}

Expand All @@ -773,14 +772,21 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
}
override def computeStats(conf: CatalystConf): Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val sizeInBytes = if (limit == 0) {
val childStats = child.stats(conf)
if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
1
Statistics(
sizeInBytes = 1,
rowCount = Some(0),
isBroadcastable = childStats.isBroadcastable)
} else {
(limit: Long) * output.map(a => a.dataType.defaultSize).sum
// The output row count of LocalLimit should be the sum of row counts from each partition.
// However, since the number of partitions is not available here, we just use statistics of
// the child. Because the distribution after a limit operation is unknown, we do not propagate
// the column stats.
childStats.copy(attributeStats = AttributeMap(Nil))
}
child.stats(conf).copy(sizeInBytes = sizeInBytes)
}
}

Expand Down Expand Up @@ -816,12 +822,14 @@ case class Sample(

override def computeStats(conf: CatalystConf): Statistics = {
val ratio = upperBound - lowerBound
// BigInt can't multiply with Double
var sizeInBytes = child.stats(conf).sizeInBytes * (ratio * 100).toInt / 100
val childStats = child.stats(conf)
var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio)
if (sizeInBytes == 0) {
sizeInBytes = 1
}
child.stats(conf).copy(sizeInBytes = sizeInBytes)
val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio))
// Don't propagate column stats, because we don't know the distribution after a sample operation
Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable)
}

override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType


class BasicStatsEstimationSuite extends StatsEstimationTestBase {
val attribute = attr("key")
val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)

val plan = StatsTestPlan(
outputList = Seq(attribute),
attributeStats = AttributeMap(Seq(attribute -> colStat)),
rowCount = 10,
// row count * (overhead + column size)
size = Some(10 * (8 + 4)))

test("limit estimation: limit < child's rowCount") {
val localLimit = LocalLimit(Literal(2), plan)
val globalLimit = GlobalLimit(Literal(2), plan)
// LocalLimit's stats is just its child's stats except column stats
checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2)))
}

test("limit estimation: limit > child's rowCount") {
val localLimit = LocalLimit(Literal(20), plan)
val globalLimit = GlobalLimit(Literal(20), plan)
checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
// Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats.
checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
}

test("limit estimation: limit = 0") {
val localLimit = LocalLimit(Literal(0), plan)
val globalLimit = GlobalLimit(Literal(0), plan)
val stats = Statistics(sizeInBytes = 1, rowCount = Some(0))
checkStats(localLimit, stats)
checkStats(globalLimit, stats)
}

test("sample estimation") {
val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)()
checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5)))

// Child doesn't have rowCount in stats
val childStats = Statistics(sizeInBytes = 120)
val childPlan = DummyLogicalPlan(childStats, childStats)
val sample2 =
Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)()
checkStats(sample2, Statistics(sizeInBytes = 14))
}

test("estimate statistics when the conf changes") {
val expectedDefaultStats =
Statistics(
sizeInBytes = 40,
rowCount = Some(10),
attributeStats = AttributeMap(Seq(
AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))),
isBroadcastable = false)
val expectedCboStats =
Statistics(
sizeInBytes = 4,
rowCount = Some(1),
attributeStats = AttributeMap(Seq(
AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))),
isBroadcastable = false)

val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats)
checkStats(
plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = expectedDefaultStats)
}

/** Check estimated stats when cbo is turned on/off. */
private def checkStats(
plan: LogicalPlan,
expectedStatsCboOn: Statistics,
expectedStatsCboOff: Statistics): Unit = {
assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn)
// Invalidate statistics
plan.invalidateStatsCache()
assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff)
}

/** Check estimated stats when it's the same whether cbo is turned on or off. */
private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit =
checkStats(plan, expectedStats, expectedStats)
}

/**
* This class is used for unit-testing the cbo switch, it mimics a logical plan which computes
* a simple statistics or a cbo estimated statistics based on the conf.
*/
private case class DummyLogicalPlan(
defaultStats: Statistics,
cboStats: Statistics) extends LogicalPlan {
override def output: Seq[Attribute] = Nil
override def children: Seq[LogicalPlan] = Nil
override def computeStats(conf: CatalystConf): Statistics =
if (conf.cboEnabled) cboStats else defaultStats
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -112,30 +112,6 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
spark.sessionState.conf.autoBroadcastJoinThreshold)
}

test("estimates the size of limit") {
withTempView("test") {
Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
.createOrReplaceTempView("test")
Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) =>
val df = sql(s"""SELECT * FROM test limit $limit""")

val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit =>
g.stats(conf).sizeInBytes
}
assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizesGlobalLimit.head === BigInt(expected),
s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}")

val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit =>
l.stats(conf).sizeInBytes
}
assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizesLocalLimit.head === BigInt(expected),
s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}")
}
}
}

test("column stats round trip serialization") {
// Make sure we serialize and then deserialize and we will get the result data
val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
Expand Down