Skip to content

Commit

Permalink
added gamma to lbfgs state
Browse files Browse the repository at this point in the history
  • Loading branch information
zaccharieramzi committed Oct 7, 2022
1 parent 982c689 commit b270cc1
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions jaxopt/_src/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class LbfgsState(NamedTuple):
s_history: Any
y_history: Any
rho_history: jnp.ndarray
gamma: jnp.ndarray
aux: Optional[Any] = None


Expand Down Expand Up @@ -255,6 +256,7 @@ def init_state(self,
s_history=init_history(init_params, self.history_size),
y_history=init_history(init_params, self.history_size),
rho_history=jnp.zeros(self.history_size, dtype=dtype),
gamma=jnp.asarray(1.0, dtype=dtype),
aux=aux)

def update(self,
Expand All @@ -275,16 +277,10 @@ def update(self,
(value, aux), grad = self._value_and_grad_with_aux(params, *args, **kwargs)

start = state.iter_num % self.history_size
previous = (start + self.history_size - 1) % self.history_size

if self.use_gamma:
gamma = compute_gamma(state.s_history, state.y_history, previous)
else:
gamma = 1.0

product = inv_hessian_product(pytree=grad, s_history=state.s_history,
y_history=state.y_history,
rho_history=state.rho_history, gamma=gamma,
rho_history=state.rho_history, gamma=state.gamma,
start=start)
descent_direction = tree_scalar_mul(-1.0, product)

Expand Down Expand Up @@ -351,13 +347,19 @@ def update(self,
y_history = update_history(state.y_history, y, last)
rho_history = update_history(state.rho_history, rho, last)

if self.use_gamma:
gamma = compute_gamma(s_history, y_history, last)
else:
gamma = jnp.array(1.0)

new_state = LbfgsState(iter_num=state.iter_num + 1,
value=new_value,
stepsize=jnp.asarray(new_stepsize),
error=tree_l2_norm(new_grad),
s_history=s_history,
y_history=y_history,
rho_history=rho_history,
gamma=gamma,
# FIXME: we should return new_aux here but
# BacktrackingLineSearch currently doesn't support
# an aux.
Expand Down

0 comments on commit b270cc1

Please sign in to comment.