@@ -34,9 +34,9 @@ import org.apache.spark.sql.types.StructType
3434/**
3535 * :: Experimental ::
3636 *
37- * Params for [[RandomProjection ]].
37+ * Params for [[BucketedRandomProjectionLSH ]].
3838 */
39- private [ml] trait RandomProjectionParams extends Params {
39+ private [ml] trait BucketedRandomProjectionLSHParams extends Params {
4040
4141 /**
4242 * The length of each hash bucket, a larger bucket lowers the false negative rate. The number of
@@ -58,8 +58,8 @@ private[ml] trait RandomProjectionParams extends Params {
5858/**
5959 * :: Experimental ::
6060 *
61- * Model produced by [[RandomProjection ]], where multiple random vectors are stored. The vectors
62- * are normalized to be unit vectors and each vector is used in a hash function:
61+ * Model produced by [[BucketedRandomProjectionLSH ]], where multiple random vectors are stored. The
62+ * vectors are normalized to be unit vectors and each vector is used in a hash function:
6363 * `h_i(x) = floor(r_i.dot(x) / bucketLength)`
6464 * where `r_i` is the i-th random unit vector. The number of buckets will be `(max L2 norm of input
6565 * vectors) / bucketLength`.
@@ -68,18 +68,19 @@ private[ml] trait RandomProjectionParams extends Params {
6868 */
6969@ Experimental
7070@ Since (" 2.1.0" )
71- class RandomProjectionModel private [ml] (
71+ class BucketedRandomProjectionLSHModel private [ml](
7272 override val uid : String ,
73- @ Since ( " 2.1.0 " ) val randUnitVectors : Array [Vector ])
74- extends LSHModel [RandomProjectionModel ] with RandomProjectionParams {
73+ private [ml] val randUnitVectors : Array [Vector ])
74+ extends LSHModel [BucketedRandomProjectionLSHModel ] with BucketedRandomProjectionLSHParams {
7575
7676 @ Since (" 2.1.0" )
77- override protected [ml] val hashFunction : ( Vector ) => Vector = {
77+ override protected [ml] val hashFunction : Vector => Array [ Vector ] = {
7878 key : Vector => {
7979 val hashValues : Array [Double ] = randUnitVectors.map({
8080 randUnitVector => Math .floor(BLAS .dot(key, randUnitVector) / $(bucketLength))
8181 })
82- Vectors .dense(hashValues)
82+ // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
83+ hashValues.map(Vectors .dense(_))
8384 }
8485 }
8586
@@ -89,27 +90,29 @@ class RandomProjectionModel private[ml] (
8990 }
9091
9192 @ Since (" 2.1.0" )
92- override protected [ml] def hashDistance (x : Vector , y : Vector ): Double = {
93+ override protected [ml] def hashDistance (x : Seq [ Vector ] , y : Seq [ Vector ] ): Double = {
9394 // Since it's generated by hashing, it will be a pair of dense vectors.
94- x.toDense.values. zip(y.toDense.values ).map(pair => math.abs(pair ._1 - pair ._2)).min
95+ x.zip(y).map(vectorPair => Vectors .sqdist(vectorPair ._1, vectorPair ._2)).min
9596 }
9697
9798 @ Since (" 2.1.0" )
9899 override def copy (extra : ParamMap ): this .type = defaultCopy(extra)
99100
100101 @ Since (" 2.1.0" )
101- override def write : MLWriter = new RandomProjectionModel .RandomProjectionModelWriter (this )
102+ override def write : MLWriter = {
103+ new BucketedRandomProjectionLSHModel .BucketedRandomProjectionLSHModelWriter (this )
104+ }
102105}
103106
104107/**
105108 * :: Experimental ::
106109 *
107- * This [[RandomProjection ]] implements Locality Sensitive Hashing functions for Euclidean
108- * distance metrics.
110+ * This [[BucketedRandomProjectionLSH ]] implements Locality Sensitive Hashing functions for
111+ * Euclidean distance metrics.
109112 *
110113 * The input is dense or sparse vectors, each of which represents a point in the Euclidean
111- * distance space. The output will be vectors of configurable dimension. Hash value in the same
112- * dimension is calculated by the same hash function.
114+ * distance space. The output will be vectors of configurable dimension. Hash values in the
115+ * same dimension are calculated by the same hash function.
113116 *
114117 * References:
115118 *
@@ -121,8 +124,9 @@ class RandomProjectionModel private[ml] (
121124 */
122125@ Experimental
123126@ Since (" 2.1.0" )
124- class RandomProjection (override val uid : String ) extends LSH [RandomProjectionModel ]
125- with RandomProjectionParams with HasSeed {
127+ class BucketedRandomProjectionLSH (override val uid : String )
128+ extends LSH [BucketedRandomProjectionLSHModel ]
129+ with BucketedRandomProjectionLSHParams with HasSeed {
126130
127131 @ Since (" 2.1.0" )
128132 override def setInputCol (value : String ): this .type = super .setInputCol(value)
@@ -131,11 +135,11 @@ class RandomProjection(override val uid: String) extends LSH[RandomProjectionMod
131135 override def setOutputCol (value : String ): this .type = super .setOutputCol(value)
132136
133137 @ Since (" 2.1.0" )
134- override def setOutputDim (value : Int ): this .type = super .setOutputDim (value)
138+ override def setNumHashTables (value : Int ): this .type = super .setNumHashTables (value)
135139
136140 @ Since (" 2.1.0" )
137141 def this () = {
138- this (Identifiable .randomUID(" random projection " ))
142+ this (Identifiable .randomUID(" brp-lsh " ))
139143 }
140144
141145 /** @group setParam */
@@ -147,15 +151,16 @@ class RandomProjection(override val uid: String) extends LSH[RandomProjectionMod
147151 def setSeed (value : Long ): this .type = set(seed, value)
148152
149153 @ Since (" 2.1.0" )
150- override protected [this ] def createRawLSHModel (inputDim : Int ): RandomProjectionModel = {
154+ override protected [this ] def createRawLSHModel (
155+ inputDim : Int ): BucketedRandomProjectionLSHModel = {
151156 val rand = new Random ($(seed))
152157 val randUnitVectors : Array [Vector ] = {
153- Array .fill($(outputDim )) {
158+ Array .fill($(numHashTables )) {
154159 val randArray = Array .fill(inputDim)(rand.nextGaussian())
155160 Vectors .fromBreeze(normalize(breeze.linalg.Vector (randArray)))
156161 }
157162 }
158- new RandomProjectionModel (uid, randUnitVectors)
163+ new BucketedRandomProjectionLSHModel (uid, randUnitVectors)
159164 }
160165
161166 @ Since (" 2.1.0" )
@@ -169,23 +174,25 @@ class RandomProjection(override val uid: String) extends LSH[RandomProjectionMod
169174}
170175
171176@ Since (" 2.1.0" )
172- object RandomProjection extends DefaultParamsReadable [RandomProjection ] {
177+ object BucketedRandomProjectionLSH extends DefaultParamsReadable [BucketedRandomProjectionLSH ] {
173178
174179 @ Since (" 2.1.0" )
175- override def load (path : String ): RandomProjection = super .load(path)
180+ override def load (path : String ): BucketedRandomProjectionLSH = super .load(path)
176181}
177182
178183@ Since (" 2.1.0" )
179- object RandomProjectionModel extends MLReadable [RandomProjectionModel ] {
184+ object BucketedRandomProjectionLSHModel extends MLReadable [BucketedRandomProjectionLSHModel ] {
180185
181186 @ Since (" 2.1.0" )
182- override def read : MLReader [RandomProjectionModel ] = new RandomProjectionModelReader
187+ override def read : MLReader [BucketedRandomProjectionLSHModel ] = {
188+ new BucketedRandomProjectionLSHModelReader
189+ }
183190
184191 @ Since (" 2.1.0" )
185- override def load (path : String ): RandomProjectionModel = super .load(path)
192+ override def load (path : String ): BucketedRandomProjectionLSHModel = super .load(path)
186193
187- private [RandomProjectionModel ] class RandomProjectionModelWriter ( instance : RandomProjectionModel )
188- extends MLWriter {
194+ private [BucketedRandomProjectionLSHModel ] class BucketedRandomProjectionLSHModelWriter (
195+ instance : BucketedRandomProjectionLSHModel ) extends MLWriter {
189196
190197 // TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved.
191198 private case class Data (randUnitVectors : Matrix )
@@ -203,20 +210,22 @@ object RandomProjectionModel extends MLReadable[RandomProjectionModel] {
203210 }
204211 }
205212
206- private class RandomProjectionModelReader extends MLReader [RandomProjectionModel ] {
213+ private class BucketedRandomProjectionLSHModelReader
214+ extends MLReader [BucketedRandomProjectionLSHModel ] {
207215
208216 /** Checked against metadata when loading model */
209- private val className = classOf [RandomProjectionModel ].getName
217+ private val className = classOf [BucketedRandomProjectionLSHModel ].getName
210218
211- override def load (path : String ): RandomProjectionModel = {
219+ override def load (path : String ): BucketedRandomProjectionLSHModel = {
212220 val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
213221
214222 val dataPath = new Path (path, " data" ).toString
215223 val data = sparkSession.read.parquet(dataPath)
216224 val Row (randUnitVectors : Matrix ) = MLUtils .convertMatrixColumnsToML(data, " randUnitVectors" )
217225 .select(" randUnitVectors" )
218226 .head()
219- val model = new RandomProjectionModel (metadata.uid, randUnitVectors.rowIter.toArray)
227+ val model = new BucketedRandomProjectionLSHModel (metadata.uid,
228+ randUnitVectors.rowIter.toArray)
220229
221230 DefaultParamsReader .getAndSetParams(model, metadata)
222231 model
0 commit comments