@@ -48,6 +48,12 @@ class RowBasedHashMapGenerator(
4848 val keySchema = ctx.addReferenceObj(" keySchemaTerm" , groupingKeySchema)
4949 val valueSchema = ctx.addReferenceObj(" valueSchemaTerm" , bufferSchema)
5050
51+ val numVarLenFields = groupingKeys.map(_.dataType).count {
52+ case dt if UnsafeRow .isFixedLength(dt) => false
53+ // TODO: consider large decimal and interval type
54+ case _ => true
55+ }
56+
5157 s """
5258 | private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
5359 | private int[] buckets;
@@ -60,6 +66,7 @@ class RowBasedHashMapGenerator(
6066 | private long emptyVOff;
6167 | private int emptyVLen;
6268 | private boolean isBatchFull = false;
69+ | private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
6370 |
6471 |
6572 | public $generatedClassName(
@@ -75,6 +82,9 @@ class RowBasedHashMapGenerator(
7582 | emptyVOff = Platform.BYTE_ARRAY_OFFSET;
7683 | emptyVLen = emptyBuffer.length;
7784 |
85+ | agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
86+ | ${groupingKeySchema.length}, ${numVarLenFields * 32 });
87+ |
7888 | buckets = new int[numBuckets];
7989 | java.util.Arrays.fill(buckets, -1);
8090 | }
@@ -112,12 +122,6 @@ class RowBasedHashMapGenerator(
112122 *
113123 */
114124 protected def generateFindOrInsert (): String = {
115- val numVarLenFields = groupingKeys.map(_.dataType).count {
116- case dt if UnsafeRow .isFixedLength(dt) => false
117- // TODO: consider large decimal and interval type
118- case _ => true
119- }
120-
121125 val createUnsafeRowForKey = groupingKeys.zipWithIndex.map { case (key : Buffer , ordinal : Int ) =>
122126 key.dataType match {
123127 case t : DecimalType =>
@@ -130,6 +134,12 @@ class RowBasedHashMapGenerator(
130134 }
131135 }.mkString(" ;\n " )
132136
137+ val resetNullBits = if (groupingKeySchema.map(_.nullable).forall(_ == false )) {
138+ " "
139+ } else {
140+ " agg_rowWriter.zeroOutNullBytes();"
141+ }
142+
133143 s """
134144 |public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert( ${
135145 groupingKeySignature}) {
@@ -140,12 +150,8 @@ class RowBasedHashMapGenerator(
140150 | // Return bucket index if it's either an empty slot or already contains the key
141151 | if (buckets[idx] == -1) {
142152 | if (numRows < capacity && !isBatchFull) {
143- | // creating the unsafe for new entry
144- | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter
145- | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
146- | ${groupingKeySchema.length}, ${numVarLenFields * 32 });
147- | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed
148- | agg_rowWriter.zeroOutNullBytes();
153+ | agg_rowWriter.reset();
154+ | $resetNullBits
149155 | ${createUnsafeRowForKey};
150156 | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result
151157 | = agg_rowWriter.getRow();
0 commit comments