Skip to content

Commit def94cc

Browse files
committed
Fix writing DecimalType bug in RowBasedHashMapGenerator.
1 parent 122cf18 commit def94cc

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,16 @@ class RowBasedHashMapGenerator(
141141
}
142142

143143
val createUnsafeRowForKey = groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
144-
s"agg_rowWriter.write(${ordinal}, ${key.name})"}
145-
.mkString(";\n")
144+
key.dataType match {
145+
case t: DecimalType =>
146+
s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})"
147+
case t: DataType =>
148+
if (!t.isInstanceOf[StringType] && !ctx.isPrimitiveType(t)) {
149+
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t")
150+
}
151+
s"agg_rowWriter.write(${ordinal}, ${key.name})"
152+
}
153+
}.mkString(";\n")
146154

147155
s"""
148156
|public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(${

0 commit comments

Comments
 (0)