Skip to content

Commit 816bf1a

Browse files
[SPARK-47836][SQL] Use doubles sketch replace the GK algorithm for approximate quantile computation, significantly improving merge performance
1 parent 0e10341 commit 816bf1a

File tree

5 files changed

+60
-99
lines changed

5 files changed

+60
-99
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala

Lines changed: 28 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20-
import java.nio.ByteBuffer
21-
22-
import com.google.common.primitives.{Doubles, Ints, Longs}
20+
import org.apache.datasketches.memory.Memory
21+
import org.apache.datasketches.quantiles.{DoublesSketch, DoublesUnion, UpdateDoublesSketch}
2322

2423
import org.apache.spark.SparkException
2524
import org.apache.spark.sql.catalyst.InternalRow
@@ -31,10 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
3130
import org.apache.spark.sql.catalyst.trees.TernaryLike
3231
import org.apache.spark.sql.catalyst.types.PhysicalNumericType
3332
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
34-
import org.apache.spark.sql.catalyst.util.QuantileSummaries
35-
import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats}
3633
import org.apache.spark.sql.types._
37-
import org.apache.spark.util.ArrayImplicits._
3834

3935
/**
4036
* The ApproximatePercentile function returns the approximate percentile(s) of a column at the given
@@ -267,35 +263,40 @@ object ApproximatePercentile {
267263
// The default relative error can be deduced by defaultError = 1.0 / DEFAULT_PERCENTILE_ACCURACY
268264
val DEFAULT_PERCENTILE_ACCURACY: Int = 10000
269265

266+
def nextPowOf2(relativeError: Double): Int = {
267+
val baseK = DoublesSketch.getKFromEpsilon(relativeError, true)
268+
if (baseK == 1 || (baseK & (baseK - 1)) == 0) {
269+
baseK
270+
} else {
271+
Integer.highestOneBit(baseK) * 2
272+
}
273+
}
274+
270275
/**
271276
* PercentileDigest is a probabilistic data structure used for approximating percentiles
272-
* with limited memory. PercentileDigest is backed by [[QuantileSummaries]].
277+
* with limited memory. PercentileDigest is backed by [[DoublesSketch]].
273278
*
274-
* @param summaries underlying probabilistic data structure [[QuantileSummaries]].
279+
* @param sketch underlying probabilistic data structure [[DoublesSketch]].
275280
*/
276-
class PercentileDigest(private var summaries: QuantileSummaries) {
281+
class PercentileDigest(private var sketch: UpdateDoublesSketch) {
277282

278283
def this(relativeError: Double) = {
279-
this(new QuantileSummaries(defaultCompressThreshold, relativeError, compressed = true))
284+
this(DoublesSketch.builder().setK(ApproximatePercentile.nextPowOf2(relativeError)).build())
280285
}
281286

282-
private[sql] def isCompressed: Boolean = summaries.compressed
283-
284-
/** Returns compressed object of [[QuantileSummaries]] */
285-
def quantileSummaries: QuantileSummaries = {
286-
if (!isCompressed) compress()
287-
summaries
288-
}
287+
def sketchInfo: UpdateDoublesSketch = sketch
289288

290289
/** Insert an observation value into the PercentileDigest data structure. */
291290
def add(value: Double): Unit = {
292-
summaries = summaries.insert(value)
291+
sketch.update(value)
293292
}
294293

295294
/** In-place merges in another PercentileDigest. */
296295
def merge(other: PercentileDigest): Unit = {
297-
if (!isCompressed) compress()
298-
summaries = summaries.merge(other.quantileSummaries)
296+
val doublesUnion = DoublesUnion.builder().setMaxK(sketch.getK).build()
297+
doublesUnion.union(sketch)
298+
doublesUnion.union(other.sketch)
299+
sketch = doublesUnion.getResult
299300
}
300301

301302
/**
@@ -309,17 +310,12 @@ object ApproximatePercentile {
309310
* }}}
310311
*/
311312
def getPercentiles(percentages: Array[Double]): Seq[Double] = {
312-
if (!isCompressed) compress()
313-
if (summaries.count == 0 || percentages.length == 0) {
314-
Array.emptyDoubleArray.toImmutableArraySeq
313+
if (!sketch.isEmpty) {
314+
sketch.getQuantiles(percentages).toSeq
315315
} else {
316-
summaries.query(percentages.toImmutableArraySeq).get
316+
Seq.empty[Double]
317317
}
318318
}
319-
320-
private final def compress(): Unit = {
321-
summaries = summaries.compress()
322-
}
323319
}
324320

325321
/**
@@ -329,52 +325,14 @@ object ApproximatePercentile {
329325
*/
330326
class PercentileDigestSerializer {
331327

332-
private final def length(summaries: QuantileSummaries): Int = {
333-
// summaries.compressThreshold, summary.relativeError, summary.count
334-
Ints.BYTES + Doubles.BYTES + Longs.BYTES +
335-
// length of summary.sampled
336-
Ints.BYTES +
337-
// summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)]
338-
summaries.sampled.length * (Doubles.BYTES + Longs.BYTES + Longs.BYTES)
339-
}
340-
341328
final def serialize(obj: PercentileDigest): Array[Byte] = {
342-
val summary = obj.quantileSummaries
343-
val buffer = ByteBuffer.wrap(new Array(length(summary)))
344-
buffer.putInt(summary.compressThreshold)
345-
buffer.putDouble(summary.relativeError)
346-
buffer.putLong(summary.count)
347-
buffer.putInt(summary.sampled.length)
348-
349-
var i = 0
350-
while (i < summary.sampled.length) {
351-
val stat = summary.sampled(i)
352-
buffer.putDouble(stat.value)
353-
buffer.putLong(stat.g)
354-
buffer.putLong(stat.delta)
355-
i += 1
356-
}
357-
buffer.array()
329+
val sketch = obj.sketchInfo
330+
sketch.toByteArray(false)
358331
}
359332

360333
final def deserialize(bytes: Array[Byte]): PercentileDigest = {
361-
val buffer = ByteBuffer.wrap(bytes)
362-
val compressThreshold = buffer.getInt()
363-
val relativeError = buffer.getDouble()
364-
val count = buffer.getLong()
365-
val sampledLength = buffer.getInt()
366-
val sampled = new Array[Stats](sampledLength)
367-
368-
var i = 0
369-
while (i < sampledLength) {
370-
val value = buffer.getDouble()
371-
val g = buffer.getLong()
372-
val delta = buffer.getLong()
373-
sampled(i) = Stats(value, g, delta)
374-
i += 1
375-
}
376-
val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count, true)
377-
new PercentileDigest(summary)
334+
val sketch = DoublesSketch.heapify(Memory.wrap(bytes))
335+
new PercentileDigest(sketch.asInstanceOf[UpdateDoublesSketch])
378336
}
379337
}
380338

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -426,12 +426,18 @@ class ApproximatePercentileSuite extends SparkFunSuite {
426426
}
427427

428428
private def compareEquals(left: PercentileDigest, right: PercentileDigest): Boolean = {
429-
val leftSummary = left.quantileSummaries
430-
val rightSummary = right.quantileSummaries
431-
leftSummary.compressThreshold == rightSummary.compressThreshold &&
432-
leftSummary.relativeError == rightSummary.relativeError &&
433-
leftSummary.count == rightSummary.count &&
434-
leftSummary.sampled.sameElements(rightSummary.sampled)
429+
val leftSketch = left.sketchInfo
430+
val rightSketch = right.sketchInfo
431+
if (leftSketch.isEmpty && rightSketch.isEmpty) {
432+
true
433+
} else if (leftSketch.isEmpty || rightSketch.isEmpty) {
434+
false
435+
} else {
436+
leftSketch.getK == rightSketch.getK &&
437+
leftSketch.getMaxItem == rightSketch.getMaxItem &&
438+
leftSketch.getMinItem == rightSketch.getMinItem &&
439+
leftSketch.getN == rightSketch.getN
440+
}
435441
}
436442

437443
private def assertEqual[T](left: T, right: T): Unit = {

sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.sql
2020
import java.sql.{Date, Timestamp}
2121
import java.time.{Duration, LocalDateTime, Period}
2222

23-
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
2423
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
2524
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
2625
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -291,18 +290,6 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession
291290
}
292291
}
293292

294-
test("SPARK-24013: unneeded compress can cause performance issues with sorted input") {
295-
val buffer = new PercentileDigest(1.0D / ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)
296-
var compressCounts = 0
297-
(1 to 10000000).foreach { i =>
298-
buffer.add(i)
299-
if (buffer.isCompressed) compressCounts += 1
300-
}
301-
assert(compressCounts > 0)
302-
buffer.quantileSummaries
303-
assert(buffer.isCompressed)
304-
}
305-
306293
test("SPARK-32908: maximum target error in percentile_approx") {
307294
withTempView(table) {
308295
spark.read
@@ -318,7 +305,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession
318305
| percentile_approx(col, 0.77, 100000),
319306
| percentile_approx(col, 0.77, 1000000)
320307
|FROM $table""".stripMargin),
321-
Row(18, 17, 17, 17))
308+
Row(17, 17, 17, 17))
322309
}
323310
}
324311

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,8 +1182,8 @@ class DataFrameAggregateSuite extends QueryTest
11821182
approx_percentile(col("earnings"), lit(0.3), lit(1)),
11831183
approx_percentile(col("earnings"), array(lit(0.3), lit(0.6)), lit(1))
11841184
),
1185-
Row("Java", 20000.0, Seq(20000.0, 30000.0), 20000.0, Seq(20000.0, 20000.0)) ::
1186-
Row("dotNET", 5000.0, Seq(5000.0, 10000.0), 5000.0, Seq(5000.0, 5000.0)) :: Nil
1185+
Row("Java", 20000.0, Seq(20000.0, 30000.0), 20000.0, Seq(20000.0, 30000.0)) ::
1186+
Row("dotNET", 5000.0, Seq(5000.0, 10000.0), 5000.0, Seq(5000.0, 10000.0)) :: Nil
11871187
)
11881188
}
11891189

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,19 +270,29 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession {
270270
val Array(s1_1, s2_1) = df.stat.approxQuantile("singles", Array(q1, q2), 1.0)
271271
val Array(Array(ms1_1, ms2_1), Array(md1_1, md2_1)) =
272272
df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), 1.0)
273+
val errorSingle = 1000 * 1.0
274+
val errorDouble = 2.0 * errorSingle
275+
276+
assert(math.abs(single1_1 - q1 * n) <= errorSingle)
277+
assert(math.abs(s1_1 - q1 * n) <= errorSingle)
278+
assert(math.abs(s2_1 - q2 * n) <= errorSingle)
279+
assert(math.abs(ms1_1 - q1 * n) <= errorSingle)
280+
assert(math.abs(ms2_1 - q2 * n) <= errorSingle)
281+
assert(math.abs(md1_1 - 2 * q1 * n) <= errorDouble)
282+
assert(math.abs(md2_1 - 2 * q2 * n) <= errorDouble)
273283

274284
for (epsilon <- epsilons) {
275285
val Array(single1) = df.stat.approxQuantile("singles", Array(q1), epsilon)
276286
val Array(s1, s2) = df.stat.approxQuantile("singles", Array(q1, q2), epsilon)
277287
val Array(Array(ms1, ms2), Array(md1, md2)) =
278288
df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon)
279-
assert(single1_1 === single1)
280-
assert(s1_1 === s1)
281-
assert(s2_1 === s2)
282-
assert(ms1_1 === ms1)
283-
assert(ms2_1 === ms2)
284-
assert(md1_1 === md1)
285-
assert(md2_1 === md2)
289+
assert(math.abs(single1 - q1 * n) <= errorSingle)
290+
assert(math.abs(s1 - q1 * n) <= errorSingle)
291+
assert(math.abs(s2 - q2 * n) <= errorSingle)
292+
assert(math.abs(ms1 - q1 * n) <= errorSingle)
293+
assert(math.abs(ms2 - q2 * n) <= errorSingle)
294+
assert(math.abs(md1 - 2 * q1 * n) <= errorDouble)
295+
assert(math.abs(md2 - 2 * q2 * n) <= errorDouble)
286296
}
287297
}
288298

0 commit comments

Comments
 (0)