-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-21984] [SQL] Join estimation based on equi-height histogram #19594
Changes from 6 commits
8b2084a
6cb9b39
e69e213
ad14a5e
2a4ee99
2637429
e1669ed
16797d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
|
||
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
import scala.math.BigDecimal.RoundingMode | ||
|
||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} | ||
|
@@ -212,4 +213,186 @@ object EstimationUtils { | |
} | ||
} | ||
|
||
/** | ||
* Returns overlapped ranges between two histograms, in the given value range | ||
* [lowerBound, upperBound]. | ||
*/ | ||
def getOverlappedRanges( | ||
leftHistogram: Histogram, | ||
rightHistogram: Histogram, | ||
lowerBound: Double, | ||
upperBound: Double): Seq[OverlappedRange] = { | ||
val overlappedRanges = new ArrayBuffer[OverlappedRange]() | ||
// Only bins whose range intersect [lowerBound, upperBound] have join possibility. | ||
val leftBins = leftHistogram.bins | ||
.filter(b => b.lo <= upperBound && b.hi >= lowerBound) | ||
val rightBins = rightHistogram.bins | ||
.filter(b => b.lo <= upperBound && b.hi >= lowerBound) | ||
|
||
leftBins.foreach { lb => | ||
rightBins.foreach { rb => | ||
val (left, leftHeight) = trimBin(lb, leftHistogram.height, lowerBound, upperBound) | ||
val (right, rightHeight) = trimBin(rb, rightHistogram.height, lowerBound, upperBound) | ||
// Only collect overlapped ranges. | ||
if (left.lo <= right.hi && left.hi >= right.lo) { | ||
// Collect overlapped ranges. | ||
val range = if (left.lo == left.hi) { | ||
// Case1: the left bin has only one value | ||
OverlappedRange( | ||
lo = left.lo, | ||
hi = left.lo, | ||
leftNdv = 1, | ||
rightNdv = 1, | ||
leftNumRows = leftHeight, | ||
rightNumRows = rightHeight / right.ndv | ||
) | ||
} else if (right.lo == right.hi) { | ||
// Case2: the right bin has only one value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we really need case 1 and 2? aren't they covered by branches below? |
||
OverlappedRange( | ||
lo = right.lo, | ||
hi = right.lo, | ||
leftNdv = 1, | ||
rightNdv = 1, | ||
leftNumRows = leftHeight / left.ndv, | ||
rightNumRows = rightHeight | ||
) | ||
} else if (right.lo >= left.lo && right.hi >= left.hi) { | ||
// Case3: the left bin is "smaller" than the right bin | ||
// left.lo right.lo left.hi right.hi | ||
// --------+------------------+------------+----------------+-------> | ||
if (left.hi == right.lo) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea this branch is needed, otherwise we will get 0 ratio which leads to wrong result. |
||
// The overlapped range has only one value. | ||
OverlappedRange( | ||
lo = right.lo, | ||
hi = right.lo, | ||
leftNdv = 1, | ||
rightNdv = 1, | ||
leftNumRows = leftHeight / left.ndv, | ||
rightNumRows = rightHeight / right.ndv | ||
) | ||
} else { | ||
val leftRatio = (left.hi - right.lo) / (left.hi - left.lo) | ||
val rightRatio = (left.hi - right.lo) / (right.hi - right.lo) | ||
OverlappedRange( | ||
lo = right.lo, | ||
hi = left.hi, | ||
leftNdv = left.ndv * leftRatio, | ||
rightNdv = right.ndv * rightRatio, | ||
leftNumRows = leftHeight * leftRatio, | ||
rightNumRows = rightHeight * rightRatio | ||
) | ||
} | ||
} else if (right.lo <= left.lo && right.hi <= left.hi) { | ||
// Case4: the left bin is "larger" than the right bin | ||
// right.lo left.lo right.hi left.hi | ||
// --------+------------------+------------+----------------+-------> | ||
if (right.hi == left.lo) { | ||
// The overlapped range has only one value. | ||
OverlappedRange( | ||
lo = right.hi, | ||
hi = right.hi, | ||
leftNdv = 1, | ||
rightNdv = 1, | ||
leftNumRows = leftHeight / left.ndv, | ||
rightNumRows = rightHeight / right.ndv | ||
) | ||
} else { | ||
val leftRatio = (right.hi - left.lo) / (left.hi - left.lo) | ||
val rightRatio = (right.hi - left.lo) / (right.hi - right.lo) | ||
OverlappedRange( | ||
lo = left.lo, | ||
hi = right.hi, | ||
leftNdv = left.ndv * leftRatio, | ||
rightNdv = right.ndv * rightRatio, | ||
leftNumRows = leftHeight * leftRatio, | ||
rightNumRows = rightHeight * rightRatio | ||
) | ||
} | ||
} else if (right.lo >= left.lo && right.hi <= left.hi) { | ||
// Case5: the left bin contains the right bin | ||
// left.lo right.lo right.hi left.hi | ||
// --------+------------------+------------+----------------+-------> | ||
val leftRatio = (right.hi - right.lo) / (left.hi - left.lo) | ||
OverlappedRange( | ||
lo = right.lo, | ||
hi = right.hi, | ||
leftNdv = left.ndv * leftRatio, | ||
rightNdv = right.ndv, | ||
leftNumRows = leftHeight * leftRatio, | ||
rightNumRows = rightHeight | ||
) | ||
} else { | ||
assert(right.lo <= left.lo && right.hi >= left.hi) | ||
// Case6: the right bin contains the left bin | ||
// right.lo left.lo left.hi right.hi | ||
// --------+------------------+------------+----------------+-------> | ||
val rightRatio = (left.hi - left.lo) / (right.hi - right.lo) | ||
OverlappedRange( | ||
lo = left.lo, | ||
hi = left.hi, | ||
leftNdv = left.ndv, | ||
rightNdv = right.ndv * rightRatio, | ||
leftNumRows = leftHeight, | ||
rightNumRows = rightHeight * rightRatio | ||
) | ||
} | ||
overlappedRanges += range | ||
} | ||
} | ||
} | ||
overlappedRanges | ||
} | ||
|
||
/** | ||
* Given an original bin and a value range [lowerBound, upperBound], returns the trimmed part | ||
* of the bin in that range and its number of rows. | ||
*/ | ||
def trimBin(bin: HistogramBin, height: Double, lowerBound: Double, upperBound: Double) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe explain in the comment that |
||
: (HistogramBin, Double) = { | ||
val (lo, hi) = if (bin.lo <= lowerBound && bin.hi >= upperBound) { | ||
// bin.lo lowerBound upperBound bin.hi | ||
// --------+------------------+------------+-------------+-------> | ||
(lowerBound, upperBound) | ||
} else if (bin.lo <= lowerBound && bin.hi >= lowerBound) { | ||
// bin.lo lowerBound bin.hi upperBound | ||
// --------+------------------+------------+-------------+-------> | ||
(lowerBound, bin.hi) | ||
} else if (bin.lo <= upperBound && bin.hi >= upperBound) { | ||
// lowerBound bin.lo upperBound bin.hi | ||
// --------+------------------+------------+-------------+-------> | ||
(bin.lo, upperBound) | ||
} else { | ||
// lowerBound bin.lo bin.hi upperBound | ||
// --------+------------------+------------+-------------+-------> | ||
(bin.lo, bin.hi) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add an assert to make sure if we reach here, the case is what we want. |
||
} | ||
|
||
if (bin.hi == bin.lo) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we really need this branch? I think the |
||
(bin, height) | ||
} else if (hi == lo) { | ||
(HistogramBin(lo, hi, 1), height / bin.ndv) | ||
} else { | ||
val ratio = (hi - lo) / (bin.hi - bin.lo) | ||
(HistogramBin(lo, hi, math.ceil(bin.ndv * ratio).toLong), height * ratio) | ||
} | ||
} | ||
|
||
/** | ||
* A join between two equi-height histograms may produce multiple overlapped ranges. | ||
* Each overlapped range is produced by a part of one bin in the left histogram and a part of | ||
* one bin in the right histogram. | ||
* @param lo lower bound of this overlapped range. | ||
* @param hi higher bound of this overlapped range. | ||
* @param leftNdv ndv in the left part. | ||
* @param rightNdv ndv in the right part. | ||
* @param leftNumRows number of rows in the left part. | ||
* @param rightNumRows number of rows in the right part. | ||
*/ | ||
case class OverlappedRange( | ||
lo: Double, | ||
hi: Double, | ||
leftNdv: Double, | ||
rightNdv: Double, | ||
leftNumRows: Double, | ||
rightNumRows: Double) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging | |
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} | ||
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys | ||
import org.apache.spark.sql.catalyst.plans._ | ||
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} | ||
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, Join, Statistics} | ||
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ | ||
|
||
|
||
|
@@ -191,8 +191,19 @@ case class JoinEstimation(join: Join) extends Logging { | |
val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType) | ||
if (ValueInterval.isIntersected(lInterval, rInterval)) { | ||
val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) | ||
val (card, joinStat) = computeByNdv(leftKey, rightKey, newMin, newMax) | ||
keyStatsAfterJoin += (leftKey -> joinStat, rightKey -> joinStat) | ||
val (card, joinStat) = (leftKeyStat.histogram, rightKeyStat.histogram) match { | ||
case (Some(l: Histogram), Some(r: Histogram)) => | ||
computeByEquiHeightHistogram(leftKey, rightKey, l, r, newMin, newMax) | ||
case _ => | ||
computeByNdv(leftKey, rightKey, newMin, newMax) | ||
} | ||
keyStatsAfterJoin += ( | ||
// Histograms are propagated as unchanged. During future estimation, they should be | ||
// truncated by the updated max/min. In this way, only pointers of the histograms are | ||
// propagated and thus reduce memory consumption. | ||
leftKey -> joinStat.copy(histogram = leftKeyStat.histogram), | ||
rightKey -> joinStat.copy(histogram = rightKeyStat.histogram) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we update the histogram after join? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently we don't update histogram since min/max can help us to know which bins are valid. It doesn't affect correctness. But updating histograms helps to reduce memory usage for histogram propagation. We can do this in both filter and join estimation in following PRs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually keeping it unchanged is more memory efficient. We just pass around pointers, but updating the histogram means creating a new one. Let's keep it, and add some comments to explain it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah right, we can keep it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we do this inside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I put it here because |
||
) | ||
// Return cardinality estimated from the most selective join keys. | ||
if (card < joinCard) joinCard = card | ||
} else { | ||
|
@@ -225,6 +236,43 @@ case class JoinEstimation(join: Join) extends Logging { | |
(ceil(card), newStats) | ||
} | ||
|
||
/** Compute join cardinality using equi-height histograms. */ | ||
private def computeByEquiHeightHistogram( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's ok to only say |
||
leftKey: AttributeReference, | ||
rightKey: AttributeReference, | ||
leftHistogram: Histogram, | ||
rightHistogram: Histogram, | ||
newMin: Option[Any], | ||
newMax: Option[Any]): (BigInt, ColumnStat) = { | ||
val overlappedRanges = getOverlappedRanges( | ||
leftHistogram = leftHistogram, | ||
rightHistogram = rightHistogram, | ||
// Only numeric values have equi-height histograms. | ||
lowerBound = newMin.get.toString.toDouble, | ||
upperBound = newMax.get.toString.toDouble) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we assume the min/max must be defined here, I think the parameter type should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's because we need to update the column stats' min and max at the end of the method. |
||
|
||
var card: BigDecimal = 0 | ||
var totalNdv: Double = 0 | ||
for (i <- overlappedRanges.indices) { | ||
val range = overlappedRanges(i) | ||
if (i == 0 || range.hi != overlappedRanges(i - 1).hi) { | ||
// If range.hi == overlappedRanges(i - 1).hi, that means the current range has only one | ||
// value, and this value is already counted in the previous range. So there is no need to | ||
// count it in this range. | ||
totalNdv += math.min(range.leftNdv, range.rightNdv) | ||
} | ||
// Apply the formula in this overlapped range. | ||
card += range.leftNumRows * range.rightNumRows / math.max(range.leftNdv, range.rightNdv) | ||
} | ||
|
||
val leftKeyStat = leftStats.attributeStats(leftKey) | ||
val rightKeyStat = rightStats.attributeStats(rightKey) | ||
val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) | ||
val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we count left/right numRows when calculating this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how do we use left/right numRows to calculate this? Ideally avgLen is calculated by total length of keys / numRowsAfterJoin. For string type, we don't the exact length of the matched keys (we don't support string histogram yet), for numeric types, their avgLen should be the same. So the equation is a fair approximation. |
||
val newStats = ColumnStat(ceil(totalNdv), newMin, newMax, 0, newAvgLen, newMaxLen) | ||
(ceil(card), newStats) | ||
} | ||
|
||
/** | ||
* Propagate or update column stats for output attributes. | ||
*/ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
Then we can omit
val overlappedRanges = new ArrayBuffer[OverlappedRange]()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only collect
OverlappedRange
when left part and right part intersect, and the decision is based on some computation, it's not very convenient to use it as guards. So it seemsyield
form is not very suitable for this case.