-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-16282][SQL] Implement percentile SQL function. #14136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test build #62088 has finished for PR 14136 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Style
|
Test build #62089 has finished for PR 14136 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use a while loop here; for is not that efficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I'll update that.
|
@jiangxb1987 Thanks for working on this. I did a quick pass, and it is a good start. I have a few issues:
|
|
A more performant way of this would be to plan this using a combination of count grouped by the percentile key, this percentile function. I am not sure if we should pursue that for this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jiangxb1987 I am just curious about why we use OpenHashMap here instead of using mutable.Map to correspond with code here in hive. Is there any specific reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OpenHashMap is typically faster and has less overhead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hvanhovell Thanks!
|
@hvanhovell Thank you for your kindly review, the suggestions are quite useful for me. I'll try to get some time later today to update some fixes. Thanks! |
|
Test build #62313 has finished for PR 14136 at commit
|
|
@hvanhovell I've fixed most of the problems mentioned above, and I also added basic tests and comments as you required. Please find some time to do a pass, thanks! |
|
Test build #62314 has finished for PR 14136 at commit
|
|
Test build #62315 has finished for PR 14136 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we check here if a percentile is valid? Waiting until eval is really late in the game.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also check if the array is not empty.
|
We also need to remove line here https://github.com/apache/spark/blob/master/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala#L240. |
|
Test build #62407 has finished for PR 14136 at commit
|
|
Test build #62408 has finished for PR 14136 at commit
|
|
Test build #62455 has finished for PR 14136 at commit
|
|
Test build #62466 has finished for PR 14136 at commit
|
|
Test build #69110 has started for PR 14136 at commit |
|
retest this please. |
|
Test build #69121 has finished for PR 14136 at commit
|
|
Test build #69144 has finished for PR 14136 at commit
|
| Countings() | ||
| } | ||
|
|
||
| private def evalPercentages(expr: Expression): (Boolean, Seq[Number]) = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not return doubles?
| copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
|
||
| // Mark as lazy so that percentageExpression is not evaluated during tree transformation. | ||
| private lazy val (returnPercentileArray: Boolean, percentages: Seq[Number]) = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be problematic with serialization. Just put the percentages in a @transient lazy val and inline the use of returnPercentileArray.
| override def nullable: Boolean = true | ||
|
|
||
| override def dataType: DataType = | ||
| if (returnPercentileArray) ArrayType(DoubleType) else DoubleType |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should return the type of the input. We can always interpolate the value and cast that to the input type. Is this is different from what Hive does?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HIVE could return double value or array of double values even the column dataType is integer, for example:
hive> insert into tbl values(1,2,5,10);
hive> insert into tbl values(1),(2),(5),(10);
hive> select percentile(a, array(0, 0.25, 0.5, 0.75, 1)) from tbl;
[1.0,1.75,3.5,6.25,10.0]
| // Returns null for empty inputs | ||
| override def nullable: Boolean = true | ||
|
|
||
| override def dataType: DataType = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
override lazy val dataType: DataType = percentageExpression.dataType match {
case _: ArrayType => ArrayType(DoubleType, false)
case _ => DoubleType
}| Seq(NumericType, TypeCollection(NumericType, ArrayType)) | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = | ||
| TypeUtils.checkForNumericExpr(child.dataType, "function percentile") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call super.checkInputDataTypes(), that will validate the inputTypes(). Also check the percentageExpression, that must foldable and the percentage(s) must be in the range [0, 1].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW - you can make the analyzer add casts for you:
override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match {
case _: ArrayType => Seq(NumericType, ArrayType(DoubleType, false))
case _ => Seq(NumericType, DoubleType)
}Then you are alway sure you get a double or a double array for the percentageExpression.
hvanhovell
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did another pass. My main feedback is to consolidate this more in a single class.
| /** | ||
| * A class that stores the numbers and their counts, used to support [[Percentile]] function. | ||
| */ | ||
| class Countings(val counts: OpenHashMap[Number, Long]) extends Serializable { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this class and put its implementation in the Percentile Aggregate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class TypedImperativeAggregate[T] requires access of this class, so perhaps we should keep it outside of the Percentile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could entirely remove the class Countings.
| */ | ||
| class CountingsSerializer { | ||
|
|
||
| final def serialize(obj: Countings, dataType: DataType): Array[Byte] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just put this in the Percentile class.
| return Seq.empty | ||
| } | ||
|
|
||
| val sortedCounts = counts.toSeq.sortBy(_._1)(new Ordering[Number]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use child.asInstanceOf[NumericType].ordering.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a dumb question: How can we order a sequence of Number using the Ordering[NumericType#InternalType] ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could cast the ordering?
| override def compare(a: Number, b: Number): Int = | ||
| scala.math.signum(a.doubleValue() - b.doubleValue()).toInt | ||
| }) | ||
| val aggreCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just use an imperative loop.
| val lower = position.floor | ||
| val higher = position.ceil | ||
|
|
||
| // Linear search since this won't take much time from the total execution anyway |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That doesn't make it right :)... Anyway there are enough binarySearch implementations around. So maybe use one of those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was taken from Hive UDAFPercentile. It is fine if you do that, but please acknowledge that you have done so by adding a line of documentation. See this for example: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala#L524
| counts.foreach { pair => | ||
| val row = InternalRow.apply(pair._1, pair._2) | ||
| val unsafeRow = projection.apply(row) | ||
| buffer ++= unsafeRow.getBytes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is extremely expensive, because you are resizing the buffer for every entry. Please use a ByteArrayOutputStream and a DataOutputStream. See this for an example: https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala#L226-L239
|
|
||
| // Read the pairs of counts map | ||
| val row = new UnsafeRow(2) | ||
| val pairRowSizeInBytes = UnsafeRow.calculateFixedPortionByteSize(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might cause an issue for a DecimalType, a decimal does not have to be fixed. I think we need to write out row sizes or not allow variable length keys. BTW if you only allow fixed length keys, you could get rid of UnsafeRows and projections and directly use a DataOutputStream.
| Countings() | ||
| } | ||
|
|
||
| private def evalPercentages(expr: Expression): Seq[Double] = (expr.dataType, expr.eval()) match { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this to the definition of percentages. You can also make this much simpler. The analyzer guarantees that you either get a single double, or an ArrayData of double:
@transient
private lazy val percentages = percentageExpression.eval() match {
case p: Double => Seq(p)
case a: ArrayData => a.toDoubleArray().toSeq
}| copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
|
||
| // Mark as lazy so that percentageExpression is not evaluated during tree transformation. | ||
| private lazy val returnPercentileArray = percentageExpression.dataType.isInstanceOf[ArrayType] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mark it @transient.
| defaultCheck | ||
| } else if (!percentageExpression.foldable) { | ||
| // percentageExpression must be foldable | ||
| TypeCheckFailure(s"The percentage(s) must be a constant literal, " + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit no string interpolation.
| } else if (!percentageExpression.foldable) { | ||
| // percentageExpression must be foldable | ||
| TypeCheckFailure(s"The percentage(s) must be a constant literal, " + | ||
| s"but got ${percentageExpression}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: you don't need {...}?
|
Test build #69186 has finished for PR 14136 at commit
|
|
Test build #69188 has finished for PR 14136 at commit
|
|
|
||
| val sortedCounts = buffer.toSeq.sortBy(_._1)( | ||
| child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]]) | ||
| val aggreCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe accumlatedCounts is a slightly better name than aggreCounts here?
|
Currently |
|
Test build #69233 has finished for PR 14136 at commit
|
## What changes were proposed in this pull request? Implement percentile SQL function. It computes the exact percentile(s) of expr at pc with range in [0, 1]. ## How was this patch tested? Add a new testsuite `PercentileSuite` to test percentile directly. Updated related testcases in `ExpressionToSQLSuite`. Author: jiangxingbo <jiangxb1987@gmail.com> Author: 蒋星博 <jiangxingbo@meituan.com> Author: jiangxingbo <jiangxingbo@meituan.com> Closes #14136 from jiangxb1987/percentile. (cherry picked from commit 0f5f52a) Signed-off-by: Herman van Hovell <hvanhovell@databricks.com>
|
LGTM. Merging to master/2.1. Thanks! |
|
@hvanhovell why did this go into branch-2.1? It's way past branch cut time. |
## What changes were proposed in this pull request? Implement percentile SQL function. It computes the exact percentile(s) of expr at pc with range in [0, 1]. ## How was this patch tested? Add a new testsuite `PercentileSuite` to test percentile directly. Updated related testcases in `ExpressionToSQLSuite`. Author: jiangxingbo <jiangxb1987@gmail.com> Author: 蒋星博 <jiangxingbo@meituan.com> Author: jiangxingbo <jiangxingbo@meituan.com> Closes apache#14136 from jiangxb1987/percentile.
## What changes were proposed in this pull request? Implement percentile SQL function. It computes the exact percentile(s) of expr at pc with range in [0, 1]. ## How was this patch tested? Add a new testsuite `PercentileSuite` to test percentile directly. Updated related testcases in `ExpressionToSQLSuite`. Author: jiangxingbo <jiangxb1987@gmail.com> Author: 蒋星博 <jiangxingbo@meituan.com> Author: jiangxingbo <jiangxingbo@meituan.com> Closes apache#14136 from jiangxb1987/percentile.
|
Hi @rxin, I was reading related codes around this and saw - #14136 (comment). It looks many suggestions for calculating median are workarounds (e.g., https://stackoverflow.com/a/31437177). I want to use from pyspark.sql.functions import *
from pyspark.sql.column import Column, _to_java_column
def approximate_percentile(child, percentage, accuracy=lit(10000)):
percentile_expr = spark.sparkContext._jvm.org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
child_expr = _to_java_column(child).expr()
percentage_expr = _to_java_column(percentage).expr()
accuracy_expr = _to_java_column(accuracy).expr()
agg_func = percentile_expr(child_expr, percentage_expr, accuracy_expr)
return Column(spark._jvm.org.apache.spark.sql.Column(agg_func.toAggregateExpression()))
spark.range(1).groupby().agg(approximate_percentile(col("id"), lit(0.5))).show()
spark.range(1).groupby().pivot("id").agg(approximate_percentile(col("id"), lit(0.5))).show()This code might be easily broken by Spark version as it accesses to internal packages via JVM. I use Another alternative should be to port existing logic in application side to SQL ones but I was wondering if I really should do this for single case. It might be expensive but exposing it might also promote users to test this at least. Could we expose this in Scala/Python/R? It should be pretty easy to expose this. Or, did I misunderstand the context and other workarounds? cc @srowen and @zero323 who I saw answered to the questions related with this outside (e.g., stackoverflow). |
What changes were proposed in this pull request?
Implement percentile SQL function. It computes the exact percentile(s) of expr at pc with range in [0, 1].
How was this patch tested?
Add a new testsuite
PercentileSuiteto test percentile directly.Updated related testcases in
ExpressionToSQLSuite.