Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 507773415
  • Loading branch information
srvasude authored and JAXopt authors committed Feb 7, 2023
1 parent 60259a4 commit 0c8b25b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
36 changes: 30 additions & 6 deletions jaxopt/_src/hager_zhang_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class HagerZhangLineSearchState(NamedTuple):
error: float
params: Any
grad: Any
aux: Optional[Any] = None


@dataclass(eq=False)
Expand All @@ -62,6 +63,9 @@ class HagerZhangLineSearch(base.IterativeLineSearch):
value_and_grad: if ``False``, ``fun`` should return the function value only.
If ``True``, ``fun`` should return both the function value and the
gradient.
has_aux: if ``False``, ``fun`` should return the function value only.
If ``True``, ``fun`` should return a pair ``(value, aux)`` where ``aux``
is a pytree of auxiliary values.
c1: constant used by the Wolfe and Approximate Wolfe condition.
c2: constant strictly less than 1 used by the Wolfe and Approximate Wolfe
Expand All @@ -79,6 +83,7 @@ class HagerZhangLineSearch(base.IterativeLineSearch):
"""
fun: Callable # pylint:disable=g-bare-generic
value_and_grad: bool = False
has_aux: bool = False

maxiter: int = 30
tol: float = 0.
Expand All @@ -95,8 +100,11 @@ class HagerZhangLineSearch(base.IterativeLineSearch):
unroll: base.AutoOrBoolean = "auto"

def _value_and_grad_on_line(self, x, c, descent_direction, *args, **kwargs):
value, grad = self._value_and_grad_fun(
tree_add_scalar_mul(x, c, descent_direction), *args, **kwargs)
z = tree_add_scalar_mul(x, c, descent_direction)
if self.has_aux:
(value, _), grad = self._value_and_grad_fun(z, *args, **kwargs)
else:
value, grad = self._value_and_grad_fun(z, *args, **kwargs)
return value, tree_vdot(grad, descent_direction)

def _satisfies_wolfe_and_approx_wolfe(
Expand Down Expand Up @@ -324,7 +332,11 @@ def init_state(self, # pylint:disable=keyword-arg-before-vararg
del init_stepsize

if value is None or grad is None:
value, grad = self._value_and_grad_fun(params, *args, **kwargs)
if self.has_aux:
(value, _), grad = self._value_and_grad_fun(params, *args, **kwargs)
else:
value, grad = self._value_and_grad_fun(params, *args, **kwargs)


if descent_direction is None:
descent_direction = tree_scalar_mul(-1, grad)
Expand Down Expand Up @@ -364,6 +376,7 @@ def init_state(self, # pylint:disable=keyword-arg-before-vararg
error=error,
done=done,
value=value,
aux=None, # we do not need to have aux in the initial state
params=params,
grad=grad)

Expand Down Expand Up @@ -392,7 +405,11 @@ def update(self, # pylint:disable=keyword-arg-before-vararg
"""

if value is None or grad is None:
value, grad = self._value_and_grad_fun(params, *args, **kwargs)
if self.has_aux:
(value, _), grad = self._value_and_grad_fun(params, *args, **kwargs)
else:
value, grad = self._value_and_grad_fun(params, *args, **kwargs)


if descent_direction is None:
descent_direction = tree_scalar_mul(-1, grad)
Expand Down Expand Up @@ -436,7 +453,13 @@ def _reupdate():

new_stepsize = jnp.where(state.done, stepsize, best_point)
new_params = tree_add_scalar_mul(params, best_point, descent_direction)
new_value, new_grad = self._value_and_grad_fun(new_params, *args, **kwargs)
if self.has_aux:
(new_value, new_aux), new_grad = self._value_and_grad_fun(
new_params, *args, **kwargs)
else:
new_value, new_grad = self._value_and_grad_fun(
new_params, *args, **kwargs)
new_aux = None

error = jnp.where(state.done, state.error,
self._satisfies_wolfe_and_approx_wolfe(
Expand All @@ -453,6 +476,7 @@ def _reupdate():
iter_num=state.iter_num + 1,
value=new_value,
grad=new_grad,
aux=new_aux,
params=new_params,
low=new_low,
high=new_high,
Expand All @@ -465,4 +489,4 @@ def __post_init__(self):
if self.value_and_grad:
self._value_and_grad_fun = self.fun
else:
self._value_and_grad_fun = jax.value_and_grad(self.fun)
self._value_and_grad_fun = jax.value_and_grad(self.fun, has_aux=self.has_aux)
7 changes: 3 additions & 4 deletions jaxopt/_src/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,11 +385,10 @@ def update(self,
params, value, grad,
descent_direction,
*args, **kwargs)

new_params = ls_state.params
(new_value, new_aux), new_grad = self._value_and_grad_with_aux(
new_params, *args, **kwargs)

new_value = ls_state.value
new_grad = ls_state.grad
new_aux = ls_state.aux
else:
raise ValueError("Invalid name in 'linesearch' option.")

Expand Down

0 comments on commit 0c8b25b

Please sign in to comment.