Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21984] [SQL] Join estimation based on equi-height histogram #19594

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.plans.logical.statsEstimation

import scala.collection.mutable.ArrayBuffer
import scala.math.BigDecimal.RoundingMode

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
Expand Down Expand Up @@ -212,4 +213,172 @@ object EstimationUtils {
}
}

/**
* Returns overlapped ranges between two histograms, in the given value range
* [lowerBound, upperBound].
*/
def getOverlappedRanges(
leftHistogram: Histogram,
rightHistogram: Histogram,
lowerBound: Double,
upperBound: Double): Seq[OverlappedRange] = {
val overlappedRanges = new ArrayBuffer[OverlappedRange]()
// Only bins whose range intersect [lowerBound, upperBound] have join possibility.
val leftBins = leftHistogram.bins
.filter(b => b.lo <= upperBound && b.hi >= lowerBound)
val rightBins = rightHistogram.bins
.filter(b => b.lo <= upperBound && b.hi >= lowerBound)

leftBins.foreach { lb =>
rightBins.foreach { rb =>
val (left, leftHeight) = trimBin(lb, leftHistogram.height, lowerBound, upperBound)
val (right, rightHeight) = trimBin(rb, rightHistogram.height, lowerBound, upperBound)
// Only collect overlapped ranges.
if (left.lo <= right.hi && left.hi >= right.lo) {
// Collect overlapped ranges.
val range = if (right.lo >= left.lo && right.hi >= left.hi) {
// Case1: the left bin is "smaller" than the right bin
// left.lo right.lo left.hi right.hi
// --------+------------------+------------+----------------+------->
if (left.hi == right.lo) {
// The overlapped range has only one value.
OverlappedRange(
lo = right.lo,
hi = right.lo,
leftNdv = 1,
rightNdv = 1,
leftNumRows = leftHeight / left.ndv,
rightNumRows = rightHeight / right.ndv
)
} else {
val leftRatio = (left.hi - right.lo) / (left.hi - left.lo)
val rightRatio = (left.hi - right.lo) / (right.hi - right.lo)
OverlappedRange(
lo = right.lo,
hi = left.hi,
leftNdv = left.ndv * leftRatio,
rightNdv = right.ndv * rightRatio,
leftNumRows = leftHeight * leftRatio,
rightNumRows = rightHeight * rightRatio
)
}
} else if (right.lo <= left.lo && right.hi <= left.hi) {
// Case2: the left bin is "larger" than the right bin
// right.lo left.lo right.hi left.hi
// --------+------------------+------------+----------------+------->
if (right.hi == left.lo) {
// The overlapped range has only one value.
OverlappedRange(
lo = right.hi,
hi = right.hi,
leftNdv = 1,
rightNdv = 1,
leftNumRows = leftHeight / left.ndv,
rightNumRows = rightHeight / right.ndv
)
} else {
val leftRatio = (right.hi - left.lo) / (left.hi - left.lo)
val rightRatio = (right.hi - left.lo) / (right.hi - right.lo)
OverlappedRange(
lo = left.lo,
hi = right.hi,
leftNdv = left.ndv * leftRatio,
rightNdv = right.ndv * rightRatio,
leftNumRows = leftHeight * leftRatio,
rightNumRows = rightHeight * rightRatio
)
}
} else if (right.lo >= left.lo && right.hi <= left.hi) {
// Case3: the left bin contains the right bin
// left.lo right.lo right.hi left.hi
// --------+------------------+------------+----------------+------->
val leftRatio = (right.hi - right.lo) / (left.hi - left.lo)
OverlappedRange(
lo = right.lo,
hi = right.hi,
leftNdv = left.ndv * leftRatio,
rightNdv = right.ndv,
leftNumRows = leftHeight * leftRatio,
rightNumRows = rightHeight
)
} else {
assert(right.lo <= left.lo && right.hi >= left.hi)
// Case4: the right bin contains the left bin
// right.lo left.lo left.hi right.hi
// --------+------------------+------------+----------------+------->
val rightRatio = (left.hi - left.lo) / (right.hi - right.lo)
OverlappedRange(
lo = left.lo,
hi = left.hi,
leftNdv = left.ndv,
rightNdv = right.ndv * rightRatio,
leftNumRows = leftHeight,
rightNumRows = rightHeight * rightRatio
)
}
overlappedRanges += range
}
}
}
overlappedRanges
}

/**
* Given an original bin and a value range [lowerBound, upperBound], returns the trimmed part
* of the bin in that range and its number of rows.
* @param bin the input histogram bin.
* @param height the number of rows of the given histogram bin inside an equi-height histogram.
* @param lowerBound lower bound of the given range.
* @param upperBound upper bound of the given range.
* @return trimmed part of the given bin and its number of rows.
*/
def trimBin(bin: HistogramBin, height: Double, lowerBound: Double, upperBound: Double)
: (HistogramBin, Double) = {
val (lo, hi) = if (bin.lo <= lowerBound && bin.hi >= upperBound) {
// bin.lo lowerBound upperBound bin.hi
// --------+------------------+------------+-------------+------->
(lowerBound, upperBound)
} else if (bin.lo <= lowerBound && bin.hi >= lowerBound) {
// bin.lo lowerBound bin.hi upperBound
// --------+------------------+------------+-------------+------->
(lowerBound, bin.hi)
} else if (bin.lo <= upperBound && bin.hi >= upperBound) {
// lowerBound bin.lo upperBound bin.hi
// --------+------------------+------------+-------------+------->
(bin.lo, upperBound)
} else {
// lowerBound bin.lo bin.hi upperBound
// --------+------------------+------------+-------------+------->
assert(bin.lo >= lowerBound && bin.hi <= upperBound)
(bin.lo, bin.hi)
}

if (hi == lo) {
// Note that bin.hi == bin.lo also falls into this branch.
(HistogramBin(lo, hi, 1), height / bin.ndv)
} else {
assert(bin.hi != bin.lo)
val ratio = (hi - lo) / (bin.hi - bin.lo)
(HistogramBin(lo, hi, math.ceil(bin.ndv * ratio).toLong), height * ratio)
}
}

/**
* A join between two equi-height histograms may produce multiple overlapped ranges.
* Each overlapped range is produced by a part of one bin in the left histogram and a part of
* one bin in the right histogram.
* @param lo lower bound of this overlapped range.
* @param hi higher bound of this overlapped range.
* @param leftNdv ndv in the left part.
* @param rightNdv ndv in the right part.
* @param leftNumRows number of rows in the left part.
* @param rightNumRows number of rows in the right part.
*/
case class OverlappedRange(
lo: Double,
hi: Double,
leftNdv: Double,
rightNdv: Double,
leftNumRows: Double,
rightNumRows: Double)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, Join, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._


Expand Down Expand Up @@ -191,8 +191,19 @@ case class JoinEstimation(join: Join) extends Logging {
val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType)
if (ValueInterval.isIntersected(lInterval, rInterval)) {
val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType)
val (card, joinStat) = computeByNdv(leftKey, rightKey, newMin, newMax)
keyStatsAfterJoin += (leftKey -> joinStat, rightKey -> joinStat)
val (card, joinStat) = (leftKeyStat.histogram, rightKeyStat.histogram) match {
case (Some(l: Histogram), Some(r: Histogram)) =>
computeByHistogram(leftKey, rightKey, l, r, newMin, newMax)
case _ =>
computeByNdv(leftKey, rightKey, newMin, newMax)
}
keyStatsAfterJoin += (
// Histograms are propagated as unchanged. During future estimation, they should be
// truncated by the updated max/min. In this way, only pointers of the histograms are
// propagated and thus reduce memory consumption.
leftKey -> joinStat.copy(histogram = leftKeyStat.histogram),
rightKey -> joinStat.copy(histogram = rightKeyStat.histogram)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we do this inside computeByEquiHeightHistogram?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put it here because computeByEquiHeightHistogram returns a single stats, here we keep the histogram for leftKey and rightKey respectively.

)
// Return cardinality estimated from the most selective join keys.
if (card < joinCard) joinCard = card
} else {
Expand Down Expand Up @@ -225,6 +236,43 @@ case class JoinEstimation(join: Join) extends Logging {
(ceil(card), newStats)
}

/** Compute join cardinality using equi-height histograms. */
private def computeByHistogram(
leftKey: AttributeReference,
rightKey: AttributeReference,
leftHistogram: Histogram,
rightHistogram: Histogram,
newMin: Option[Any],
newMax: Option[Any]): (BigInt, ColumnStat) = {
val overlappedRanges = getOverlappedRanges(
leftHistogram = leftHistogram,
rightHistogram = rightHistogram,
// Only numeric values have equi-height histograms.
lowerBound = newMin.get.toString.toDouble,
upperBound = newMax.get.toString.toDouble)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we assume the min/max must be defined here, I think the parameter type should be double instead of Option[Any]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's because we need to update the column stats' min and max at the end of the method.


var card: BigDecimal = 0
var totalNdv: Double = 0
for (i <- overlappedRanges.indices) {
val range = overlappedRanges(i)
if (i == 0 || range.hi != overlappedRanges(i - 1).hi) {
// If range.hi == overlappedRanges(i - 1).hi, that means the current range has only one
// value, and this value is already counted in the previous range. So there is no need to
// count it in this range.
totalNdv += math.min(range.leftNdv, range.rightNdv)
}
// Apply the formula in this overlapped range.
card += range.leftNumRows * range.rightNumRows / math.max(range.leftNdv, range.rightNdv)
}

val leftKeyStat = leftStats.attributeStats(leftKey)
val rightKeyStat = rightStats.attributeStats(rightKey)
val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen)
val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we count left/right numRows when calculating this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we use left/right numRows to calculate this? Ideally avgLen is calculated by total length of keys / numRowsAfterJoin. For string type, we don't the exact length of the matched keys (we don't support string histogram yet), for numeric types, their avgLen should be the same. So the equation is a fair approximation.

val newStats = ColumnStat(ceil(totalNdv), newMin, newMax, 0, newAvgLen, newMaxLen)
(ceil(card), newStats)
}

/**
* Propagate or update column stats for output attributes.
*/
Expand Down
Loading