Skip to content

Commit

Permalink
Add bias correction and add_eps_in_sqrt options to rmsprop and associ…
Browse files Browse the repository at this point in the history
…ated transforms

PiperOrigin-RevId: 654024066
  • Loading branch information
vroulet authored and OptaxDev committed Jul 25, 2024
1 parent 203ecdd commit 073af1a
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 40 deletions.
69 changes: 51 additions & 18 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,9 +1337,11 @@ def rmsprop(
decay: float = 0.9,
eps: float = 1e-8,
initial_scale: float = 0.,
add_eps_in_sqrt: bool = True,
centered: bool = False,
momentum: Optional[float] = None,
nesterov: bool = False
nesterov: bool = False,
bias_correction: bool = False,
) -> base.GradientTransformation:
# pylint: disable=line-too-long
r"""A flexible RMSProp optimizer.
Expand All @@ -1350,12 +1352,16 @@ def rmsprop(
in the literature. This alias provides an easy to configure RMSProp
optimizer that can be used to switch between several of these variants.
..warning::
PyTorch and optax's RMSprop implementations differ and could impact
performance. In the denominator, optax uses :math:`$\sqrt{v + \epsilon}$`
whereas PyTorch uses :math:`$\sqrt{v} + \epsilon$`. See
.. warning::
Default behavior of optax's RMSprop (``add_eps_in_sqrt=True``) differs from
Pytorch's implementations differ and could impact performance.
If ``add_eps_in_sqrt=True``, in the denominator, optax uses
:math:`\sqrt{v + \epsilon}` whereas PyTorch uses
:math:`\sqrt{v} + \epsilon`.
Using ``add_eps_in_sqrt=False`` in optax will match PyTorch's behavior.
See
https://github.com/google-deepmind/optax/issues/532 for more detail.
Examples:
>>> import optax
>>> import jax
Expand All @@ -1378,8 +1384,14 @@ def rmsprop(
Objective function: 1.36E+01
References:
Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf
Graves, 2013: https://arxiv.org/abs/1308.0850
Hinton, `Overview of mini-batch gradient descent`
<www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_, 2012
Graves, `Generating Sequences With Recurrent Neural Networks
<https://arxiv.org/pdf/1308.0850v5>`_, 2014
Ziyin, `LaProp: Separating Momentum and Adaptivity in Adam`
<https://arxiv.org/pdf/2002.04839>`_, 2021
Args:
learning_rate: A global scaling factor, either fixed or evolving along
Expand All @@ -1389,11 +1401,15 @@ def rmsprop(
initial_scale: Initial value of accumulators tracking the magnitude of
previous updates. PyTorch uses `0`, TF1 uses `1`. When reproducing results
from a paper, verify the value used by the authors.
add_eps_in_sqrt: Whether to add ``eps`` in the square root of the
denominator or outside the square root.
centered: Whether the second moment or the variance of the past gradients is
used to rescale the latest gradients.
momentum: Decay rate used by the momentum term, when it is set to `None`,
then momentum is not used at all.
nesterov: Whether Nesterov momentum is used.
bias_correction: Whether to apply bias correction to the estimates of the
second moments (and first moment if ``centered=True``).
Returns:
The corresponding `GradientTransformation`.
Expand All @@ -1402,17 +1418,33 @@ def rmsprop(
if centered:
return combine.chain(
transform.scale_by_stddev(
decay=decay, eps=eps, initial_scale=initial_scale),
decay=decay,
eps=eps,
initial_scale=initial_scale,
add_eps_in_sqrt=add_eps_in_sqrt,
bias_correction=bias_correction,
),
transform.scale_by_learning_rate(learning_rate),
(transform.trace(decay=momentum, nesterov=nesterov)
if momentum is not None else base.identity())
(
transform.trace(decay=momentum, nesterov=nesterov)
if momentum is not None
else base.identity()
),
)
return combine.chain(
transform.scale_by_rms(
decay=decay, eps=eps, initial_scale=initial_scale),
decay=decay,
eps=eps,
initial_scale=initial_scale,
add_eps_in_sqrt=add_eps_in_sqrt,
bias_correction=bias_correction,
),
transform.scale_by_learning_rate(learning_rate),
(transform.trace(decay=momentum, nesterov=nesterov)
if momentum is not None else base.identity())
(
transform.trace(decay=momentum, nesterov=nesterov)
if momentum is not None
else base.identity()
),
)


Expand Down Expand Up @@ -1707,10 +1739,11 @@ def adamaxw(
to implement this as an additive loss term, however L2 regularization
does not behave as intended for adaptive gradient algorithms such as Adam.
WARNING: Sometimes you may want to skip weight decay for BatchNorm scale or
for the bias parameters. You can use `optax.masked` to make your own AdamaxW
variant where `additive_weight_decay` is applied only to a subset of `params`.
.. warning:: Sometimes you may want to skip weight decay for BatchNorm scale
or for the bias parameters. You can use `optax.masked` to make your own
AdamaxW variant where `additive_weight_decay` is applied only to a subset of
`params`.
Examples:
>>> import optax
>>> import jax
Expand Down
4 changes: 2 additions & 2 deletions optax/_src/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


def abs_sq(x: chex.Array) -> chex.Array:
"""Returns the squared norm of a (maybe complex) array.
"""Returns the squared absolute value of a (maybe complex) array.
For real `x`, JAX generates the same HLO from this, `jnp.square(x)`, `x * x`,
or `x**2`.
Expand All @@ -38,7 +38,7 @@ def abs_sq(x: chex.Array) -> chex.Array:
x: a (maybe complex) array.
Returns:
The squared norm of `x`.
The squared absolute value of `x`.
"""
if not isinstance(x, (np.ndarray, jnp.ndarray)):
raise ValueError(f"`abs_sq` accepts only NDarrays, got: {x}.")
Expand Down
137 changes: 117 additions & 20 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,67 +79,128 @@ def update_fn(updates, state, params=None):

class ScaleByRmsState(NamedTuple):
"""State for exponential root mean-squared (RMS)-normalized updates."""
# Kept for backward compatibility, even though ScaleByRmsWithCountState
# encompasses this state.
nu: base.Updates


class ScaleByRmsWithCountState(NamedTuple):
"""State for exponential root mean-squared (RMS)-normalized updates."""
count: chex.Array # shape=(), dtype=jnp.int32.
nu: base.Updates


def scale_by_rms(
decay: float = 0.9,
eps: float = 1e-8,
initial_scale: float = 0.
initial_scale: float = 0.0,
add_eps_in_sqrt: bool = True,
bias_correction: bool = False,
) -> base.GradientTransformation:
r"""Rescale updates by the root of the exp. moving avg of the square.
WARNING: PyTorch and optax's RMSprop implementations differ and could impact
performance. In the denominator, optax uses $\sqrt{v + \epsilon}$ whereas
PyTorch uses $\sqrt{v} + \epsilon$. See
.. warning::
Default behavior of optax's RMSprop (`add_eps_in_sqrt=True`) differ from
Pytorch's implementations differ and could impact performance.
If `add_eps_in_sqrt=True`, in the denominator, optax uses
$\sqrt{v + \epsilon}$ whereas PyTorch uses $\sqrt{v} + \epsilon$.
Using `add_eps_in_sqrt=False` in optax will match PyTorch's behavior.
See
https://github.com/google-deepmind/optax/issues/532 for more detail.
.. note::
Using `scale_by_rms(decay=b2, add_eps_in_sqrt=False, bias_correction=True)`
will match the behavior of `scale_by_adam(b1=0, b2=b2)`, while sparing the
memory cost of storing the first moment.
References:
[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
Hinton, `Overview of mini-batch gradient descent`
<www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_, 2012
Args:
decay: Decay rate for the exponentially weighted average of squared grads.
eps: Term added to the denominator to improve numerical stability.
initial_scale: Initial value for second moment.
add_eps_in_sqrt: Whether to add ``eps`` in the square root of the
denominator or outside the square root.
bias_correction: Whether to apply bias correction to the exponentially
weighted average of squared grads.
Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
nu = otu.tree_full_like(params, initial_scale) # second moment
return ScaleByRmsState(nu=nu)
if bias_correction:
return ScaleByRmsWithCountState(
count=jnp.zeros([], jnp.int32), nu=nu
)
else:
return ScaleByRmsState(nu=nu)

def update_fn(updates, state, params=None):
del params
nu = otu.tree_update_moment_per_elem_norm(updates, state.nu, decay, 2)
updates = jtu.tree_map(
lambda g, n: g * jax.lax.rsqrt(n + eps), updates, nu)
return updates, ScaleByRmsState(nu=nu)
if bias_correction:
count_inc = numerics.safe_int32_increment(state.count)
nu_hat = otu.tree_bias_correction(nu, decay, count_inc)
else:
count_inc = jnp.asarray(0)
nu_hat = nu
if add_eps_in_sqrt:
scaling = jtu.tree_map(lambda n: jax.lax.rsqrt(n + eps), nu_hat)
else:
scaling = jtu.tree_map(lambda n: 1/(jnp.sqrt(n) + eps), nu_hat)
updates = jtu.tree_map(lambda s, g: s * g, scaling, updates)
if bias_correction:
new_state = ScaleByRmsWithCountState(count=count_inc, nu=nu)
else:
new_state = ScaleByRmsState(nu=nu)
return updates, new_state

return base.GradientTransformation(init_fn, update_fn)


class ScaleByRStdDevState(NamedTuple):
"""State for centered exponential moving average of squares of updates."""
# Kept for backward compatibility, even though ScaleByRStdDevWithCountState
# encompasses this state.
mu: base.Updates
nu: base.Updates


class ScaleByRStdDevWithCountState(NamedTuple):
"""State for centered exponential moving average of squares of updates."""
count: chex.Array # shape=(), dtype=jnp.int32.
mu: base.Updates
nu: base.Updates


def scale_by_stddev(
decay: float = 0.9,
eps: float = 1e-8,
initial_scale: float = 0.
initial_scale: float = 0.,
add_eps_in_sqrt: bool = True,
bias_correction: bool = False,
) -> base.GradientTransformation:
"""Rescale updates by the root of the centered exp. moving average of squares.
References:
[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
Hinton, `Overview of mini-batch gradient descent`
<www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_, 2012
Graves, `Generating Sequences With Recurrent Neural Networks
<https://arxiv.org/pdf/1308.0850v5>`_, 2014
Args:
decay: Decay rate for the exponentially weighted average of squared grads.
eps: Term added to the denominator to improve numerical stability.
initial_scale: Initial value for second moment.
add_eps_in_sqrt: Whether to add ``eps`` in the square root of the
denominator or outside the square root.
bias_correction: Whether to apply bias correction to the first and
second moment.
Returns:
A `GradientTransformation` object.
Expand All @@ -148,16 +209,48 @@ def scale_by_stddev(
def init_fn(params):
mu = otu.tree_zeros_like(params) # First moment
nu = otu.tree_full_like(params, initial_scale) # second moment
return ScaleByRStdDevState(mu=mu, nu=nu)
if bias_correction:
return ScaleByRStdDevWithCountState(
count=jnp.zeros([], jnp.int32), mu=mu, nu=nu
)
else:
return ScaleByRStdDevState(mu=mu, nu=nu)

def update_fn(updates, state, params=None):
del params
mu = otu.tree_update_moment(updates, state.mu, decay, 1)
nu = otu.tree_update_moment_per_elem_norm(updates, state.nu, decay, 2)
if bias_correction:
count_inc = numerics.safe_int32_increment(state.count)
mu_hat = otu.tree_bias_correction(mu, decay, count_inc)
nu_hat = otu.tree_bias_correction(nu, decay, count_inc)
else:
count_inc = jnp.asarray(0)
mu_hat = mu
nu_hat = nu

if add_eps_in_sqrt:
scaling = jtu.tree_map(
lambda m, n: jax.lax.rsqrt(n - abs_sq(m) + eps),
mu_hat,
nu_hat,
)
else:
scaling = jtu.tree_map(
lambda m, n: 1/(jnp.sqrt(n - abs_sq(m)) + eps),
mu_hat,
nu_hat,
)
updates = jtu.tree_map(
lambda g, m, n: g * jax.lax.rsqrt(n - abs_sq(m) + eps),
updates, mu, nu)
return updates, ScaleByRStdDevState(mu=mu, nu=nu)
lambda s, g: s * g, scaling, updates
)
if bias_correction:
new_state = ScaleByRStdDevWithCountState(
count=count_inc, mu=mu, nu=nu
)
else:
new_state = ScaleByRStdDevState(mu=mu, nu=nu)
return updates, new_state

return base.GradientTransformation(init_fn, update_fn)

Expand All @@ -178,7 +271,7 @@ def scale_by_adam(
*,
nesterov: bool = False
) -> base.GradientTransformation:
"""Rescale updates according to the Adam algorithm.
r"""Rescale updates according to the Adam algorithm.
References:
Kingma et al, `Adam: A Method for Stochastic Optimization
Expand All @@ -188,10 +281,14 @@ def scale_by_adam(
<https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ>`_ 2016
.. warning::
PyTorch and optax's adam follow Algorithm 1 of the Kingma
and Ba's Adam paper, if reproducing old results note that TensorFlow
used instead the formulation just before Section 2.1 of the paper.
See https://github.com/deepmind/optax/issues/571 for more detail.
Default behavior of optax's RMSprop (``add_eps_in_sqrt=True``) differs from
Pytorch's implementations differ and could impact performance.
If ``add_eps_in_sqrt=True``, in the denominator, optax uses
:math:`\sqrt{v + \epsilon}` whereas PyTorch uses
:math:`\sqrt{v} + \epsilon`.
Using ``add_eps_in_sqrt=False`` in optax will match PyTorch's behavior.
See
https://github.com/google-deepmind/optax/issues/532 for more detail.
Args:
b1: Decay rate for the exponentially weighted average of grads.
Expand Down
11 changes: 11 additions & 0 deletions optax/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,17 @@ def test_scale_by_polyak_l1_norm(self, tol=1e-10):
print(grad, value, updates)
self.assertLess(objective(init_params - updates), tol)

def test_rms_match_adam(self):
"""Test scale_by_rms add_eps_in_sqrt=False matches scale_by_adam(b1=0)."""
rms = transform.scale_by_rms(
decay=0.999, add_eps_in_sqrt=False, bias_correction=True
)
adam = transform.scale_by_adam(b1=0)
rms_state = rms.init(self.init_params)
adam_state = adam.init(self.init_params)
rms_updates, _ = rms.update(self.per_step_updates, rms_state)
adam_updates, _ = adam.update(self.per_step_updates, adam_state)
chex.assert_trees_all_close(rms_updates, adam_updates, atol=1e-5, rtol=1e-5)

if __name__ == '__main__':
absltest.main()

0 comments on commit 073af1a

Please sign in to comment.