From bb6b2b404ef6a809ae20a3e1f01dd353f3ac91e6 Mon Sep 17 00:00:00 2001 From: gbruno16 Date: Thu, 28 Mar 2024 01:27:26 +0100 Subject: [PATCH 1/2] Add a mathematical description of AdamW --- optax/_src/alias.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 7f6efefa1..4d109b6ec 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -530,15 +530,49 @@ def adamw( *, nesterov: bool = False, ) -> base.GradientTransformation: - """Adam with weight decay regularization. + r"""Adam with weight decay regularization. AdamW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization to implement this as an additive loss term, however L2 regularization does not behave as intended for adaptive gradient algorithms such as Adam. + Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`, + :math:`\varepsilon`, :math:`\bar{\varepsilon}` represent the arguments + ``b1``, ``b2``, ``eps`` and ``eps_root`` respectively. The learning rate is + indexed by :math:`t` since the learning rate may also be provided by a + schedule function. Let :math:`\lambda` be the weight decay and + :math:`\theta_t` the parameter vector at time :math:`t`. + + The ``init`` function of this optimizer initializes an internal state + :math:`S_0 := (m_0, v_0) = (0, 0)`, representing initial estimates for the + first and second moments. In practice these values are stored as pytrees + containing all zeros, with the same shape as the model updates. + At step :math:`t`, the ``update`` function of this optimizer takes as + arguments the incoming gradients :math:`g_t`, the optimizer state :math:`S_t` + and the parameters :math:`\theta_t` and computes updates :math:`u_t` and + new state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have, + + .. math:: + + \begin{align*} + m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ + v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ + \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ + \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ + u_t &\leftarrow -\alpha_t \cdot \left( \hat{m}_t / \left({\sqrt{\hat{v}_t + + \bar{\varepsilon}} + \varepsilon} \right) + \lambda \theta_{t-1} \right)\\ + S_t &\leftarrow (m_t, v_t). + \end{align*} + This implementation can incorporate a momentum a la Nesterov introduced by [Dozat 2016]. The resulting optimizer is then often referred as NAdamW. + With the keyword argument `nesterov=True`, the optimizer uses Nesterov + momentum, replacing the above :math:`\hat{m}_t` with + + .. math:: + \hat{m}_t \leftarrow + \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}. Examples: >>> import optax From 7c3c00f4c53a25160be179fb83a8513e0b2ff971 Mon Sep 17 00:00:00 2001 From: gbruno16 Date: Thu, 28 Mar 2024 01:55:05 +0100 Subject: [PATCH 2/2] Add corrections --- optax/_src/alias.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 4d109b6ec..49c735f6d 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -535,7 +535,8 @@ def adamw( AdamW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization to implement this as an additive loss term, however L2 regularization - does not behave as intended for adaptive gradient algorithms such as Adam. + does not behave as intended for adaptive gradient algorithms such as Adam, + see [Loshchilov et al, 2019]. Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`, :math:`\varepsilon`, :math:`\bar{\varepsilon}` represent the arguments @@ -561,7 +562,7 @@ def adamw( \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \left( \hat{m}_t / \left({\sqrt{\hat{v}_t - + \bar{\varepsilon}} + \varepsilon} \right) + \lambda \theta_{t-1} \right)\\ + + \bar{\varepsilon}} + \varepsilon} \right) + \lambda \theta_{t} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}