Skip to content

Commit e7853dc

Browse files
heary-caocloud-fan
authored andcommitted
[SPARK-24999][SQL] Reduce unnecessary 'new' memory operations
## What changes were proposed in this pull request? This PR is to solve the CodeGen code generated by fast hash, and there is no need to apply for a block of memory for every new entry, because unsafeRow's memory can be reused. ## How was this patch tested? the existed test cases. Closes #21968 from heary-cao/updateNewMemory. Authored-by: caoxuewen <cao.xuewen@zte.com.cn> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent f8b4d5a commit e7853dc

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ class RowBasedHashMapGenerator(
4848
val keySchema = ctx.addReferenceObj("keySchemaTerm", groupingKeySchema)
4949
val valueSchema = ctx.addReferenceObj("valueSchemaTerm", bufferSchema)
5050

51+
val numVarLenFields = groupingKeys.map(_.dataType).count {
52+
case dt if UnsafeRow.isFixedLength(dt) => false
53+
// TODO: consider large decimal and interval type
54+
case _ => true
55+
}
56+
5157
s"""
5258
| private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
5359
| private int[] buckets;
@@ -60,6 +66,7 @@ class RowBasedHashMapGenerator(
6066
| private long emptyVOff;
6167
| private int emptyVLen;
6268
| private boolean isBatchFull = false;
69+
| private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
6370
|
6471
|
6572
| public $generatedClassName(
@@ -75,6 +82,9 @@ class RowBasedHashMapGenerator(
7582
| emptyVOff = Platform.BYTE_ARRAY_OFFSET;
7683
| emptyVLen = emptyBuffer.length;
7784
|
85+
| agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
86+
| ${groupingKeySchema.length}, ${numVarLenFields * 32});
87+
|
7888
| buckets = new int[numBuckets];
7989
| java.util.Arrays.fill(buckets, -1);
8090
| }
@@ -112,12 +122,6 @@ class RowBasedHashMapGenerator(
112122
*
113123
*/
114124
protected def generateFindOrInsert(): String = {
115-
val numVarLenFields = groupingKeys.map(_.dataType).count {
116-
case dt if UnsafeRow.isFixedLength(dt) => false
117-
// TODO: consider large decimal and interval type
118-
case _ => true
119-
}
120-
121125
val createUnsafeRowForKey = groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
122126
key.dataType match {
123127
case t: DecimalType =>
@@ -130,6 +134,12 @@ class RowBasedHashMapGenerator(
130134
}
131135
}.mkString(";\n")
132136

137+
val resetNullBits = if (groupingKeySchema.map(_.nullable).forall(_ == false)) {
138+
""
139+
} else {
140+
"agg_rowWriter.zeroOutNullBytes();"
141+
}
142+
133143
s"""
134144
|public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(${
135145
groupingKeySignature}) {
@@ -140,12 +150,8 @@ class RowBasedHashMapGenerator(
140150
| // Return bucket index if it's either an empty slot or already contains the key
141151
| if (buckets[idx] == -1) {
142152
| if (numRows < capacity && !isBatchFull) {
143-
| // creating the unsafe for new entry
144-
| org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter
145-
| = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
146-
| ${groupingKeySchema.length}, ${numVarLenFields * 32});
147-
| agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed
148-
| agg_rowWriter.zeroOutNullBytes();
153+
| agg_rowWriter.reset();
154+
| $resetNullBits
149155
| ${createUnsafeRowForKey};
150156
| org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result
151157
| = agg_rowWriter.getRow();

0 commit comments

Comments
 (0)