Skip to content

Commit 89ea0a8

Browse files
committed
Hack ALSSuite to support NNLS testing.
1 parent f5dbf4d commit 89ea0a8

File tree

1 file changed

+39
-7
lines changed
  • mllib/src/test/scala/org/apache/spark/mllib/recommendation

1 file changed

+39
-7
lines changed

mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,18 @@ object ALSSuite {
4848
features: Int,
4949
samplingRate: Double,
5050
implicitPrefs: Boolean = false,
51-
negativeWeights: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = {
51+
negativeWeights: Boolean = false,
52+
negativeFactors: Boolean = true): (Seq[Rating], DoubleMatrix, DoubleMatrix) = {
5253
val rand = new Random(42)
5354

5455
// Create a random matrix with uniform values from -1 to 1
55-
def randomMatrix(m: Int, n: Int) =
56-
new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*)
56+
def randomMatrix(m: Int, n: Int) = {
57+
if (negativeFactors) {
58+
new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*)
59+
} else {
60+
new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble()): _*)
61+
}
62+
}
5763

5864
val userMatrix = randomMatrix(users, features)
5965
val productMatrix = randomMatrix(features, products)
@@ -128,6 +134,27 @@ class ALSSuite extends FunSuite with LocalSparkContext {
128134
assert(u11 != u2)
129135
}
130136

137+
test("negative ids") {
138+
val data = ALSSuite.generateRatings(50, 50, 2, 0.7, false, false)
139+
val ratings = sc.parallelize(data._1.map { case Rating(u,p,r) => Rating(u-25,p-25,r) })
140+
val correct = data._2
141+
val model = ALS.train(ratings, 5, 15)
142+
143+
val pairs = Array.tabulate(50, 50)((u,p) => (u-25,p-25)).flatten
144+
val ans = model.predict(sc.parallelize(pairs)).collect
145+
ans.foreach { r =>
146+
val u = r.user + 25
147+
val p = r.product + 25
148+
val v = r.rating
149+
val error = v - correct.get(u,p)
150+
assert(math.abs(error) < 0.4)
151+
}
152+
}
153+
154+
test("NNALS, rank 2") {
155+
testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, false)
156+
}
157+
131158
/**
132159
* Test if we can correctly factorize R = U * P where U and P are of known rank.
133160
*
@@ -140,16 +167,21 @@ class ALSSuite extends FunSuite with LocalSparkContext {
140167
* @param implicitPrefs flag to test implicit feedback
141168
* @param bulkPredict flag to test bulk prediciton
142169
* @param negativeWeights whether the generated data can contain negative values
170+
* @param numBlocks number of blocks to partition users and products into
171+
* @param negativeFactors whether the generated user/product factors can have negative entries
143172
*/
144173
def testALS(users: Int, products: Int, features: Int, iterations: Int,
145174
samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false,
146-
bulkPredict: Boolean = false, negativeWeights: Boolean = false)
175+
bulkPredict: Boolean = false, negativeWeights: Boolean = false, numBlocks: Int = -1,
176+
negativeFactors: Boolean = true)
147177
{
148178
val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
149-
features, samplingRate, implicitPrefs, negativeWeights)
179+
features, samplingRate, implicitPrefs, negativeWeights, negativeFactors)
150180
val model = implicitPrefs match {
151-
case false => ALS.train(sc.parallelize(sampledRatings), features, iterations)
152-
case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, iterations)
181+
case false => ALS.train(sc.parallelize(sampledRatings), features, iterations, 0.01,
182+
numBlocks, 0L, !negativeFactors)
183+
case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, iterations, 0.01,
184+
numBlocks, 1.0, 0L, !negativeFactors)
153185
}
154186

155187
val predictedU = new DoubleMatrix(users, features)

0 commit comments

Comments
 (0)