Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Formatting in momo docstring + doctest #950

Merged
merged 3 commits into from
Apr 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 49 additions & 10 deletions optax/contrib/_momo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,37 @@ def momo(
final loss.

MoMo performs SGD with momentum with a Polyak-type learning rate. The
effective step size is
``min(learning_rate, <adaptive term>)``
effective step size is ``min(learning_rate, <adaptive term>)``, where the
adaptive term is computed on the fly.

where the adaptive term is computed on the fly.
Note that one needs to pass the latest (batch) loss value to the update
function using the keyword argument ``value``.

Examples:
>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = contrib.momo()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... value, grad = jax.value_and_grad(f)(params)
... params, opt_state = solver.update(grad, opt_state, params, value=value)
... print('Objective function: ', f(params))
Objective function: 3.5
Objective function: 0.0
Objective function: 0.0
Objective function: 0.0
Objective function: 0.0

Note that in ``update_fn`` you need to pass the latest (batch) loss value to
the argument `value`.

References:
Schaipp et al., `MoMo: Momentum Models for Adaptive Learning Rates
<https://arxiv.org/abs/2305.07583>`_, 2023

Args:
learning_rate: User-specified learning rate. Recommended to be chosen rather
large, by default 1.0.
Expand Down Expand Up @@ -187,17 +207,36 @@ def momo_adam(
final loss.

MoMo performs Adam(W) with a Polyak-type learning rate. The
effective step size is
``min(learning_rate, <adaptive term>)``
effective step size is ``min(learning_rate, <adaptive term>)``, where the
adaptive term is computed on the fly.

where the adaptive term is computed on the fly.
Note that one needs to pass the latest (batch) loss value to the update
function using the keyword argument ``value``.

Note that in ``update_fn`` you need to pass the latest (batch) loss value to
the argument `value`.
Examples:
>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = contrib.momo_adam()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... value, grad = jax.value_and_grad(f)(params)
... params, opt_state = solver.update(grad, opt_state, params, value=value)
... print('Objective function: ', f(params))
Objective function: 0.00029999594
Objective function: 0.0
Objective function: 0.0
Objective function: 0.0
Objective function: 0.0

References:
Schaipp et al., `MoMo: Momentum Models for Adaptive Learning Rates
<https://arxiv.org/abs/2305.07583>`_, 2023

Args:
learning_rate: User-specified learning rate. Recommended to be chosen rather
large, by default 1.0.
Expand Down
Loading