Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i
}

public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
// This is not compatible with original and another implementations.
// But remain it for backward compatibility for the components existing before 2.3.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: touch this up. The sentence is a bit weird, i.e.: However we retain this implementation for backwards compatibility with pre-existing (pre 2.3) components.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will correct it in the follow-up PR.

assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
Expand All @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i
return fmix(h1, lengthInBytes);
}

public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) {
// This is compatible with original and another implementations.
// Use this method for new components after Spark 2.3.
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
int k1 = 0;
for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) {
k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift;
}
h1 ^= mixK1(k1);
return fmix(h1, lengthInBytes);
}

private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
assert (lengthInBytes % 4 == 0);
int h1 = seed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i
}

public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
// This is not compatible with original and another implementations.
// But remain it for backward compatibility for the components existing before 2.3.
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
Expand All @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i
return fmix(h1, lengthInBytes);
}

public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) {
// This is compatible with original and another implementations.
// Use this method for new components after Spark 2.3.
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
int k1 = 0;
for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) {
k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift;
}
h1 ^= mixK1(k1);
return fmix(h1, lengthInBytes);
}

private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
assert (lengthInBytes % 4 == 0);
int h1 = seed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.util.Random;
import java.util.Set;

import scala.util.hashing.MurmurHash3$;

import org.apache.spark.unsafe.Platform;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -51,6 +53,23 @@ public void testKnownLongInputs() {
Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE));
}

// SPARK-23381 Check whether the hash of the byte array is the same as another implementations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add a randomized (disabled) test here. That might increase the confidence we have in the new hash.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can do it in the follow-up PR.

@Test
public void testKnownBytesInputs() {
byte[] test = "test".getBytes(StandardCharsets.UTF_8);
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test, 0),
Murmur3_x86_32.hashUnsafeBytes2(test, Platform.BYTE_ARRAY_OFFSET, test.length, 0));
byte[] test1 = "test1".getBytes(StandardCharsets.UTF_8);
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test1, 0),
Murmur3_x86_32.hashUnsafeBytes2(test1, Platform.BYTE_ARRAY_OFFSET, test1.length, 0));
byte[] te = "te".getBytes(StandardCharsets.UTF_8);
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(te, 0),
Murmur3_x86_32.hashUnsafeBytes2(te, Platform.BYTE_ARRAY_OFFSET, te.length, 0));
byte[] tes = "tes".getBytes(StandardCharsets.UTF_8);
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(tes, 0),
Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0));
}

@Test
public void randomizedStressTest() {
int size = 65536;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.feature

import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
Expand All @@ -28,6 +29,8 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.OpenHashMap

Expand Down Expand Up @@ -138,7 +141,7 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme

@Since("2.3.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val hashFunc: Any => Int = OldHashingTF.murmur3Hash
val hashFunc: Any => Int = FeatureHasher.murmur3Hash
val n = $(numFeatures)
val localInputCols = $(inputCols)
val catCols = if (isSet(categoricalCols)) {
Expand Down Expand Up @@ -218,4 +221,32 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] {

@Since("2.3.0")
override def load(path: String): FeatureHasher = super.load(path)

private val seed = OldHashingTF.seed

/**
* Calculate a hash code value for the term object using
* Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32).
* This is the default hash algorithm used from Spark 2.0 onwards.
* Use hashUnsafeBytes2 to match the original algorithm with the value.
* See SPARK-23381.
*/
@Since("2.3.0")
private[feature] def murmur3Hash(term: Any): Int = {
term match {
case null => seed
case b: Boolean => hashInt(if (b) 1 else 0, seed)
case b: Byte => hashInt(b, seed)
case s: Short => hashInt(s, seed)
case i: Int => hashInt(i, seed)
case l: Long => hashLong(l, seed)
case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
case s: String =>
val utf8 = UTF8String.fromString(s)
hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed)
case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " +
s"support type ${term.getClass.getCanonicalName} of input data.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ object HashingTF {

private[HashingTF] val Murmur3: String = "murmur3"

private val seed = 42
private[spark] val seed = 42

/**
* Calculate a hash code value for the term object using the native Scala implementation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class FeatureHasherSuite extends SparkFunSuite
with MLlibTestSparkContext
with DefaultReadWriteTest {

import testImplicits._

import HashingTFSuite.murmur3FeatureIdx
import FeatureHasherSuite.murmur3FeatureIdx

implicit private val vectorEncoder = ExpressionEncoder[Vector]()

Expand Down Expand Up @@ -216,3 +217,11 @@ class FeatureHasherSuite extends SparkFunSuite
testDefaultReadWrite(t)
}
}

object FeatureHasherSuite {

private[feature] def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = {
Utils.nonNegativeMod(FeatureHasher.murmur3Hash(term), numFeatures)
}

}
4 changes: 2 additions & 2 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,9 +741,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
>>> df = spark.createDataFrame(data, cols)
>>> hasher = FeatureHasher(inputCols=cols, outputCol="features")
>>> hasher.transform(df).head().features
SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0})
SparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})
>>> hasher.setCategoricalCols(["real"]).transform(df).head().features
SparseVector(262144, {51871: 1.0, 63643: 1.0, 171257: 1.0, 253195: 1.0})
SparseVector(262144, {171257: 1.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})
>>> hasherPath = temp_path + "/hasher"
>>> hasher.save(hasherPath)
>>> loadedHasher = FeatureHasher.load(hasherPath)
Expand Down