Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.apache.spark.unsafe.Platform;

/**
* Simulates Hive's hashing function at
* Simulates Hive's hashing function from Hive v1.2.1
* org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode()
*/
public class HiveHasher {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,10 +573,9 @@ object XxHash64Function extends InterpretedHashFunction {
}
}


/**
* Simulates Hive's hashing function at
* org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() in Hive
* Simulates Hive's hashing function from Hive v1.2.1 at
* org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode()
*
* We should use this hash function for both shuffle and bucket of Hive tables, so that
* we can guarantee shuffle and bucketing have same data distribution
Expand All @@ -595,7 +594,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
override protected def hasherClassName: String = classOf[HiveHasher].getName

override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = {
HiveHashFunction.hash(value, dataType, seed).toInt
HiveHashFunction.hash(value, dataType, this.seed).toInt
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -781,12 +780,12 @@ object HiveHashFunction extends InterpretedHashFunction {
var i = 0
val length = struct.numFields
while (i < length) {
result = (31 * result) + hash(struct.get(i, types(i)), types(i), seed + 1).toInt
result = (31 * result) + hash(struct.get(i, types(i)), types(i), 0).toInt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the reason?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The seed is something used in murmur3 hash and hive hash does not need it. See original impl in Hive codebase : https://github.com/apache/hive/blob/4ba713ccd85c3706d195aeef9476e6e6363f1c21/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java#L638

Since the methods related to hashing in Spark already had seed, I had to add it in hive-hash. When I compute the hash, I always need to set seed to 0 which is what is done here.

i += 1
}
result

case _ => super.hash(value, dataType, seed)
case _ => super.hash(value, dataType, 0)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@ package org.apache.spark.sql.catalyst.expressions

import java.nio.charset.StandardCharsets

import scala.collection.mutable.ArrayBuffer

import org.apache.commons.codec.digest.DigestUtils

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{ArrayType, StructType, _}
import org.apache.spark.unsafe.types.UTF8String

class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val random = new scala.util.Random

test("md5") {
checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))),
Expand Down Expand Up @@ -71,6 +75,247 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
}


def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = {
// Note : All expected hashes need to be computed using Hive 1.2.1
val actual = HiveHashFunction.hash(input, dataType, seed = 0)

withClue(s"hash mismatch for input = `$input` of type `$dataType`.") {
assert(actual == expected)
}
}

def checkHiveHashForIntegralType(dataType: DataType): Unit = {
// corner cases
checkHiveHash(null, dataType, 0)
checkHiveHash(1, dataType, 1)
checkHiveHash(0, dataType, 0)
checkHiveHash(-1, dataType, -1)
checkHiveHash(Int.MaxValue, dataType, Int.MaxValue)
checkHiveHash(Int.MinValue, dataType, Int.MinValue)

// random values
for (_ <- 0 until 10) {
val input = random.nextInt()
checkHiveHash(input, dataType, input)
}
}

test("hive-hash for null") {
checkHiveHash(null, NullType, 0)
}

test("hive-hash for boolean") {
checkHiveHash(true, BooleanType, 1)
checkHiveHash(false, BooleanType, 0)
}

test("hive-hash for byte") {
checkHiveHashForIntegralType(ByteType)
}

test("hive-hash for short") {
checkHiveHashForIntegralType(ShortType)
}

test("hive-hash for int") {
checkHiveHashForIntegralType(IntegerType)
}

test("hive-hash for long") {
checkHiveHash(1L, LongType, 1L)
checkHiveHash(0L, LongType, 0L)
checkHiveHash(-1L, LongType, 0L)
checkHiveHash(Long.MaxValue, LongType, -2147483648)
// Hive's fails to parse this.. but the hashing function itself can handle this input
checkHiveHash(Long.MinValue, LongType, -2147483648)

for (_ <- 0 until 10) {
val input = random.nextLong()
checkHiveHash(input, LongType, ((input >>> 32) ^ input).toInt)
}
}

test("hive-hash for float") {
checkHiveHash(0F, FloatType, 0)
checkHiveHash(0.0F, FloatType, 0)
checkHiveHash(1.1F, FloatType, 1066192077L)
checkHiveHash(-1.1F, FloatType, -1081291571)
checkHiveHash(99999999.99999999999F, FloatType, 1287568416L)
checkHiveHash(Float.MaxValue, FloatType, 2139095039)
checkHiveHash(Float.MinValue, FloatType, -8388609)
}

test("hive-hash for double") {
checkHiveHash(0, DoubleType, 0)
checkHiveHash(0.0, DoubleType, 0)
checkHiveHash(1.1, DoubleType, -1503133693)
checkHiveHash(-1.1, DoubleType, 644349955)
checkHiveHash(1000000000.000001, DoubleType, 1104006509)
checkHiveHash(1000000000.0000000000000000000000001, DoubleType, 1104006501)
checkHiveHash(9999999999999999999.9999999999999999999, DoubleType, 594568676)
checkHiveHash(Double.MaxValue, DoubleType, -2146435072)
checkHiveHash(Double.MinValue, DoubleType, 1048576)
}

test("hive-hash for string") {
checkHiveHash(UTF8String.fromString("apache spark"), StringType, 1142704523L)
checkHiveHash(UTF8String.fromString("!@#$%^&*()_+=-"), StringType, -613724358L)
checkHiveHash(UTF8String.fromString("abcdefghijklmnopqrstuvwxyz"), StringType, 958031277L)
checkHiveHash(UTF8String.fromString("AbCdEfGhIjKlMnOpQrStUvWxYz012"), StringType, -648013852L)
// scalastyle:off nonascii
checkHiveHash(UTF8String.fromString("数据砖头"), StringType, -898686242L)
checkHiveHash(UTF8String.fromString("नमस्ते"), StringType, 2006045948L)
// scalastyle:on nonascii
}

test("hive-hash for array") {
// empty array
checkHiveHash(
input = new GenericArrayData(Array[Int]()),
dataType = ArrayType(IntegerType, containsNull = false),
expected = 0)

// basic case
checkHiveHash(
input = new GenericArrayData(Array(1, 10000, Int.MaxValue)),
dataType = ArrayType(IntegerType, containsNull = false),
expected = -2147172688L)

// with negative values
checkHiveHash(
input = new GenericArrayData(Array(-1L, 0L, 999L, Int.MinValue.toLong)),
dataType = ArrayType(LongType, containsNull = false),
expected = -2147452680L)

// with nulls only
val arrayTypeWithNull = ArrayType(IntegerType, containsNull = true)
checkHiveHash(
input = new GenericArrayData(Array(null, null)),
dataType = arrayTypeWithNull,
expected = 0)

// mix with null
checkHiveHash(
input = new GenericArrayData(Array(-12221, 89, null, 767)),
dataType = arrayTypeWithNull,
expected = -363989515)

// nested with array
checkHiveHash(
input = new GenericArrayData(
Array(
new GenericArrayData(Array(1234L, -9L, 67L)),
new GenericArrayData(Array(null, null)),
new GenericArrayData(Array(55L, -100L, -2147452680L))
)),
dataType = ArrayType(ArrayType(LongType)),
expected = -1007531064)

// nested with map
checkHiveHash(
input = new GenericArrayData(
Array(
new ArrayBasedMapData(
new GenericArrayData(Array(-99, 1234)),
new GenericArrayData(Array(UTF8String.fromString("sql"), null))),
new ArrayBasedMapData(
new GenericArrayData(Array(67)),
new GenericArrayData(Array(UTF8String.fromString("apache spark"))))
)),
dataType = ArrayType(MapType(IntegerType, StringType)),
expected = 1139205955)
}

test("hive-hash for map") {
val mapType = MapType(IntegerType, StringType)

// empty map
checkHiveHash(
input = new ArrayBasedMapData(new GenericArrayData(Array()), new GenericArrayData(Array())),
dataType = mapType,
expected = 0)

// basic case
checkHiveHash(
input = new ArrayBasedMapData(
new GenericArrayData(Array(1, 2)),
new GenericArrayData(Array(UTF8String.fromString("foo"), UTF8String.fromString("bar")))),
dataType = mapType,
expected = 198872)

// with null value
checkHiveHash(
input = new ArrayBasedMapData(
new GenericArrayData(Array(55, -99)),
new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))),
dataType = mapType,
expected = 1142704473)

// nesting (only values can be nested as keys have to be primitive datatype)
val nestedMapType = MapType(IntegerType, MapType(IntegerType, StringType))
checkHiveHash(
input = new ArrayBasedMapData(
new GenericArrayData(Array(1, -100)),
new GenericArrayData(
Array(
new ArrayBasedMapData(
new GenericArrayData(Array(-99, 1234)),
new GenericArrayData(Array(UTF8String.fromString("sql"), null))),
new ArrayBasedMapData(
new GenericArrayData(Array(67)),
new GenericArrayData(Array(UTF8String.fromString("apache spark"))))
))),
dataType = nestedMapType,
expected = -1142817416)
}

test("hive-hash for struct") {
// basic
val row = new GenericInternalRow(Array[Any](1, 2, 3))
checkHiveHash(
input = row,
dataType =
new StructType()
.add("col1", IntegerType)
.add("col2", IntegerType)
.add("col3", IntegerType),
expected = 1026)

// mix of several datatypes
val structType = new StructType()
.add("null", NullType)
.add("boolean", BooleanType)
.add("byte", ByteType)
.add("short", ShortType)
.add("int", IntegerType)
.add("long", LongType)
.add("arrayOfString", arrayOfString)
.add("mapOfString", mapOfString)

val rowValues = new ArrayBuffer[Any]()
rowValues += null
rowValues += true
rowValues += 1
rowValues += 2
rowValues += Int.MaxValue
rowValues += Long.MinValue
rowValues += new GenericArrayData(Array(
UTF8String.fromString("apache spark"),
UTF8String.fromString("hello world")
))
rowValues += new ArrayBasedMapData(
new GenericArrayData(Array(UTF8String.fromString("project"), UTF8String.fromString("meta"))),
new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))
)

val row2 = new GenericInternalRow(rowValues.toArray)
checkHiveHash(
input = row2,
dataType = structType,
expected = -2119012447)
}

private val structOfString = new StructType().add("str", StringType)
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
private val arrayOfString = ArrayType(StringType)
Expand Down