Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a mathematical description of AdamW #894

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,15 +530,50 @@ def adamw(
*,
nesterov: bool = False,
) -> base.GradientTransformation:
"""Adam with weight decay regularization.
r"""Adam with weight decay regularization.

AdamW uses weight decay to regularize learning towards small weights, as
this leads to better generalization. In SGD you can also use L2 regularization
to implement this as an additive loss term, however L2 regularization
does not behave as intended for adaptive gradient algorithms such as Adam.
does not behave as intended for adaptive gradient algorithms such as Adam,
see [Loshchilov et al, 2019].

Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`,
:math:`\varepsilon`, :math:`\bar{\varepsilon}` represent the arguments
``b1``, ``b2``, ``eps`` and ``eps_root`` respectively. The learning rate is
indexed by :math:`t` since the learning rate may also be provided by a
schedule function. Let :math:`\lambda` be the weight decay and
:math:`\theta_t` the parameter vector at time :math:`t`.

The ``init`` function of this optimizer initializes an internal state
:math:`S_0 := (m_0, v_0) = (0, 0)`, representing initial estimates for the
first and second moments. In practice these values are stored as pytrees
containing all zeros, with the same shape as the model updates.
At step :math:`t`, the ``update`` function of this optimizer takes as
arguments the incoming gradients :math:`g_t`, the optimizer state :math:`S_t`
and the parameters :math:`\theta_t` and computes updates :math:`u_t` and
new state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have,

.. math::

\begin{align*}
m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\
v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\
\hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\
\hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\
u_t &\leftarrow -\alpha_t \cdot \left( \hat{m}_t / \left({\sqrt{\hat{v}_t
+ \bar{\varepsilon}} + \varepsilon} \right) + \lambda \theta_{t} \right)\\
S_t &\leftarrow (m_t, v_t).
\end{align*}

This implementation can incorporate a momentum a la Nesterov introduced by
[Dozat 2016]. The resulting optimizer is then often referred as NAdamW.
With the keyword argument `nesterov=True`, the optimizer uses Nesterov
momentum, replacing the above :math:`\hat{m}_t` with

.. math::
\hat{m}_t \leftarrow
\beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}.

Examples:
>>> import optax
Expand Down
Loading