Skip to content

Commit b285106

Browse files
committed
Clean up NNLS test cases.
1 parent 9c820b6 commit b285106

File tree

1 file changed

+34
-19
lines changed

1 file changed

+34
-19
lines changed

mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,40 +20,55 @@ package org.apache.spark.mllib.optimization
2020
import scala.util.Random
2121

2222
import org.scalatest.FunSuite
23-
import org.scalatest.matchers.ShouldMatchers
2423

25-
import org.apache.spark.mllib.util.LocalSparkContext
24+
import org.jblas.{DoubleMatrix, SimpleBlas, NativeBlas}
2625

27-
import org.jblas.DoubleMatrix
28-
import org.jblas.SimpleBlas
29-
30-
class NNLSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
26+
class NNLSSuite extends FunSuite {
3127
test("NNLSbyPCG: exact solution case") {
32-
val A = new DoubleMatrix(20, 20)
33-
val b = new DoubleMatrix(20, 1)
28+
val n = 20
29+
val A = new DoubleMatrix(n, n)
30+
val b = new DoubleMatrix(n, 1)
3431
val rand = new Random(12345)
35-
for (i <- 0 until 20; j <- 0 until 20) {
32+
for (i <- 0 until n; j <- 0 until n) {
3633
val aij = rand.nextDouble()
3734
A.put(i, j, aij)
3835
b.put(i, b.get(i, 0) + aij)
3936
}
4037

41-
val ata = new DoubleMatrix(20, 20)
42-
val atb = new DoubleMatrix(20, 1)
43-
for (i <- 0 until 20; j <- 0 until 20; k <- 0 until 20) {
44-
ata.put(i, j, ata.get(i, j) + A.get(k, i) * A.get(k, j))
45-
}
46-
for (i <- 0 until 20; j <- 0 until 20) {
47-
atb.put(i, atb.get(i, 0) + A.get(j, i) * b.get(j))
48-
}
38+
val ata = new DoubleMatrix(n, n)
39+
val atb = new DoubleMatrix(n, 1)
40+
41+
NativeBlas.dgemm('T', 'N', n, n, n, 1.0, A.data, 0, n, A.data, 0, n, 0.0, ata.data, 0, n)
42+
NativeBlas.dgemv('T', n, n, 1.0, A.data, 0, n, b.data, 0, 1, 0.0, atb.data, 0, 1)
4943

5044
val x = NNLSbyPCG.solve(ata, atb, true)
51-
assert(x.length == 20)
45+
assert(x.length == n)
5246
var error = 0.0
53-
for (i <- 0 until 20) {
47+
for (i <- 0 until n) {
5448
error = error + (x(i) - 1) * (x(i) - 1)
5549
assert(Math.abs(x(i) - 1) < 1e-3)
5650
}
5751
assert(error < 1e-2)
5852
}
53+
54+
test("NNLSbyPCG: nonnegativity constraint active") {
55+
val n = 5
56+
val M = Array(
57+
Array( 4.377, -3.531, -1.306, -0.139, 3.418, -1.632),
58+
Array(-3.531, 4.344, 0.934, 0.305, -2.140, 2.115),
59+
Array(-1.306, 0.934, 2.644, -0.203, -0.170, 1.094),
60+
Array(-0.139, 0.305, -0.203, 5.883, 1.428, -1.025),
61+
Array( 3.418, -2.140, -0.170, 1.428, 4.684, -0.636))
62+
val ata = new DoubleMatrix(5, 5)
63+
val atb = new DoubleMatrix(5, 1)
64+
for (i <- 0 until 5; j <- 0 until 5) ata.put(i, j, M(i)(j))
65+
for (i <- 0 until 5) atb.put(i, M(i)(5))
66+
67+
val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628)
68+
69+
val x = NNLSbyPCG.solve(ata, atb, true)
70+
for (i <- 0 until 5) {
71+
assert(Math.abs(x(i) - goodx(i)) < 1e-3)
72+
}
73+
}
5974
}

0 commit comments

Comments
 (0)