Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 526978774
  • Loading branch information
emilyfertig authored and JAXopt authors committed Apr 25, 2023
1 parent 4edd8ac commit 4aa9bc9
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 24 deletions.
8 changes: 6 additions & 2 deletions jaxopt/_src/backtracking_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class BacktrackingLineSearchState(NamedTuple):
grad: Any
num_fun_eval: int
num_grad_eval: int
failed: bool
aux: Optional[Any] = None


Expand Down Expand Up @@ -133,7 +134,8 @@ def init_state(self,
num_fun_eval=num_fun_eval,
num_grad_eval=num_grad_eval,
done=jnp.asarray(False),
grad=grad)
grad=grad,
failed=jnp.asarray(False))

def update(self,
stepsize: float,
Expand Down Expand Up @@ -233,6 +235,7 @@ def update(self,
stepsize,
stepsize * self.decrease_factor)
done = state.done | (error <= self.tol)
failed = state.failed | ((state.iter_num + 1 == self.maxiter) & ~done)

new_state = BacktrackingLineSearchState(iter_num=state.iter_num + 1,
value=new_value,
Expand All @@ -242,7 +245,8 @@ def update(self,
num_fun_eval=num_fun_eval,
num_grad_eval=num_grad_eval,
done=done,
error=error)
error=error,
failed=failed)

return base.LineSearchStep(stepsize=new_stepsize, state=new_state)

Expand Down
30 changes: 20 additions & 10 deletions jaxopt/_src/hager_zhang_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class HagerZhangLineSearchState(NamedTuple):
error: float
params: Any
grad: Any
failed: bool
aux: Optional[Any] = None


Expand Down Expand Up @@ -239,7 +240,7 @@ def _bracket(
# Initial interval that satisfies the opposite slope condition.

def cond_fn(state):
return jnp.any(~state[0])
return jnp.any(~state[0]) | jnp.any(state[-1])

def body_fn(state):
(done,
Expand All @@ -248,7 +249,8 @@ def body_fn(state):
high,
value_middle,
grad_middle,
best_middle) = state
best_middle,
failed) = state
# Correspond to B1 in the paper.
update_right_endpoint = grad_middle >= 0.
new_high = jnp.where(~done & update_right_endpoint, middle, high)
Expand Down Expand Up @@ -285,18 +287,23 @@ def _update_interval():
new_value_middle, new_grad_middle = self._value_and_grad_on_line(
x, new_middle, descent_direction, *args, **kwargs)

# Terminate on encountering NaNs to avoid an infinite loop.
failed = (failed |
jnp.any(jnp.isnan(new_value_middle)) |
jnp.any(jnp.isnan(new_grad_middle)))
return (done,
new_low,
new_middle,
new_high,
new_value_middle,
new_grad_middle,
best_middle)
best_middle,
failed)

value_c, grad_c = self._value_and_grad_on_line(
x, c, descent_direction, *args, **kwargs)

_, final_low, _, final_high, _, _, _ = jax.lax.while_loop(
_, final_low, _, final_high, _, _, _, failed = jax.lax.while_loop(
cond_fn,
body_fn,
(jnp.array(False),
Expand All @@ -305,8 +312,9 @@ def _update_interval():
c,
value_c,
grad_c,
jnp.array(0.)))
return final_low, final_high
jnp.array(0.),
jnp.array(False)))
return final_low, final_high, failed

def init_state(self, # pylint:disable=keyword-arg-before-vararg
init_stepsize: float,
Expand Down Expand Up @@ -345,7 +353,7 @@ def init_state(self, # pylint:disable=keyword-arg-before-vararg
value + self.approximate_wolfe_threshold * jnp.abs(value))

# Create initial interval.
low, high = self._bracket(
low, high, failed = self._bracket(
params, jnp.ones_like(value),
approx_wolfe_threshold_value, descent_direction, *args, **kwargs)

Expand Down Expand Up @@ -378,7 +386,8 @@ def init_state(self, # pylint:disable=keyword-arg-before-vararg
value=value,
aux=None, # we do not need to have aux in the initial state
params=params,
grad=grad)
grad=grad,
failed=failed)

def update(self, # pylint:disable=keyword-arg-before-vararg
stepsize: float,
Expand Down Expand Up @@ -410,7 +419,6 @@ def update(self, # pylint:disable=keyword-arg-before-vararg
else:
value, grad = self._value_and_grad_fun(params, *args, **kwargs)


if descent_direction is None:
descent_direction = tree_scalar_mul(-1, grad)

Expand Down Expand Up @@ -471,6 +479,7 @@ def _reupdate():
approx_wolfe_threshold_value,
descent_direction))
done = state.done | (error <= self.tol)
failed = state.failed | ((state.iter_num + 1 == self.maxiter) & ~done)

new_state = HagerZhangLineSearchState(
iter_num=state.iter_num + 1,
Expand All @@ -481,7 +490,8 @@ def _reupdate():
low=new_low,
high=new_high,
error=error,
done=done)
done=done,
failed=failed)

return base.LineSearchStep(stepsize=new_stepsize, state=new_state)

Expand Down
7 changes: 2 additions & 5 deletions jaxopt/_src/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def update(self,
new_aux = ls_state.aux
else:
raise ValueError("Invalid name in 'linesearch' option.")
failed_linesearch = ls_state.failed

else:
# without line search
Expand All @@ -401,6 +402,7 @@ def update(self,

new_params = tree_add_scalar_mul(params, new_stepsize, descent_direction)
(new_value, new_aux), new_grad = self._value_and_grad_with_aux(new_params, *args, **kwargs)
failed_linesearch = jnp.asarray(False)
s = tree_sub(new_params, params)
y = tree_sub(new_grad, grad)
vdot_sy = tree_vdot(s, y)
Expand All @@ -416,11 +418,6 @@ def update(self,
else:
gamma = jnp.array(1.0)

if use_linesearch and self.linesearch == "zoom":
failed_linesearch = ls_state.failed
else: # backtracking linesearch doesn't support failed state yet
failed_linesearch = jnp.asarray(False)

new_state = LbfgsState(iter_num=state.iter_num + 1,
value=new_value,
grad=new_grad,
Expand Down
7 changes: 2 additions & 5 deletions jaxopt/_src/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ def update(
new_aux = ls_state.aux
else:
raise ValueError("Invalid name in 'linesearch' option.")
failed_linesearch = ls_state.failed
else:
# without line search
if isinstance(self.stepsize, Callable):
Expand All @@ -558,6 +559,7 @@ def update(
(new_value, new_aux), new_grad = self._value_and_grad_with_aux(
new_params, *args, **kwargs
)
failed_linesearch = jnp.asarray(False)

s = tree_sub(new_params, params)
y = tree_sub(new_grad, state.grad)
Expand All @@ -568,11 +570,6 @@ def update(
else:
gamma_inv = jnp.ones([], dtype=curvature.dtype)

if use_linesearch and self.linesearch == "zoom":
failed_linesearch = ls_state.failed
else: # backtracking linesearch doesn't support failed state yet
failed_linesearch = jnp.asarray(False)

history_ind = state.num_updates % self.history_size
(new_s_history, new_y_history, new_theta, new_num_updates) = (
jax.lax.cond(
Expand Down
7 changes: 7 additions & 0 deletions tests/backtracking_linesearch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _check_conditions_satisfied(
initial_grad,
final_state):
self.assertTrue(jnp.all(final_state.done))
self.assertFalse(jnp.any(final_state.failed))

descent_direction = tree_scalar_mul(-1, initial_grad)
# Check sufficient decrease for all line search methods.
Expand Down Expand Up @@ -127,6 +128,12 @@ def test_backtracking_linesearch(self, cond):
self._check_conditions_satisfied(
cond, ls.c1, ls.c2, stepsize, initial_value, initial_grad, state)

# Failed linesearch (high c1 ensures convergence condition is not met).
ls = BacktrackingLineSearch(fun=fun, maxiter=20, condition=cond, c1=2.)
_, state = ls.run(init_stepsize=1.0, params=w_init, data=data)
self.assertTrue(jnp.all(state.failed))
self.assertFalse(jnp.any(state.done))


if __name__ == '__main__':
# Uncomment the line below in order to run in float64.
Expand Down
7 changes: 7 additions & 0 deletions tests/hager_zhang_linesearch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _check_conditions_satisfied(
initial_grad,
final_state):
self.assertTrue(jnp.all(final_state.done))
self.assertFalse(jnp.any(final_state.failed))

descent_direction = tree_scalar_mul(-1, initial_grad)
sufficient_decrease = jnp.all(
Expand Down Expand Up @@ -84,6 +85,12 @@ def test_hager_zhang_linesearch(self):
self._check_conditions_satisfied(
ls.c1, ls.c2, stepsize, initial_value, initial_grad, state)

# Failed linesearch (high c1 ensures convergence condition is not met).
ls = HagerZhangLineSearch(fun=fun, maxiter=20, c1=2.)
_, state = ls.run(init_stepsize=1., params=w_init, data=data)
self.assertTrue(jnp.all(state.failed))
self.assertFalse(jnp.any(state.done))


if __name__ == '__main__':
# Uncomment the line below in order to run in float64.
Expand Down
6 changes: 4 additions & 2 deletions tests/mirror_descent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def stepsize_schedule(t: int) -> float:
class MirrorDescentTest(test_util.JaxoptTestCase):

@parameterized.named_parameters(
('kl', None),
# FIXME: Re-enable when JAX numerical stability issue is solved.
# ('kl', None),
('kl_stable', projection_grad_kl_stable),
)
def test_multiclass_svm_dual(self, projection_grad):
Expand Down Expand Up @@ -110,7 +111,8 @@ def test_multiclass_svm_dual(self, projection_grad):
self.assertArraysAllClose(W_fit, W_skl, atol=atol)

@parameterized.named_parameters(
('kl', None),
# FIXME: Re-enable when JAX numerical stability issue is solved.
# ('kl', None),
('kl_stable', projection_grad_kl_stable),
)
def test_multiclass_svm_dual_implicit_diff(self, projection_grad):
Expand Down

0 comments on commit 4aa9bc9

Please sign in to comment.