From b270cc1a44144d84faca5ae73b2d75d1a89c5f65 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Fri, 7 Oct 2022 20:10:20 +0200 Subject: [PATCH] added gamma to lbfgs state --- jaxopt/_src/lbfgs.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/jaxopt/_src/lbfgs.py b/jaxopt/_src/lbfgs.py index 639a839a..9f79ada9 100644 --- a/jaxopt/_src/lbfgs.py +++ b/jaxopt/_src/lbfgs.py @@ -142,6 +142,7 @@ class LbfgsState(NamedTuple): s_history: Any y_history: Any rho_history: jnp.ndarray + gamma: jnp.ndarray aux: Optional[Any] = None @@ -255,6 +256,7 @@ def init_state(self, s_history=init_history(init_params, self.history_size), y_history=init_history(init_params, self.history_size), rho_history=jnp.zeros(self.history_size, dtype=dtype), + gamma=jnp.asarray(1.0, dtype=dtype), aux=aux) def update(self, @@ -275,16 +277,10 @@ def update(self, (value, aux), grad = self._value_and_grad_with_aux(params, *args, **kwargs) start = state.iter_num % self.history_size - previous = (start + self.history_size - 1) % self.history_size - - if self.use_gamma: - gamma = compute_gamma(state.s_history, state.y_history, previous) - else: - gamma = 1.0 product = inv_hessian_product(pytree=grad, s_history=state.s_history, y_history=state.y_history, - rho_history=state.rho_history, gamma=gamma, + rho_history=state.rho_history, gamma=state.gamma, start=start) descent_direction = tree_scalar_mul(-1.0, product) @@ -351,6 +347,11 @@ def update(self, y_history = update_history(state.y_history, y, last) rho_history = update_history(state.rho_history, rho, last) + if self.use_gamma: + gamma = compute_gamma(s_history, y_history, last) + else: + gamma = jnp.array(1.0) + new_state = LbfgsState(iter_num=state.iter_num + 1, value=new_value, stepsize=jnp.asarray(new_stepsize), @@ -358,6 +359,7 @@ def update(self, s_history=s_history, y_history=y_history, rho_history=rho_history, + gamma=gamma, # FIXME: we should return new_aux here but # BacktrackingLineSearch currently doesn't support # an aux.