Skip to content

Commit e01d0b2

Browse files
committed
Implement serializer for Percentile.
1 parent 4ace3bc commit e01d0b2

File tree

2 files changed

+64
-10
lines changed

2 files changed

+64
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,17 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20-
import java.nio.ByteBuffer
21-
22-
import org.apache.spark.SparkConf
23-
import org.apache.spark.serializer.{KryoSerializer, SerializerInstance}
2420
import org.apache.spark.sql.AnalysisException
2521
import org.apache.spark.sql.catalyst.InternalRow
2622
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2723
import org.apache.spark.sql.catalyst.expressions._
2824
import org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.Countings
2925
import org.apache.spark.sql.catalyst.util._
3026
import org.apache.spark.sql.types._
27+
import org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET
3128
import org.apache.spark.util.collection.OpenHashMap
3229

30+
3331
/**
3432
* The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at
3533
* the given percentage(s) with value range in [0.0, 1.0].
@@ -138,11 +136,11 @@ case class Percentile(
138136
}
139137

140138
override def serialize(obj: Countings): Array[Byte] = {
141-
Percentile.serializer.serialize(obj).array()
139+
Percentile.serializer.serialize(obj, child.dataType)
142140
}
143141

144142
override def deserialize(bytes: Array[Byte]): Countings = {
145-
Percentile.serializer.deserialize[Countings](ByteBuffer.wrap(bytes))
143+
Percentile.serializer.deserialize(bytes, child.dataType)
146144
}
147145
}
148146

@@ -236,5 +234,59 @@ object Percentile {
236234
}
237235
}
238236

239-
val serializer: SerializerInstance = new KryoSerializer(new SparkConf).newInstance()
237+
238+
/**
239+
* Serializer for class [[Countings]]
240+
*
241+
* This class is thread safe.
242+
*/
243+
class CountingsSerializer {
244+
245+
final def serialize(obj: Countings, dataType: DataType): Array[Byte] = {
246+
val counts = obj.counts
247+
248+
// Write the size of counts map.
249+
val sizeProjection = UnsafeProjection.create(Array[DataType](IntegerType))
250+
val row = InternalRow.apply(counts.size)
251+
var buffer = sizeProjection.apply(row).getBytes
252+
253+
// Write the pairs of counts map.
254+
val projection = UnsafeProjection.create(Array[DataType](dataType, LongType))
255+
counts.foreach { pair =>
256+
val row = InternalRow.apply(pair._1, pair._2)
257+
val unsafeRow = projection.apply(row)
258+
buffer ++= unsafeRow.getBytes
259+
}
260+
261+
buffer
262+
}
263+
264+
final def deserialize(bytes: Array[Byte], dataType: DataType): Countings = {
265+
val counts = new OpenHashMap[Number, Long]
266+
var offset = 0
267+
268+
// Read the size of counts map
269+
val sizeRow = new UnsafeRow(1)
270+
val rowSizeInBytes = UnsafeRow.calculateFixedPortionByteSize(1)
271+
sizeRow.pointTo(bytes, rowSizeInBytes)
272+
val size = sizeRow.get(0, IntegerType).asInstanceOf[Integer]
273+
offset += rowSizeInBytes
274+
275+
// Read the pairs of counts map
276+
val row = new UnsafeRow(2)
277+
val pairRowSizeInBytes = UnsafeRow.calculateFixedPortionByteSize(2)
278+
var i = 0
279+
while (i < size) {
280+
row.pointTo(bytes, offset + BYTE_ARRAY_OFFSET, pairRowSizeInBytes)
281+
val key = row.get(0, dataType).asInstanceOf[Number]
282+
val count = row.get(1, LongType).asInstanceOf[Long]
283+
offset += pairRowSizeInBytes
284+
counts.update(key, count)
285+
i += 1
286+
}
287+
Countings(counts)
288+
}
289+
}
290+
291+
val serializer: CountingsSerializer = new CountingsSerializer
240292
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,17 @@ class PercentileSuite extends SparkFunSuite {
4040

4141
// Check empty serialize and de-serialize
4242
val emptyBuffer = Countings()
43-
assert(compareEquals(emptyBuffer, serializer.deserialize(serializer.serialize(emptyBuffer))))
43+
assert(compareEquals(emptyBuffer,
44+
serializer.deserialize(serializer.serialize(emptyBuffer, DoubleType), DoubleType)))
4445

4546
val buffer = Countings()
4647
data.foreach { value =>
4748
buffer.add(value)
4849
}
49-
assert(compareEquals(buffer, serializer.deserialize(serializer.serialize(buffer))))
50+
assert(compareEquals(buffer,
51+
serializer.deserialize(serializer.serialize(buffer, IntegerType), IntegerType)))
5052

51-
val agg = new Percentile(BoundReference(0, DoubleType, true), Literal(0.5))
53+
val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5))
5254
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
5355
}
5456

0 commit comments

Comments
 (0)