@@ -14,9 +14,17 @@ import breeze.math._
1414 *
1515 * @author dlwh
1616 */
17- class OWLQN [T ](maxIter : Int , m : Int , l1reg : Double = 1.0 , tolerance : Double = 1E-8 )(implicit space : MutableCoordinateField [T , Double ]) extends LBFGS [T ](maxIter, m, tolerance= tolerance) with SerializableLogging {
17+ class OWLQN [T , K ](maxIter : Int , m : Int , l1reg : K => Double , tolerance : Double )(implicit space : MutableEnumeratedCoordinateField [T , K , Double ]) extends LBFGS [T ](maxIter, m, tolerance= tolerance) with SerializableLogging {
18+
19+ def this (maxIter : Int , m : Int , l1reg : K => Double )(implicit space : MutableEnumeratedCoordinateField [T , K , Double ]) = this (maxIter, m, l1reg, 1E-8 )
20+
21+ def this (maxIter : Int , m : Int , l1reg : Double , tolerance : Double = 1E-8 )(implicit space: MutableEnumeratedCoordinateField [T , K , Double ]) = this (maxIter, m, (_ : K ) => l1reg, tolerance)
22+
23+ def this (maxIter : Int , m : Int , l1reg : Double )(implicit space : MutableEnumeratedCoordinateField [T , K , Double ]) = this (maxIter, m, (_ : K ) => l1reg, 1E-8 )
24+
25+ def this (maxIter : Int , m : Int )(implicit space : MutableEnumeratedCoordinateField [T , K , Double ]) = this (maxIter, m, (_ : K ) => 1.0 , 1E-8 )
26+
1827 require(m > 0 )
19- require(l1reg >= 0 )
2028
2129 import space ._
2230
@@ -81,18 +89,25 @@ class OWLQN[T](maxIter: Int, m: Int, l1reg: Double=1.0, tolerance: Double = 1E-
8189
8290 // Adds in the regularization stuff to the gradient
8391 override protected def adjust (newX : T , newGrad : T , newVal : Double ): (Double , T ) = {
84- val res = space.zipMapValues.map(newX, newGrad, {case (xv, v) =>
85- xv match {
86- case 0.0 => {
87- val delta_+ = v + l1reg
88- val delta_- = v - l1reg
89- if (delta_- > 0 ) delta_- else if (delta_+ < 0 ) delta_+ else 0.0
92+ var adjValue = newVal
93+ val res = space.zipMapKeyValues.map(newX, newGrad, {case (i, xv, v) =>
94+ val l1regValue = l1reg(i)
95+ require(l1regValue >= 0.0 )
96+
97+ if (l1regValue == 0.0 ) {
98+ v
99+ } else {
100+ adjValue += Math .abs(l1regValue * xv)
101+ xv match {
102+ case 0.0 => {
103+ val delta_+ = v + l1regValue
104+ val delta_- = v - l1regValue
105+ if (delta_- > 0 ) delta_- else if (delta_+ < 0 ) delta_+ else 0.0
106+ }
107+ case _ => v + math.signum(xv) * l1regValue
90108 }
91-
92- case _ => v + math.signum(xv) * l1reg
93109 }
94110 })
95- val adjValue = newVal + l1reg * norm(newX, 1.0 )
96111 adjValue -> res
97112 }
98113
@@ -105,26 +120,3 @@ class OWLQN[T](maxIter: Int, m: Int, l1reg: Double=1.0, tolerance: Double = 1E-
105120 }
106121
107122}
108-
109-
110- object OWLQN {
111- def main (args : Array [String ]) {
112- val lbfgs = new OWLQN [DenseVector [Double ]](100 ,4 )
113-
114- def optimizeThis (init : DenseVector [Double ]) = {
115- val f = new DiffFunction [DenseVector [Double ]] {
116- def calculate (x : DenseVector [Double ]) = {
117- (sum((x - 3.0 ) :^ 2.0 ),(x * 2.0 ) - 6.0 )
118- }
119- }
120-
121- val result = lbfgs.minimize(f,init)
122- }
123-
124- // optimizeThis(Counter(1->1.0,2->2.0,3->3.0))
125- // optimizeThis(Counter(3-> -2.0,2->3.0,1-> -10.0))
126- // optimizeThis(DenseVector(1.0,2.0,3.0))
127- optimizeThis(DenseVector ( - 0.0 ,0.0 , - 0.0 ))
128- }
129- }
130-
0 commit comments