-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-13072][SQL] simplify and improve murmur3 hash expression codegen #10974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -325,36 +325,62 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression | |
|
|
||
| override def genCode(ctx: CodegenContext, ev: ExprCode): String = { | ||
| ev.isNull = "false" | ||
| val childrenHash = children.zipWithIndex.map { | ||
| case (child, dt) => | ||
| val childGen = child.gen(ctx) | ||
| val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx) | ||
| s""" | ||
| ${childGen.code} | ||
| if (!${childGen.isNull}) { | ||
| ${childHash.code} | ||
| ${ev.value} = ${childHash.value}; | ||
| } | ||
| """ | ||
| val childrenHash = children.map { child => | ||
| val childGen = child.gen(ctx) | ||
| childGen.code + generateNullCheck(child.nullable, childGen.isNull) { | ||
| computeHash(childGen.value, child.dataType, ev.value, ctx) | ||
| } | ||
| }.mkString("\n") | ||
|
|
||
| s""" | ||
| int ${ev.value} = $seed; | ||
| $childrenHash | ||
| """ | ||
| } | ||
|
|
||
| private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = { | ||
| if (nullable) { | ||
| s""" | ||
| if (!$isNull) { | ||
| $execution | ||
| } | ||
| """ | ||
| } else { | ||
| "\n" + execution | ||
| } | ||
| } | ||
|
|
||
| private def nullSafeElementHash( | ||
| input: String, | ||
| index: String, | ||
| nullable: Boolean, | ||
| elementType: DataType, | ||
| result: String, | ||
| ctx: CodegenContext): String = { | ||
| val element = ctx.freshName("element") | ||
|
|
||
| generateNullCheck(nullable, s"$input.isNullAt($index)") { | ||
| s""" | ||
| final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; | ||
| ${computeHash(element, elementType, result, ctx)} | ||
| """ | ||
| } | ||
| } | ||
|
|
||
| private def computeHash( | ||
| input: String, | ||
| dataType: DataType, | ||
| seed: String, | ||
| ctx: CodegenContext): ExprCode = { | ||
| result: String, | ||
| ctx: CodegenContext): String = { | ||
| val hasher = classOf[Murmur3_x86_32].getName | ||
| def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)") | ||
| def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)") | ||
| def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v) | ||
|
|
||
| def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);" | ||
| def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);" | ||
| def hashBytes(b: String): String = | ||
| s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);" | ||
|
|
||
| dataType match { | ||
| case NullType => inlineValue(seed) | ||
| case NullType => "" | ||
| case BooleanType => hashInt(s"$input ? 1 : 0") | ||
| case ByteType | ShortType | IntegerType | DateType => hashInt(input) | ||
| case LongType | TimestampType => hashLong(input) | ||
|
|
@@ -365,91 +391,48 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression | |
| hashLong(s"$input.toUnscaledLong()") | ||
| } else { | ||
| val bytes = ctx.freshName("bytes") | ||
| val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();" | ||
| val offset = "Platform.BYTE_ARRAY_OFFSET" | ||
| val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)" | ||
| ExprCode(code, "false", result) | ||
| s""" | ||
| final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); | ||
| ${hashBytes(bytes)} | ||
| """ | ||
| } | ||
| case CalendarIntervalType => | ||
| val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)" | ||
| val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)" | ||
| inlineValue(monthsHash) | ||
| case BinaryType => | ||
| val offset = "Platform.BYTE_ARRAY_OFFSET" | ||
| inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)") | ||
| val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)" | ||
| s"$result = $hasher.hashInt($input.months, $microsecondsHash);" | ||
| case BinaryType => hashBytes(input) | ||
| case StringType => | ||
| val baseObject = s"$input.getBaseObject()" | ||
| val baseOffset = s"$input.getBaseOffset()" | ||
| val numBytes = s"$input.numBytes()" | ||
| inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)") | ||
| s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" | ||
|
|
||
| case ArrayType(et, _) => | ||
| val result = ctx.freshName("result") | ||
| case ArrayType(et, containsNull) => | ||
| val index = ctx.freshName("index") | ||
| val element = ctx.freshName("element") | ||
| val elementHash = computeHash(element, et, result, ctx) | ||
| val code = | ||
| s""" | ||
| int $result = $seed; | ||
| for (int $index = 0; $index < $input.numElements(); $index++) { | ||
| if (!$input.isNullAt($index)) { | ||
| final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)}; | ||
| ${elementHash.code} | ||
| $result = ${elementHash.value}; | ||
| } | ||
| } | ||
| """ | ||
| ExprCode(code, "false", result) | ||
| s""" | ||
| for (int $index = 0; $index < $input.numElements(); $index++) { | ||
| ${nullSafeElementHash(input, index, containsNull, et, result, ctx)} | ||
| } | ||
| """ | ||
|
|
||
| case MapType(kt, vt, _) => | ||
| val result = ctx.freshName("result") | ||
| case MapType(kt, vt, valueContainsNull) => | ||
| val index = ctx.freshName("index") | ||
| val keys = ctx.freshName("keys") | ||
| val values = ctx.freshName("values") | ||
| val key = ctx.freshName("key") | ||
| val value = ctx.freshName("value") | ||
| val keyHash = computeHash(key, kt, result, ctx) | ||
| val valueHash = computeHash(value, vt, result, ctx) | ||
| val code = | ||
| s""" | ||
| int $result = $seed; | ||
| final ArrayData $keys = $input.keyArray(); | ||
| final ArrayData $values = $input.valueArray(); | ||
| for (int $index = 0; $index < $input.numElements(); $index++) { | ||
| final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)}; | ||
| ${keyHash.code} | ||
| $result = ${keyHash.value}; | ||
| if (!$values.isNullAt($index)) { | ||
| final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)}; | ||
| ${valueHash.code} | ||
| $result = ${valueHash.value}; | ||
| } | ||
| } | ||
| """ | ||
| ExprCode(code, "false", result) | ||
| s""" | ||
| final ArrayData $keys = $input.keyArray(); | ||
| final ArrayData $values = $input.valueArray(); | ||
| for (int $index = 0; $index < $input.numElements(); $index++) { | ||
| ${nullSafeElementHash(keys, index, false, kt, result, ctx)} | ||
| ${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)} | ||
| } | ||
| """ | ||
|
|
||
| case StructType(fields) => | ||
| val result = ctx.freshName("result") | ||
| val fieldsHash = fields.map(_.dataType).zipWithIndex.map { | ||
| case (dt, index) => | ||
| val field = ctx.freshName("field") | ||
| val fieldHash = computeHash(field, dt, result, ctx) | ||
| s""" | ||
| if (!$input.isNullAt($index)) { | ||
| final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)}; | ||
| ${fieldHash.code} | ||
| $result = ${fieldHash.value}; | ||
| } | ||
| """ | ||
| fields.zipWithIndex.map { case (field, index) => | ||
| nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) | ||
| }.mkString("\n") | ||
| val code = | ||
| s""" | ||
| int $result = $seed; | ||
| $fieldsHash | ||
| """ | ||
| ExprCode(code, "false", result) | ||
|
|
||
| case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx) | ||
| case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bug fix right? Do we have tests for this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's not a bug, actually I rename |
||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there other places we can use this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, we can promote it to
CodeGenContextand use it in other places in a follow-up PR.