-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-10641][SQL] Add Skewness and Kurtosis Support #9003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
bc8ab0c
Added skewness and kurtosis aggregate functions
sethah cf52ed7
Adding kahan updates to higher order aggregate stats
sethah 7ecf50e
adding zero division protection
sethah 579b9f2
Adding order check to reduce calculation overhead
sethah 230f66c
style and scaladoc fixes
sethah 1c4c4d0
updating kurtosis test
sethah dc223bc
converting from codegen to imperative aggregate
sethah 83fb682
cleaning up and style fixes
sethah d54fb0d
cast child expression to double type
sethah 853922a
using vars for aggregator
sethah 4a5350e
restructuring eval method
sethah dba511b
style cleanup
sethah 44c1437
reverting some stddev changes
sethah 7baac9d
adding helper function for tests with tolerances
sethah 345463e
more generic tests with tolerance function and placeholders for Aggre…
sethah 3ef2faa
style and readability updates
sethah fd3f4d6
addressing feedback
sethah cf8a14b
correcting error in merge function
sethah b86386a
adding back some stddev codegen tests
sethah 3045e3b
changing variance to default to population variance
sethah ff363cc
removing fetch_aggregation from whitelist
sethah f49ce5c
Throw UnsupportedOperationException for AggregateExpression1
sethah File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -930,3 +930,332 @@ object HyperLogLogPlusPlus { | |
| ) | ||
| // scalastyle:on | ||
| } | ||
|
|
||
| /** | ||
| * A central moment is the expected value of a specified power of the deviation of a random | ||
| * variable from the mean. Central moments are often used to characterize the properties of about | ||
| * the shape of a distribution. | ||
| * | ||
| * This class implements online, one-pass algorithms for computing the central moments of a set of | ||
| * points. | ||
| * | ||
| * Behavior: | ||
| * - null values are ignored | ||
| * - returns `Double.NaN` when the column contains `Double.NaN` values | ||
| * | ||
| * References: | ||
| * - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central Moments." | ||
| * 2015. http://arxiv.org/abs/1510.04923 | ||
| * | ||
| * @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance | ||
| * Algorithms for calculating variance (Wikipedia)]] | ||
| * | ||
| * @param child to compute central moments of. | ||
| */ | ||
| abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { | ||
|
|
||
| /** | ||
| * The central moment order to be computed. | ||
| */ | ||
| protected def momentOrder: Int | ||
|
|
||
| override def children: Seq[Expression] = Seq(child) | ||
|
|
||
| override def nullable: Boolean = false | ||
|
|
||
| override def dataType: DataType = DoubleType | ||
|
|
||
| // Expected input data type. | ||
| // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the | ||
| // new version at planning time (after analysis phase). For now, NullType is added at here | ||
| // to make it resolved when we have cases like `select avg(null)`. | ||
| // We can use our analyzer to cast NullType to the default data type of the NumericType once | ||
| // we remove the old aggregate functions. Then, we will not need NullType at here. | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) | ||
|
|
||
| override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) | ||
|
|
||
| /** | ||
| * Size of aggregation buffer. | ||
| */ | ||
| private[this] val bufferSize = 5 | ||
|
|
||
| override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i => | ||
| AttributeReference(s"M$i", DoubleType)() | ||
| } | ||
|
|
||
| // Note: although this simply copies aggBufferAttributes, this common code can not be placed | ||
| // in the superclass because that will lead to initialization ordering issues. | ||
| override val inputAggBufferAttributes: Seq[AttributeReference] = | ||
| aggBufferAttributes.map(_.newInstance()) | ||
|
|
||
| // buffer offsets | ||
| private[this] val nOffset = mutableAggBufferOffset | ||
| private[this] val meanOffset = mutableAggBufferOffset + 1 | ||
| private[this] val secondMomentOffset = mutableAggBufferOffset + 2 | ||
| private[this] val thirdMomentOffset = mutableAggBufferOffset + 3 | ||
| private[this] val fourthMomentOffset = mutableAggBufferOffset + 4 | ||
|
|
||
| // frequently used values for online updates | ||
| private[this] var delta = 0.0 | ||
| private[this] var deltaN = 0.0 | ||
| private[this] var delta2 = 0.0 | ||
| private[this] var deltaN2 = 0.0 | ||
| private[this] var n = 0.0 | ||
| private[this] var mean = 0.0 | ||
| private[this] var m2 = 0.0 | ||
| private[this] var m3 = 0.0 | ||
| private[this] var m4 = 0.0 | ||
|
|
||
| /** | ||
| * Initialize all moments to zero. | ||
| */ | ||
| override def initialize(buffer: MutableRow): Unit = { | ||
| for (aggIndex <- 0 until bufferSize) { | ||
| buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Update the central moments buffer. | ||
| */ | ||
| override def update(buffer: MutableRow, input: InternalRow): Unit = { | ||
| val v = Cast(child, DoubleType).eval(input) | ||
| if (v != null) { | ||
| val updateValue = v match { | ||
| case d: Double => d | ||
| } | ||
|
|
||
| n = buffer.getDouble(nOffset) | ||
| mean = buffer.getDouble(meanOffset) | ||
|
|
||
| n += 1.0 | ||
| buffer.setDouble(nOffset, n) | ||
| delta = updateValue - mean | ||
| deltaN = delta / n | ||
| mean += deltaN | ||
| buffer.setDouble(meanOffset, mean) | ||
|
|
||
| if (momentOrder >= 2) { | ||
| m2 = buffer.getDouble(secondMomentOffset) | ||
| m2 += delta * (delta - deltaN) | ||
| buffer.setDouble(secondMomentOffset, m2) | ||
| } | ||
|
|
||
| if (momentOrder >= 3) { | ||
| delta2 = delta * delta | ||
| deltaN2 = deltaN * deltaN | ||
| m3 = buffer.getDouble(thirdMomentOffset) | ||
| m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) | ||
| buffer.setDouble(thirdMomentOffset, m3) | ||
| } | ||
|
|
||
| if (momentOrder >= 4) { | ||
| m4 = buffer.getDouble(fourthMomentOffset) | ||
| m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + | ||
| delta * (delta * delta2 - deltaN * deltaN2) | ||
| buffer.setDouble(fourthMomentOffset, m4) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Merge two central moment buffers. | ||
| */ | ||
| override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { | ||
| val n1 = buffer1.getDouble(nOffset) | ||
| val n2 = buffer2.getDouble(inputAggBufferOffset) | ||
| val mean1 = buffer1.getDouble(meanOffset) | ||
| val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) | ||
|
|
||
| var secondMoment1 = 0.0 | ||
| var secondMoment2 = 0.0 | ||
|
|
||
| var thirdMoment1 = 0.0 | ||
| var thirdMoment2 = 0.0 | ||
|
|
||
| var fourthMoment1 = 0.0 | ||
| var fourthMoment2 = 0.0 | ||
|
|
||
| n = n1 + n2 | ||
| buffer1.setDouble(nOffset, n) | ||
| delta = mean2 - mean1 | ||
| deltaN = if (n == 0.0) 0.0 else delta / n | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed divide by zero case here, which was causing problems when number of partitions > number of samples. |
||
| mean = mean1 + deltaN * n2 | ||
| buffer1.setDouble(mutableAggBufferOffset + 1, mean) | ||
|
|
||
| // higher order moments computed according to: | ||
| // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics | ||
| if (momentOrder >= 2) { | ||
| secondMoment1 = buffer1.getDouble(secondMomentOffset) | ||
| secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) | ||
| m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 | ||
| buffer1.setDouble(secondMomentOffset, m2) | ||
| } | ||
|
|
||
| if (momentOrder >= 3) { | ||
| thirdMoment1 = buffer1.getDouble(thirdMomentOffset) | ||
| thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) | ||
| m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * | ||
| (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) | ||
| buffer1.setDouble(thirdMomentOffset, m3) | ||
| } | ||
|
|
||
| if (momentOrder >= 4) { | ||
| fourthMoment1 = buffer1.getDouble(fourthMomentOffset) | ||
| fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) | ||
| m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * | ||
| n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * | ||
| (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + | ||
| 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) | ||
| buffer1.setDouble(fourthMomentOffset, m4) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Compute aggregate statistic from sufficient moments. | ||
| * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) | ||
| * needed to compute the aggregate stat. | ||
| */ | ||
| def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double | ||
|
|
||
| override final def eval(buffer: InternalRow): Any = { | ||
| val n = buffer.getDouble(nOffset) | ||
| val mean = buffer.getDouble(meanOffset) | ||
| val moments = Array.ofDim[Double](momentOrder + 1) | ||
| moments(0) = 1.0 | ||
| moments(1) = 0.0 | ||
| if (momentOrder >= 2) { | ||
| moments(2) = buffer.getDouble(secondMomentOffset) | ||
| } | ||
| if (momentOrder >= 3) { | ||
| moments(3) = buffer.getDouble(thirdMomentOffset) | ||
| } | ||
| if (momentOrder >= 4) { | ||
| moments(4) = buffer.getDouble(fourthMomentOffset) | ||
| } | ||
|
|
||
| getStatistic(n, mean, moments) | ||
| } | ||
| } | ||
|
|
||
| case class Variance(child: Expression, | ||
| mutableAggBufferOffset: Int = 0, | ||
| inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { | ||
|
|
||
| override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
| copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
|
||
| override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
| copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
|
||
| override def prettyName: String = "variance" | ||
|
|
||
| override protected val momentOrder = 2 | ||
|
|
||
| override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { | ||
| require(moments.length == momentOrder + 1, | ||
| s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") | ||
|
|
||
| if (n == 0.0) Double.NaN else moments(2) / n | ||
| } | ||
| } | ||
|
|
||
| case class VarianceSamp(child: Expression, | ||
| mutableAggBufferOffset: Int = 0, | ||
| inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { | ||
|
|
||
| override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
| copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
|
||
| override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
| copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
|
||
| override def prettyName: String = "variance_samp" | ||
|
|
||
| override protected val momentOrder = 2 | ||
|
|
||
| override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { | ||
| require(moments.length == momentOrder + 1, | ||
| s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") | ||
|
|
||
| if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) | ||
| } | ||
| } | ||
|
|
||
| case class VariancePop(child: Expression, | ||
| mutableAggBufferOffset: Int = 0, | ||
| inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { | ||
|
|
||
| override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
| copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
|
||
| override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
| copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
|
||
| override def prettyName: String = "variance_pop" | ||
|
|
||
| override protected val momentOrder = 2 | ||
|
|
||
| override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { | ||
| require(moments.length == momentOrder + 1, | ||
| s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") | ||
|
|
||
| if (n == 0.0) Double.NaN else moments(2) / n | ||
| } | ||
| } | ||
|
|
||
| case class Skewness(child: Expression, | ||
| mutableAggBufferOffset: Int = 0, | ||
| inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { | ||
|
|
||
| override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
| copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
|
||
| override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
| copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
|
||
| override def prettyName: String = "skewness" | ||
|
|
||
| override protected val momentOrder = 3 | ||
|
|
||
| override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { | ||
| require(moments.length == momentOrder + 1, | ||
| s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") | ||
| val m2 = moments(2) | ||
| val m3 = moments(3) | ||
| if (n == 0.0 || m2 == 0.0) { | ||
| Double.NaN | ||
| } else { | ||
| math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| case class Kurtosis(child: Expression, | ||
| mutableAggBufferOffset: Int = 0, | ||
| inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { | ||
|
|
||
| override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
| copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
|
||
| override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
| copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
|
||
| override def prettyName: String = "kurtosis" | ||
|
|
||
| override protected val momentOrder = 4 | ||
|
|
||
| // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy | ||
| override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { | ||
| require(moments.length == momentOrder + 1, | ||
| s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") | ||
| val m2 = moments(2) | ||
| val m4 = moments(4) | ||
| if (n == 0.0 || m2 == 0.0) { | ||
| Double.NaN | ||
| } else { | ||
| n * m4 / (m2 * m2) - 3.0 | ||
| } | ||
| } | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also document the behavior for
nullandNaNvalues.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not for the case when
n = 0orn =1but when we havenullorNaNin the values. Are we ignoring them or outputtingNaNdirectly?