Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,27 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
validateAndTransformSchema(schema)
}

/**
* Check prerequisite for nearest neighbor. This method will be overridden in subclasses.
*
* @param singleProbe True for using single-probe; false for multi-probe
*/
protected[this] def checkNearestNeighbor(singleProbe: Boolean) = {}

/**
* Given a large dataset and an item, approximately find at most k items which have the closest
* distance to the item. If the [[outputCol]] is missing, the method will transform the data; if
* the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the
* transformed data when necessary.
*
* This method implements two ways of fetching k nearest neighbors:
* - Single Probing: Fast, return at most k elements (Probing only one buckets)
* - Multiple Probing: Slow, return exact k elements (Probing multiple buckets close to the key)
* - Single-probe: Fast, return at most k elements (Probing only one buckets)
* - Multi-probe: Slow, return exact k elements (Probing multiple buckets close to the key)
*
* @param dataset the dataset to search for nearest neighbors of the key
* @param key Feature vector representing the item to search for
* @param numNearestNeighbors The maximum number of nearest neighbors
* @param singleProbing True for using Single Probing; false for multiple probing
* @param singleProbe True for using single-probe; false for multi-probe
* @param distCol Output column for storing the distance between each result row and the key
* @return A dataset containing at most k items closest to the key. A distCol is added to show
* the distance between each row and the key.
Expand All @@ -121,9 +128,10 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
dataset: Dataset[_],
key: Vector,
numNearestNeighbors: Int,
singleProbing: Boolean,
singleProbe: Boolean,
distCol: String): Dataset[_] = {
require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1")
checkNearestNeighbor(singleProbe)
// Get Hash Value of the key
val keyHash = hashFunction(key)
val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) {
Expand All @@ -136,7 +144,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
val hashDistUDF = udf((x: Vector) => hashDistance(x, keyHash), DataTypes.DoubleType)
val hashDistCol = hashDistUDF(col($(outputCol)))

val modelSubset = if (singleProbing) {
val modelSubset = if (singleProbe) {
modelDataset.filter(hashDistCol === 0.0)
} else {
// Compute threshold to get exact k elements.
Expand Down
22 changes: 15 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,8 @@ import org.apache.spark.sql.types.StructType
* :: Experimental ::
*
* Model produced by [[MinHash]], where multiple hash functions are stored. Each hash function is
* a perfect hash function:
* `h_i(x) = (x * k_i mod prime) mod numEntries`
* where `k_i` is the i-th coefficient, and both `x` and `k_i` are from `Z_prime^*`
*
* Reference:
* [[https://en.wikipedia.org/wiki/Perfect_hash_function Wikipedia on Perfect Hash Function]]
* a perfect hash function for a specific set `S` with cardinality equal to a half of `numEntries`:
* `h_i(x) = ((x \cdot k_i) \mod prime) \mod numEntries`
*
* @param numEntries The number of entries of the hash functions.
* @param randCoefficients An array of random coefficients, each used by one hash function.
Expand Down Expand Up @@ -76,7 +72,19 @@ class MinHashModel private[ml] (
@Since("2.1.0")
override protected[ml] def hashDistance(x: Vector, y: Vector): Double = {
// Since it's generated by hashing, it will be a pair of dense vectors.
x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min
if (x.toDense.values.zip(y.toDense.values).exists(pair => pair._1 == pair._2)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why just 0 and 1? I think if more pairs of values are the same, more the two vectors are closer, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See discussion above :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I do more agree on the comment from @jkbradley at #15800 (comment), if I understand correctly some terms here.

Is the indicator meaning a matching hashing value between two vectors from one hashing function, i.e., h_i?
If this understanding is correct, I think averaging indicators should be the right way to compute MinHash's hash distance.

0
} else {
1
}
}

@Since("2.1.0")
override protected[this] def checkNearestNeighbor(singleProbe: Boolean) = {
if (!singleProbe) {
log.warn("Multi-probe for MinHash will run brute force nearest neighbor when there " +
"aren't enough candidates.")
}
}

@Since("2.1.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ private[ml] object LSHTest {
* @param dataset the dataset to look for the key
* @param key The key to hash for the item
* @param k The maximum number of items closest to the key
* @param singleProbe True for using single-probe; false for multi-probe
* @tparam T The class type of lsh
* @return A tuple of two doubles, representing precision and recall rate
*/
Expand All @@ -91,22 +92,22 @@ private[ml] object LSHTest {
dataset: Dataset[_],
key: Vector,
k: Int,
singleProbing: Boolean): (Double, Double) = {
singleProbe: Boolean): (Double, Double) = {
val model = lsh.fit(dataset)

// Compute expected
val distUDF = udf((x: Vector) => model.keyDistance(x, key), DataTypes.DoubleType)
val expected = dataset.sort(distUDF(col(model.getInputCol))).limit(k)

// Compute actual
val actual = model.approxNearestNeighbors(dataset, key, k, singleProbing, "distCol")
val actual = model.approxNearestNeighbors(dataset, key, k, singleProbe, "distCol")

assert(actual.schema.sameType(model
.transformSchema(dataset.schema)
.add("distCol", DataTypes.DoubleType))
)

if (!singleProbing) {
if (!singleProbe) {
assert(actual.count() == k)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ class MinHashSuite extends SparkFunSuite with MLlibTestSparkContext with Default
val v1 = Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0)))
val v2 = Vectors.sparse(10, Seq((1, 1.0), (3, 1.0), (5, 1.0), (7, 1.0), (9, 1.0)))
val keyDist = model.keyDistance(v1, v2)
val hashDist = model.hashDistance(Vectors.dense(-5, 5), Vectors.dense(1, 2))
val hashDist1 = model.hashDistance(Vectors.dense(1, 2), Vectors.dense(3, 4))
val hashDist2 = model.hashDistance(Vectors.dense(1, 2), Vectors.dense(3, 2))
assert(keyDist === 0.5)
assert(hashDist === 3)
assert(hashDist1 === 1.0)
assert(hashDist2 === 0.0)
}

test("MinHash: test of LSH property") {
Expand All @@ -97,7 +99,7 @@ class MinHashSuite extends SparkFunSuite with MLlibTestSparkContext with Default
(0 until 100).filter(_.toString.contains("1")).map((_, 1.0)))

val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(mh, dataset, key, 20,
singleProbing = true)
singleProbe = true)
assert(precision >= 0.7)
assert(recall >= 0.7)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class RandomProjectionSuite
.setSeed(12345)

val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100,
singleProbing = true)
singleProbe = true)
assert(precision >= 0.6)
assert(recall >= 0.6)
}
Expand All @@ -154,7 +154,7 @@ class RandomProjectionSuite
.setSeed(12345)

val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100,
singleProbing = false)
singleProbe = false)
assert(precision >= 0.7)
assert(recall >= 0.7)
}
Expand Down