Skip to content

Commit cdf315b

Browse files
Yun Nijkbradley
authored andcommitted
[SPARK-18408][ML] API Improvements for LSH
## What changes were proposed in this pull request? (1) Change output schema to `Array of Vector` instead of `Vectors` (2) Use `numHashTables` as the dimension of Array (3) Rename `RandomProjection` to `BucketedRandomProjectionLSH`, `MinHash` to `MinHashLSH` (4) Make `randUnitVectors/randCoefficients` private (5) Make Multi-Probe NN Search and `hashDistance` private for future discussion Saved for future PRs: (1) AND-amplification and `numHashFunctions` as the dimension of Vector are saved for a future PR. (2) `hashDistance` and MultiProbe NN Search needs more discussion. The current implementation is just a backward compatible one. ## How was this patch tested? Related unit tests are modified to make sure the performance of LSH are ensured, and the outputs of the APIs meets expectation. Author: Yun Ni <yunn@uber.com> Author: Yunni <Euler57721@gmail.com> Closes #15874 from Yunni/SPARK-18408-yunn-api-improvements. (cherry picked from commit 05f7c6f) Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
1 parent 75d73d1 commit cdf315b

File tree

6 files changed

+306
-221
lines changed

6 files changed

+306
-221
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala renamed to mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ import org.apache.spark.sql.types.StructType
3434
/**
3535
* :: Experimental ::
3636
*
37-
* Params for [[RandomProjection]].
37+
* Params for [[BucketedRandomProjectionLSH]].
3838
*/
39-
private[ml] trait RandomProjectionParams extends Params {
39+
private[ml] trait BucketedRandomProjectionLSHParams extends Params {
4040

4141
/**
4242
* The length of each hash bucket, a larger bucket lowers the false negative rate. The number of
@@ -58,8 +58,8 @@ private[ml] trait RandomProjectionParams extends Params {
5858
/**
5959
* :: Experimental ::
6060
*
61-
* Model produced by [[RandomProjection]], where multiple random vectors are stored. The vectors
62-
* are normalized to be unit vectors and each vector is used in a hash function:
61+
* Model produced by [[BucketedRandomProjectionLSH]], where multiple random vectors are stored. The
62+
* vectors are normalized to be unit vectors and each vector is used in a hash function:
6363
* `h_i(x) = floor(r_i.dot(x) / bucketLength)`
6464
* where `r_i` is the i-th random unit vector. The number of buckets will be `(max L2 norm of input
6565
* vectors) / bucketLength`.
@@ -68,18 +68,19 @@ private[ml] trait RandomProjectionParams extends Params {
6868
*/
6969
@Experimental
7070
@Since("2.1.0")
71-
class RandomProjectionModel private[ml] (
71+
class BucketedRandomProjectionLSHModel private[ml](
7272
override val uid: String,
73-
@Since("2.1.0") val randUnitVectors: Array[Vector])
74-
extends LSHModel[RandomProjectionModel] with RandomProjectionParams {
73+
private[ml] val randUnitVectors: Array[Vector])
74+
extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams {
7575

7676
@Since("2.1.0")
77-
override protected[ml] val hashFunction: (Vector) => Vector = {
77+
override protected[ml] val hashFunction: Vector => Array[Vector] = {
7878
key: Vector => {
7979
val hashValues: Array[Double] = randUnitVectors.map({
8080
randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength))
8181
})
82-
Vectors.dense(hashValues)
82+
// TODO: Output vectors of dimension numHashFunctions in SPARK-18450
83+
hashValues.map(Vectors.dense(_))
8384
}
8485
}
8586

@@ -89,27 +90,29 @@ class RandomProjectionModel private[ml] (
8990
}
9091

9192
@Since("2.1.0")
92-
override protected[ml] def hashDistance(x: Vector, y: Vector): Double = {
93+
override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
9394
// Since it's generated by hashing, it will be a pair of dense vectors.
94-
x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min
95+
x.zip(y).map(vectorPair => Vectors.sqdist(vectorPair._1, vectorPair._2)).min
9596
}
9697

9798
@Since("2.1.0")
9899
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
99100

100101
@Since("2.1.0")
101-
override def write: MLWriter = new RandomProjectionModel.RandomProjectionModelWriter(this)
102+
override def write: MLWriter = {
103+
new BucketedRandomProjectionLSHModel.BucketedRandomProjectionLSHModelWriter(this)
104+
}
102105
}
103106

104107
/**
105108
* :: Experimental ::
106109
*
107-
* This [[RandomProjection]] implements Locality Sensitive Hashing functions for Euclidean
108-
* distance metrics.
110+
* This [[BucketedRandomProjectionLSH]] implements Locality Sensitive Hashing functions for
111+
* Euclidean distance metrics.
109112
*
110113
* The input is dense or sparse vectors, each of which represents a point in the Euclidean
111-
* distance space. The output will be vectors of configurable dimension. Hash value in the same
112-
* dimension is calculated by the same hash function.
114+
* distance space. The output will be vectors of configurable dimension. Hash values in the
115+
* same dimension are calculated by the same hash function.
113116
*
114117
* References:
115118
*
@@ -121,8 +124,9 @@ class RandomProjectionModel private[ml] (
121124
*/
122125
@Experimental
123126
@Since("2.1.0")
124-
class RandomProjection(override val uid: String) extends LSH[RandomProjectionModel]
125-
with RandomProjectionParams with HasSeed {
127+
class BucketedRandomProjectionLSH(override val uid: String)
128+
extends LSH[BucketedRandomProjectionLSHModel]
129+
with BucketedRandomProjectionLSHParams with HasSeed {
126130

127131
@Since("2.1.0")
128132
override def setInputCol(value: String): this.type = super.setInputCol(value)
@@ -131,11 +135,11 @@ class RandomProjection(override val uid: String) extends LSH[RandomProjectionMod
131135
override def setOutputCol(value: String): this.type = super.setOutputCol(value)
132136

133137
@Since("2.1.0")
134-
override def setOutputDim(value: Int): this.type = super.setOutputDim(value)
138+
override def setNumHashTables(value: Int): this.type = super.setNumHashTables(value)
135139

136140
@Since("2.1.0")
137141
def this() = {
138-
this(Identifiable.randomUID("random projection"))
142+
this(Identifiable.randomUID("brp-lsh"))
139143
}
140144

141145
/** @group setParam */
@@ -147,15 +151,16 @@ class RandomProjection(override val uid: String) extends LSH[RandomProjectionMod
147151
def setSeed(value: Long): this.type = set(seed, value)
148152

149153
@Since("2.1.0")
150-
override protected[this] def createRawLSHModel(inputDim: Int): RandomProjectionModel = {
154+
override protected[this] def createRawLSHModel(
155+
inputDim: Int): BucketedRandomProjectionLSHModel = {
151156
val rand = new Random($(seed))
152157
val randUnitVectors: Array[Vector] = {
153-
Array.fill($(outputDim)) {
158+
Array.fill($(numHashTables)) {
154159
val randArray = Array.fill(inputDim)(rand.nextGaussian())
155160
Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray)))
156161
}
157162
}
158-
new RandomProjectionModel(uid, randUnitVectors)
163+
new BucketedRandomProjectionLSHModel(uid, randUnitVectors)
159164
}
160165

161166
@Since("2.1.0")
@@ -169,23 +174,25 @@ class RandomProjection(override val uid: String) extends LSH[RandomProjectionMod
169174
}
170175

171176
@Since("2.1.0")
172-
object RandomProjection extends DefaultParamsReadable[RandomProjection] {
177+
object BucketedRandomProjectionLSH extends DefaultParamsReadable[BucketedRandomProjectionLSH] {
173178

174179
@Since("2.1.0")
175-
override def load(path: String): RandomProjection = super.load(path)
180+
override def load(path: String): BucketedRandomProjectionLSH = super.load(path)
176181
}
177182

178183
@Since("2.1.0")
179-
object RandomProjectionModel extends MLReadable[RandomProjectionModel] {
184+
object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProjectionLSHModel] {
180185

181186
@Since("2.1.0")
182-
override def read: MLReader[RandomProjectionModel] = new RandomProjectionModelReader
187+
override def read: MLReader[BucketedRandomProjectionLSHModel] = {
188+
new BucketedRandomProjectionLSHModelReader
189+
}
183190

184191
@Since("2.1.0")
185-
override def load(path: String): RandomProjectionModel = super.load(path)
192+
override def load(path: String): BucketedRandomProjectionLSHModel = super.load(path)
186193

187-
private[RandomProjectionModel] class RandomProjectionModelWriter(instance: RandomProjectionModel)
188-
extends MLWriter {
194+
private[BucketedRandomProjectionLSHModel] class BucketedRandomProjectionLSHModelWriter(
195+
instance: BucketedRandomProjectionLSHModel) extends MLWriter {
189196

190197
// TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved.
191198
private case class Data(randUnitVectors: Matrix)
@@ -203,20 +210,22 @@ object RandomProjectionModel extends MLReadable[RandomProjectionModel] {
203210
}
204211
}
205212

206-
private class RandomProjectionModelReader extends MLReader[RandomProjectionModel] {
213+
private class BucketedRandomProjectionLSHModelReader
214+
extends MLReader[BucketedRandomProjectionLSHModel] {
207215

208216
/** Checked against metadata when loading model */
209-
private val className = classOf[RandomProjectionModel].getName
217+
private val className = classOf[BucketedRandomProjectionLSHModel].getName
210218

211-
override def load(path: String): RandomProjectionModel = {
219+
override def load(path: String): BucketedRandomProjectionLSHModel = {
212220
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
213221

214222
val dataPath = new Path(path, "data").toString
215223
val data = sparkSession.read.parquet(dataPath)
216224
val Row(randUnitVectors: Matrix) = MLUtils.convertMatrixColumnsToML(data, "randUnitVectors")
217225
.select("randUnitVectors")
218226
.head()
219-
val model = new RandomProjectionModel(metadata.uid, randUnitVectors.rowIter.toArray)
227+
val model = new BucketedRandomProjectionLSHModel(metadata.uid,
228+
randUnitVectors.rowIter.toArray)
220229

221230
DefaultParamsReader.getAndSetParams(model, metadata)
222231
model

0 commit comments

Comments
 (0)