|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.catalyst.expressions.aggregate |
19 | 19 |
|
20 | | -import java.nio.ByteBuffer |
21 | | - |
22 | | -import org.apache.spark.SparkConf |
23 | | -import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} |
24 | 20 | import org.apache.spark.sql.AnalysisException |
25 | 21 | import org.apache.spark.sql.catalyst.InternalRow |
26 | 22 | import org.apache.spark.sql.catalyst.analysis.TypeCheckResult |
27 | 23 | import org.apache.spark.sql.catalyst.expressions._ |
28 | 24 | import org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.Countings |
29 | 25 | import org.apache.spark.sql.catalyst.util._ |
30 | 26 | import org.apache.spark.sql.types._ |
| 27 | +import org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET |
31 | 28 | import org.apache.spark.util.collection.OpenHashMap |
32 | 29 |
|
| 30 | + |
33 | 31 | /** |
34 | 32 | * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at |
35 | 33 | * the given percentage(s) with value range in [0.0, 1.0]. |
@@ -138,11 +136,11 @@ case class Percentile( |
138 | 136 | } |
139 | 137 |
|
140 | 138 | override def serialize(obj: Countings): Array[Byte] = { |
141 | | - Percentile.serializer.serialize(obj).array() |
| 139 | + Percentile.serializer.serialize(obj, child.dataType) |
142 | 140 | } |
143 | 141 |
|
144 | 142 | override def deserialize(bytes: Array[Byte]): Countings = { |
145 | | - Percentile.serializer.deserialize[Countings](ByteBuffer.wrap(bytes)) |
| 143 | + Percentile.serializer.deserialize(bytes, child.dataType) |
146 | 144 | } |
147 | 145 | } |
148 | 146 |
|
@@ -236,5 +234,59 @@ object Percentile { |
236 | 234 | } |
237 | 235 | } |
238 | 236 |
|
239 | | - val serializer: SerializerInstance = new KryoSerializer(new SparkConf).newInstance() |
| 237 | + |
| 238 | + /** |
| 239 | + * Serializer for class [[Countings]] |
| 240 | + * |
| 241 | + * This class is thread safe. |
| 242 | + */ |
| 243 | + class CountingsSerializer { |
| 244 | + |
| 245 | + final def serialize(obj: Countings, dataType: DataType): Array[Byte] = { |
| 246 | + val counts = obj.counts |
| 247 | + |
| 248 | + // Write the size of counts map. |
| 249 | + val sizeProjection = UnsafeProjection.create(Array[DataType](IntegerType)) |
| 250 | + val row = InternalRow.apply(counts.size) |
| 251 | + var buffer = sizeProjection.apply(row).getBytes |
| 252 | + |
| 253 | + // Write the pairs of counts map. |
| 254 | + val projection = UnsafeProjection.create(Array[DataType](dataType, LongType)) |
| 255 | + counts.foreach { pair => |
| 256 | + val row = InternalRow.apply(pair._1, pair._2) |
| 257 | + val unsafeRow = projection.apply(row) |
| 258 | + buffer ++= unsafeRow.getBytes |
| 259 | + } |
| 260 | + |
| 261 | + buffer |
| 262 | + } |
| 263 | + |
| 264 | + final def deserialize(bytes: Array[Byte], dataType: DataType): Countings = { |
| 265 | + val counts = new OpenHashMap[Number, Long] |
| 266 | + var offset = 0 |
| 267 | + |
| 268 | + // Read the size of counts map |
| 269 | + val sizeRow = new UnsafeRow(1) |
| 270 | + val rowSizeInBytes = UnsafeRow.calculateFixedPortionByteSize(1) |
| 271 | + sizeRow.pointTo(bytes, rowSizeInBytes) |
| 272 | + val size = sizeRow.get(0, IntegerType).asInstanceOf[Integer] |
| 273 | + offset += rowSizeInBytes |
| 274 | + |
| 275 | + // Read the pairs of counts map |
| 276 | + val row = new UnsafeRow(2) |
| 277 | + val pairRowSizeInBytes = UnsafeRow.calculateFixedPortionByteSize(2) |
| 278 | + var i = 0 |
| 279 | + while (i < size) { |
| 280 | + row.pointTo(bytes, offset + BYTE_ARRAY_OFFSET, pairRowSizeInBytes) |
| 281 | + val key = row.get(0, dataType).asInstanceOf[Number] |
| 282 | + val count = row.get(1, LongType).asInstanceOf[Long] |
| 283 | + offset += pairRowSizeInBytes |
| 284 | + counts.update(key, count) |
| 285 | + i += 1 |
| 286 | + } |
| 287 | + Countings(counts) |
| 288 | + } |
| 289 | + } |
| 290 | + |
| 291 | + val serializer: CountingsSerializer = new CountingsSerializer |
240 | 292 | } |
0 commit comments