Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -325,36 +325,62 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression

override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
ev.isNull = "false"
val childrenHash = children.zipWithIndex.map {
case (child, dt) =>
val childGen = child.gen(ctx)
val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx)
s"""
${childGen.code}
if (!${childGen.isNull}) {
${childHash.code}
${ev.value} = ${childHash.value};
}
"""
val childrenHash = children.map { child =>
val childGen = child.gen(ctx)
childGen.code + generateNullCheck(child.nullable, childGen.isNull) {
computeHash(childGen.value, child.dataType, ev.value, ctx)
}
}.mkString("\n")

s"""
int ${ev.value} = $seed;
$childrenHash
"""
}

private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there other places we can use this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, we can promote it to CodeGenContext and use it in other places in a follow-up PR.

if (nullable) {
s"""
if (!$isNull) {
$execution
}
"""
} else {
"\n" + execution
}
}

private def nullSafeElementHash(
input: String,
index: String,
nullable: Boolean,
elementType: DataType,
result: String,
ctx: CodegenContext): String = {
val element = ctx.freshName("element")

generateNullCheck(nullable, s"$input.isNullAt($index)") {
s"""
final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
${computeHash(element, elementType, result, ctx)}
"""
}
}

private def computeHash(
input: String,
dataType: DataType,
seed: String,
ctx: CodegenContext): ExprCode = {
result: String,
ctx: CodegenContext): String = {
val hasher = classOf[Murmur3_x86_32].getName
def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)")
def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)")
def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v)

def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);"
def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);"
def hashBytes(b: String): String =
s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);"

dataType match {
case NullType => inlineValue(seed)
case NullType => ""
case BooleanType => hashInt(s"$input ? 1 : 0")
case ByteType | ShortType | IntegerType | DateType => hashInt(input)
case LongType | TimestampType => hashLong(input)
Expand All @@ -365,91 +391,48 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
hashLong(s"$input.toUnscaledLong()")
} else {
val bytes = ctx.freshName("bytes")
val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();"
val offset = "Platform.BYTE_ARRAY_OFFSET"
val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)"
ExprCode(code, "false", result)
s"""
final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();
${hashBytes(bytes)}
"""
}
case CalendarIntervalType =>
val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)"
val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)"
inlineValue(monthsHash)
case BinaryType =>
val offset = "Platform.BYTE_ARRAY_OFFSET"
inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)")
val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)"
s"$result = $hasher.hashInt($input.months, $microsecondsHash);"
case BinaryType => hashBytes(input)
case StringType =>
val baseObject = s"$input.getBaseObject()"
val baseOffset = s"$input.getBaseOffset()"
val numBytes = s"$input.numBytes()"
inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)")
s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);"

case ArrayType(et, _) =>
val result = ctx.freshName("result")
case ArrayType(et, containsNull) =>
val index = ctx.freshName("index")
val element = ctx.freshName("element")
val elementHash = computeHash(element, et, result, ctx)
val code =
s"""
int $result = $seed;
for (int $index = 0; $index < $input.numElements(); $index++) {
if (!$input.isNullAt($index)) {
final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)};
${elementHash.code}
$result = ${elementHash.value};
}
}
"""
ExprCode(code, "false", result)
s"""
for (int $index = 0; $index < $input.numElements(); $index++) {
${nullSafeElementHash(input, index, containsNull, et, result, ctx)}
}
"""

case MapType(kt, vt, _) =>
val result = ctx.freshName("result")
case MapType(kt, vt, valueContainsNull) =>
val index = ctx.freshName("index")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val key = ctx.freshName("key")
val value = ctx.freshName("value")
val keyHash = computeHash(key, kt, result, ctx)
val valueHash = computeHash(value, vt, result, ctx)
val code =
s"""
int $result = $seed;
final ArrayData $keys = $input.keyArray();
final ArrayData $values = $input.valueArray();
for (int $index = 0; $index < $input.numElements(); $index++) {
final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)};
${keyHash.code}
$result = ${keyHash.value};
if (!$values.isNullAt($index)) {
final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)};
${valueHash.code}
$result = ${valueHash.value};
}
}
"""
ExprCode(code, "false", result)
s"""
final ArrayData $keys = $input.keyArray();
final ArrayData $values = $input.valueArray();
for (int $index = 0; $index < $input.numElements(); $index++) {
${nullSafeElementHash(keys, index, false, kt, result, ctx)}
${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)}
}
"""

case StructType(fields) =>
val result = ctx.freshName("result")
val fieldsHash = fields.map(_.dataType).zipWithIndex.map {
case (dt, index) =>
val field = ctx.freshName("field")
val fieldHash = computeHash(field, dt, result, ctx)
s"""
if (!$input.isNullAt($index)) {
final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)};
${fieldHash.code}
$result = ${fieldHash.value};
}
"""
fields.zipWithIndex.map { case (field, index) =>
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
}.mkString("\n")
val code =
s"""
int $result = $seed;
$fieldsHash
"""
ExprCode(code, "false", result)

case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx)
case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug fix right? Do we have tests for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not a bug, actually I rename seed to result in this method, to make it clear that we use the previous result as seed to produce new result.

}
}
}