@@ -20,40 +20,55 @@ package org.apache.spark.mllib.optimization
2020import scala .util .Random
2121
2222import org .scalatest .FunSuite
23- import org .scalatest .matchers .ShouldMatchers
2423
25- import org .apache . spark . mllib . util . LocalSparkContext
24+ import org .jblas .{ DoubleMatrix , SimpleBlas , NativeBlas }
2625
27- import org .jblas .DoubleMatrix
28- import org .jblas .SimpleBlas
29-
30- class NNLSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
26+ class NNLSSuite extends FunSuite {
3127 test(" NNLSbyPCG: exact solution case" ) {
32- val A = new DoubleMatrix (20 , 20 )
33- val b = new DoubleMatrix (20 , 1 )
28+ val n = 20
29+ val A = new DoubleMatrix (n, n)
30+ val b = new DoubleMatrix (n, 1 )
3431 val rand = new Random (12345 )
35- for (i <- 0 until 20 ; j <- 0 until 20 ) {
32+ for (i <- 0 until n ; j <- 0 until n ) {
3633 val aij = rand.nextDouble()
3734 A .put(i, j, aij)
3835 b.put(i, b.get(i, 0 ) + aij)
3936 }
4037
41- val ata = new DoubleMatrix (20 , 20 )
42- val atb = new DoubleMatrix (20 , 1 )
43- for (i <- 0 until 20 ; j <- 0 until 20 ; k <- 0 until 20 ) {
44- ata.put(i, j, ata.get(i, j) + A .get(k, i) * A .get(k, j))
45- }
46- for (i <- 0 until 20 ; j <- 0 until 20 ) {
47- atb.put(i, atb.get(i, 0 ) + A .get(j, i) * b.get(j))
48- }
38+ val ata = new DoubleMatrix (n, n)
39+ val atb = new DoubleMatrix (n, 1 )
40+
41+ NativeBlas .dgemm('T' , 'N' , n, n, n, 1.0 , A .data, 0 , n, A .data, 0 , n, 0.0 , ata.data, 0 , n)
42+ NativeBlas .dgemv('T' , n, n, 1.0 , A .data, 0 , n, b.data, 0 , 1 , 0.0 , atb.data, 0 , 1 )
4943
5044 val x = NNLSbyPCG .solve(ata, atb, true )
51- assert(x.length == 20 )
45+ assert(x.length == n )
5246 var error = 0.0
53- for (i <- 0 until 20 ) {
47+ for (i <- 0 until n ) {
5448 error = error + (x(i) - 1 ) * (x(i) - 1 )
5549 assert(Math .abs(x(i) - 1 ) < 1e-3 )
5650 }
5751 assert(error < 1e-2 )
5852 }
53+
54+ test(" NNLSbyPCG: nonnegativity constraint active" ) {
55+ val n = 5
56+ val M = Array (
57+ Array ( 4.377 , - 3.531 , - 1.306 , - 0.139 , 3.418 , - 1.632 ),
58+ Array (- 3.531 , 4.344 , 0.934 , 0.305 , - 2.140 , 2.115 ),
59+ Array (- 1.306 , 0.934 , 2.644 , - 0.203 , - 0.170 , 1.094 ),
60+ Array (- 0.139 , 0.305 , - 0.203 , 5.883 , 1.428 , - 1.025 ),
61+ Array ( 3.418 , - 2.140 , - 0.170 , 1.428 , 4.684 , - 0.636 ))
62+ val ata = new DoubleMatrix (5 , 5 )
63+ val atb = new DoubleMatrix (5 , 1 )
64+ for (i <- 0 until 5 ; j <- 0 until 5 ) ata.put(i, j, M (i)(j))
65+ for (i <- 0 until 5 ) atb.put(i, M (i)(5 ))
66+
67+ val goodx = Array (0.13025 , 0.54506 , 0.2874 , 0.0 , 0.028628 )
68+
69+ val x = NNLSbyPCG .solve(ata, atb, true )
70+ for (i <- 0 until 5 ) {
71+ assert(Math .abs(x(i) - goodx(i)) < 1e-3 )
72+ }
73+ }
5974}
0 commit comments