From 559c09904538012b70bcb3493b8bc287dd855b2d Mon Sep 17 00:00:00 2001 From: Yun Ni Date: Mon, 7 Nov 2016 13:30:32 -0800 Subject: [PATCH 1/5] [SPARK-18334] MinHash should use binary hash distance --- .../main/scala/org/apache/spark/ml/feature/MinHash.scala | 6 +++++- .../scala/org/apache/spark/ml/feature/MinHashSuite.scala | 6 ++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala index d9d0f32254e2..e7ba6375168e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala @@ -76,7 +76,11 @@ class MinHashModel private[ml] ( @Since("2.1.0") override protected[ml] def hashDistance(x: Vector, y: Vector): Double = { // Since it's generated by hashing, it will be a pair of dense vectors. - x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min + if (x.toDense.values.zip(y.toDense.values).exists(pair => pair._1 == pair._2)) { + 0 + } else { + 1 + } } @Since("2.1.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala index c32ca7d69cf8..cf6290b315ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala @@ -69,9 +69,11 @@ class MinHashSuite extends SparkFunSuite with MLlibTestSparkContext with Default val v1 = Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0))) val v2 = Vectors.sparse(10, Seq((1, 1.0), (3, 1.0), (5, 1.0), (7, 1.0), (9, 1.0))) val keyDist = model.keyDistance(v1, v2) - val hashDist = model.hashDistance(Vectors.dense(-5, 5), Vectors.dense(1, 2)) + val hashDist1 = model.hashDistance(Vectors.dense(1, 2), Vectors.dense(3, 4)) + val hashDist2 = model.hashDistance(Vectors.dense(1, 2), Vectors.dense(3, 2)) assert(keyDist === 0.5) - assert(hashDist === 3) + assert(hashDist1 === 1.0) + assert(hashDist2 === 0.0) } test("MinHash: test of LSH property") { From 517a97bd16f3771d9abbcdf54957a011f5f87adc Mon Sep 17 00:00:00 2001 From: Yunni Date: Tue, 8 Nov 2016 01:15:24 -0500 Subject: [PATCH 2/5] Remove misleading documentation as requested --- .../main/scala/org/apache/spark/ml/feature/MinHash.scala | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala index e7ba6375168e..488c4ede5f45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala @@ -32,13 +32,7 @@ import org.apache.spark.sql.types.StructType * :: Experimental :: * * Model produced by [[MinHash]], where multiple hash functions are stored. Each hash function is - * a perfect hash function: - * `h_i(x) = (x * k_i mod prime) mod numEntries` - * where `k_i` is the i-th coefficient, and both `x` and `k_i` are from `Z_prime^*` - * - * Reference: - * [[https://en.wikipedia.org/wiki/Perfect_hash_function Wikipedia on Perfect Hash Function]] - * + * a perfect hash. * @param numEntries The number of entries of the hash functions. * @param randCoefficients An array of random coefficients, each used by one hash function. */ From b546dbd207a04e73bde097f25cae8c927322c2ae Mon Sep 17 00:00:00 2001 From: Yun Ni Date: Tue, 8 Nov 2016 10:54:09 -0800 Subject: [PATCH 3/5] Add warning for multi-probe in MinHash --- .../org/apache/spark/ml/feature/LSH.scala | 18 +++++++++++++----- .../org/apache/spark/ml/feature/MinHash.scala | 8 ++++++++ .../org/apache/spark/ml/feature/LSHTest.scala | 7 ++++--- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 333a8c364a88..a0b2e75bdd02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -99,6 +99,13 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] validateAndTransformSchema(schema) } + /** + * Check prerequisite for nearest neighbor. This method will be overridden in subclasses. + * + * @param singleProbe True for using single-probe; false for multi-probe + */ + protected[this] def checkNearestNeighbor(singleProbe: Boolean) = {} + /** * Given a large dataset and an item, approximately find at most k items which have the closest * distance to the item. If the [[outputCol]] is missing, the method will transform the data; if @@ -106,13 +113,13 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] * transformed data when necessary. * * This method implements two ways of fetching k nearest neighbors: - * - Single Probing: Fast, return at most k elements (Probing only one buckets) - * - Multiple Probing: Slow, return exact k elements (Probing multiple buckets close to the key) + * - Single-probe: Fast, return at most k elements (Probing only one buckets) + * - Multi-probe: Slow, return exact k elements (Probing multiple buckets close to the key) * * @param dataset the dataset to search for nearest neighbors of the key * @param key Feature vector representing the item to search for * @param numNearestNeighbors The maximum number of nearest neighbors - * @param singleProbing True for using Single Probing; false for multiple probing + * @param singleProbe True for using single-probe; false for multi-probe * @param distCol Output column for storing the distance between each result row and the key * @return A dataset containing at most k items closest to the key. A distCol is added to show * the distance between each row and the key. @@ -121,9 +128,10 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] dataset: Dataset[_], key: Vector, numNearestNeighbors: Int, - singleProbing: Boolean, + singleProbe: Boolean, distCol: String): Dataset[_] = { require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1") + checkNearestNeighbor(singleProbe) // Get Hash Value of the key val keyHash = hashFunction(key) val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) { @@ -136,7 +144,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] val hashDistUDF = udf((x: Vector) => hashDistance(x, keyHash), DataTypes.DoubleType) val hashDistCol = hashDistUDF(col($(outputCol))) - val modelSubset = if (singleProbing) { + val modelSubset = if (singleProbe) { modelDataset.filter(hashDistCol === 0.0) } else { // Compute threshold to get exact k elements. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala index e7ba6375168e..6b99fbdb217a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala @@ -83,6 +83,14 @@ class MinHashModel private[ml] ( } } + @Since("2.1.0") + override protected[this] def checkNearestNeighbor(singleProbe: Boolean) = { + if (!singleProbe) { + log.warn("Multi-probe for MinHash will run brute force nearest neighbor when there " + + "aren't enough candidates.") + } + } + @Since("2.1.0") override def copy(extra: ParamMap): this.type = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala index 5c025546f332..f7ded60fb654 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala @@ -83,6 +83,7 @@ private[ml] object LSHTest { * @param dataset the dataset to look for the key * @param key The key to hash for the item * @param k The maximum number of items closest to the key + * @param singleProbe True for using single-probe; false for multi-probe * @tparam T The class type of lsh * @return A tuple of two doubles, representing precision and recall rate */ @@ -91,7 +92,7 @@ private[ml] object LSHTest { dataset: Dataset[_], key: Vector, k: Int, - singleProbing: Boolean): (Double, Double) = { + singleProbe: Boolean): (Double, Double) = { val model = lsh.fit(dataset) // Compute expected @@ -99,14 +100,14 @@ private[ml] object LSHTest { val expected = dataset.sort(distUDF(col(model.getInputCol))).limit(k) // Compute actual - val actual = model.approxNearestNeighbors(dataset, key, k, singleProbing, "distCol") + val actual = model.approxNearestNeighbors(dataset, key, k, singleProbe, "distCol") assert(actual.schema.sameType(model .transformSchema(dataset.schema) .add("distCol", DataTypes.DoubleType)) ) - if (!singleProbing) { + if (!singleProbe) { assert(actual.count() == k) } From c8243c7def8c270072edd5889cea7fd02677b44f Mon Sep 17 00:00:00 2001 From: Yun Ni Date: Wed, 9 Nov 2016 15:11:20 -0800 Subject: [PATCH 4/5] (1) Fix documentation as CR suggested (2) Fix typo in unit test --- .../src/main/scala/org/apache/spark/ml/feature/MinHash.scala | 4 +++- .../test/scala/org/apache/spark/ml/feature/MinHashSuite.scala | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala index 82ac9b0888b2..8b320c5bbb77 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala @@ -32,7 +32,9 @@ import org.apache.spark.sql.types.StructType * :: Experimental :: * * Model produced by [[MinHash]], where multiple hash functions are stored. Each hash function is - * a perfect hash. + * a perfect hash function for a specific set `S` with cardinality equal to a half of `numEntries`: + * `h_i(x) = ((x \cdot k_i) \mod prime) \mod numEntries` + * * @param numEntries The number of entries of the hash functions. * @param randCoefficients An array of random coefficients, each used by one hash function. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala index cf6290b315ab..d05f693cc961 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala @@ -99,7 +99,7 @@ class MinHashSuite extends SparkFunSuite with MLlibTestSparkContext with Default (0 until 100).filter(_.toString.contains("1")).map((_, 1.0))) val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(mh, dataset, key, 20, - singleProbing = true) + singleProbe = true) assert(precision >= 0.7) assert(recall >= 0.7) } From 6aac8b343c5ea3a91b8517a2d3f47ed055ece9ad Mon Sep 17 00:00:00 2001 From: Yun Ni Date: Wed, 9 Nov 2016 15:22:27 -0800 Subject: [PATCH 5/5] Fix typo in unit test --- .../org/apache/spark/ml/feature/RandomProjectionSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala index cd82ee2117a0..07f95527fcfd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala @@ -138,7 +138,7 @@ class RandomProjectionSuite .setSeed(12345) val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100, - singleProbing = true) + singleProbe = true) assert(precision >= 0.6) assert(recall >= 0.6) } @@ -154,7 +154,7 @@ class RandomProjectionSuite .setSeed(12345) val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100, - singleProbing = false) + singleProbe = false) assert(precision >= 0.7) assert(recall >= 0.7) }