Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added gamma to LBFGSState #320

Merged
merged 1 commit into from
Oct 7, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions jaxopt/_src/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class LbfgsState(NamedTuple):
s_history: Any
y_history: Any
rho_history: jnp.ndarray
gamma: jnp.ndarray
aux: Optional[Any] = None


Expand Down Expand Up @@ -255,6 +256,7 @@ def init_state(self,
s_history=init_history(init_params, self.history_size),
y_history=init_history(init_params, self.history_size),
rho_history=jnp.zeros(self.history_size, dtype=dtype),
gamma=jnp.asarray(1.0, dtype=dtype),
aux=aux)

def update(self,
Expand All @@ -275,16 +277,10 @@ def update(self,
(value, aux), grad = self._value_and_grad_with_aux(params, *args, **kwargs)

start = state.iter_num % self.history_size
previous = (start + self.history_size - 1) % self.history_size

if self.use_gamma:
gamma = compute_gamma(state.s_history, state.y_history, previous)
else:
gamma = 1.0

product = inv_hessian_product(pytree=grad, s_history=state.s_history,
y_history=state.y_history,
rho_history=state.rho_history, gamma=gamma,
rho_history=state.rho_history, gamma=state.gamma,
start=start)
descent_direction = tree_scalar_mul(-1.0, product)

Expand Down Expand Up @@ -351,13 +347,19 @@ def update(self,
y_history = update_history(state.y_history, y, last)
rho_history = update_history(state.rho_history, rho, last)

if self.use_gamma:
gamma = compute_gamma(s_history, y_history, last)
else:
gamma = jnp.array(1.0)

new_state = LbfgsState(iter_num=state.iter_num + 1,
value=new_value,
stepsize=jnp.asarray(new_stepsize),
error=tree_l2_norm(new_grad),
s_history=s_history,
y_history=y_history,
rho_history=rho_history,
gamma=gamma,
# FIXME: we should return new_aux here but
# BacktrackingLineSearch currently doesn't support
# an aux.
Expand Down