Skip to content

Commit

Permalink
feat(optim): add support of AdEMAMix optimizer (#373)
Browse files Browse the repository at this point in the history
* docs(optim): fix docstrings of optimizers

* feat(optim): add AdEMAMix

* feat(references): add new optimizer to training options

* docs(readme): update README papers

* docs(docs): add new optimizer

* test(optim): add test for new optimizer

* style(ruff): fix lint

* ci(github): fix docker syntax in CI

* test(models): fix repvgg test

* fix(optim): fix AdEMAMix
  • Loading branch information
frgfm authored Sep 9, 2024
1 parent c47411b commit 2973184
Show file tree
Hide file tree
Showing 15 changed files with 198 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/builds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ jobs:
poetry lock
poetry export -f requirements.txt --without-hashes --output requirements.txt
- name: Build & run docker
run: cd api && docker-compose up -d --build
run: cd api && docker compose up -d --build
- name: Docker sanity check
run: sleep 15 && nc -vz localhost 8050
- name: Ping server
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
poetry lock
poetry export -f requirements.txt --without-hashes --with dev --output requirements.txt
- name: Build & run docker
run: cd api && docker-compose up -d --build
run: cd api && docker compose up -d --build
- name: Docker sanity check
run: sleep 15 && nc -vz localhost 8050
- name: Ping server
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pip install -e Holocron/.
- boxes: [Distance-IoU & Complete-IoU losses](https://arxiv.org/abs/1911.08287)

### Trying something else than Adam
- Optimizer: [LARS](https://arxiv.org/abs/1708.03888), [Lamb](https://arxiv.org/abs/1904.00962), [TAdam](https://arxiv.org/abs/2003.00179), [AdamP](https://arxiv.org/abs/2006.08217), [AdaBelief](https://arxiv.org/abs/2010.07468), [Adan](https://arxiv.org/abs/2208.06677), and customized versions (RaLars)
- Optimizer: [LARS](https://arxiv.org/abs/1708.03888), [Lamb](https://arxiv.org/abs/1904.00962), [TAdam](https://arxiv.org/abs/2003.00179), [AdamP](https://arxiv.org/abs/2006.08217), [AdaBelief](https://arxiv.org/abs/2010.07468), [Adan](https://arxiv.org/abs/2208.06677), and customized versions (RaLars), [AdEMAMix](https://arxiv.org/abs/2409.03137)
- Optimizer wrapper: [Lookahead](https://arxiv.org/abs/1907.08610), Scout (experimental)


Expand Down
2 changes: 1 addition & 1 deletion api/tests/routes/test_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_classification(test_app_asyncio, mock_classification_image):
response = await test_app_asyncio.post("/classification", files={"file": mock_classification_image})
assert response.status_code == 200
Expand Down
2 changes: 2 additions & 0 deletions docs/source/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Implementations of recent parameter optimizer for Pytorch modules.

.. autoclass:: Adan

.. autoclass:: AdEMAMix


Optimizer wrappers
------------------
Expand Down
1 change: 1 addition & 0 deletions holocron/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .adabelief import AdaBelief
from .adamp import AdamP
from .adan import Adan
from .ademamix import AdEMAMix
from .lamb import LAMB
from .lars import LARS
from .ralars import RaLars
Expand Down
2 changes: 1 addition & 1 deletion holocron/optim/adabelief.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class AdaBelief(Adam):
s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon
where :math:`g_t` is the gradient of :math:`\theta_t`,
:math:`\beta_1, \beta_2 \in [0, 1]^3` are the exponential average smoothing coefficients,
:math:`\beta_1, \beta_2 \in [0, 1]^2` are the exponential average smoothing coefficients,
:math:`m_0 = 0,\ s_0 = 0`, :math:`\epsilon > 0`.
Then we correct their biases using:
Expand Down
2 changes: 1 addition & 1 deletion holocron/optim/adamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class AdamP(Adam):
v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) g_t^2
where :math:`g_t` is the gradient of :math:`\theta_t`,
:math:`\beta_1, \beta_2 \in [0, 1]^3` are the exponential average smoothing coefficients,
:math:`\beta_1, \beta_2 \in [0, 1]^2` are the exponential average smoothing coefficients,
:math:`m_0 = g_0,\ v_0 = 0`.
Then we correct their biases using:
Expand Down
176 changes: 176 additions & 0 deletions holocron/optim/ademamix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (C) 2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import math
from typing import Callable, Iterable, List, Optional, Tuple

import torch
from torch import Tensor
from torch.optim import Optimizer

__all__ = ["AdEMAMix", "ademamix"]


class AdEMAMix(Optimizer):
r"""Implements the AdEMAMix optimizer from `"The AdEMAMix Optimizer: Better, Faster, Older" <https://arxiv.org/pdf/2409.03137>`_.
The estimation of momentums is described as follows, :math:`\forall t \geq 1`:
.. math::
m_{1,t} \leftarrow \beta_1 m_{1, t-1} + (1 - \beta_1) g_t \\
m_{2,t} \leftarrow \beta_3 m_{2, t-1} + (1 - \beta_3) g_t \\
s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon
where :math:`g_t` is the gradient of :math:`\theta_t`,
:math:`\beta_1, \beta_2, \beta_3 \in [0, 1]^3` are the exponential average smoothing coefficients,
:math:`m_{1,0} = 0,\ m_{2,0} = 0,\ s_0 = 0`, :math:`\epsilon > 0`.
Then we correct their biases using:
.. math::
\hat{m_{1,t}} \leftarrow \frac{m_{1,t}}{1 - \beta_1^t} \\
\hat{s_t} \leftarrow \frac{s_t}{1 - \beta_2^t}
And finally the update step is performed using the following rule:
.. math::
\theta_t \leftarrow \theta_{t-1} - \eta \frac{\hat{m_{1,t}} + \alpha m_{2,t}}{\sqrt{\hat{s_t}} + \epsilon}
where :math:`\theta_t` is the parameter value at step :math:`t` (:math:`\theta_0` being the initialization value),
:math:`\eta` is the learning rate, :math:`\alpha > 0` :math:`\epsilon > 0`.
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate
betas (Tuple[float, float, float], optional): coefficients used for running averages (default: (0.9, 0.999, 0.9999))
alpha (float, optional): the exponential decay rate of the second moment estimates (default: 5.0)
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (bool, optional): whether to use the AMSGrad variant (default: False)
"""

def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
eps: float = 1e-8,
weight_decay: float = 0.0,
) -> None:
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if eps < 0.0:
raise ValueError(f"Invalid epsilon value: {eps}")
for idx, beta in enumerate(betas):
if not 0.0 <= beta < 1.0:
raise ValueError(f"Invalid beta parameter at index {idx}: {beta}")
defaults = {"lr": lr, "betas": betas, "alpha": alpha, "eps": eps, "weight_decay": weight_decay}
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override]
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avgs_slow = []
exp_avg_sqs = []
state_steps = []

for p in group["params"]:
if p.grad is not None:
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError(f"{self.__class__.__name__} does not support sparse gradients")
grads.append(p.grad)

state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg_slow"] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)

exp_avgs.append(state["exp_avg"])
exp_avgs_slow.append(state["exp_avg_slow"])
exp_avg_sqs.append(state["exp_avg_sq"])

# update the steps for each param group update
state["step"] += 1
# record the step after step update
state_steps.append(state["step"])

beta1, beta2, beta3 = group["betas"]
ademamix(
params_with_grad,
grads,
exp_avgs,
exp_avgs_slow,
exp_avg_sqs,
state_steps,
beta1,
beta2,
beta3,
group["alpha"],
group["lr"],
group["weight_decay"],
group["eps"],
)
return loss


def ademamix(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avgs_slow: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[int],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
lr: float,
weight_decay: float,
eps: float,
) -> None:
r"""Functional API that performs AdaBelief algorithm computation.
See :class:`~holocron.optim.AdaBelief` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
m1 = exp_avgs[i]
m2 = exp_avgs_slow[i]
nu = exp_avg_sqs[i]
step = state_steps[i]

bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step

if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)

# Decay the first and second moment running average coefficient
m1.mul_(beta1).add_(grad, alpha=1 - beta1)
nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
m2.mul_(beta3).add_(grad, alpha=1 - beta3)

denom = (nu.sqrt() / math.sqrt(bias_correction2)).add_(eps)

param.addcdiv_(m1 / bias_correction1 + alpha * m2, denom, value=-lr)
2 changes: 1 addition & 1 deletion holocron/optim/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class LAMB(Optimizer):
v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) g_t^2
where :math:`g_t` is the gradient of :math:`\theta_t`,
:math:`\beta_1, \beta_2 \in [0, 1]^3` are the exponential average smoothing coefficients,
:math:`\beta_1, \beta_2 \in [0, 1]^2` are the exponential average smoothing coefficients,
:math:`m_0 = 0,\ v_0 = 0`.
Then we correct their biases using:
Expand Down
2 changes: 1 addition & 1 deletion holocron/optim/tadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TAdam(Optimizer):
v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) (g_t - g_{t-1})
where :math:`g_t` is the gradient of :math:`\theta_t`,
:math:`\beta_1, \beta_2 \in [0, 1]^3` are the exponential average smoothing coefficients,
:math:`\beta_1, \beta_2 \in [0, 1]^2` are the exponential average smoothing coefficients,
:math:`m_0 = 0,\ v_0 = 0,\ W_0 = \frac{\beta_1}{1 - \beta_1}`;
:math:`\nu` is the degrees of freedom and :math:`d` if the number of dimensions of the parameter gradient.
Expand Down
6 changes: 5 additions & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from holocron.models import classification
from holocron.models.presets import CIFAR10 as CIF10
from holocron.models.presets import IMAGENETTE
from holocron.optim import AdaBelief, AdamP
from holocron.optim import AdaBelief, AdamP, AdEMAMix
from holocron.trainer import ClassificationTrainer
from holocron.utils.data import Mixup
from holocron.utils.misc import find_image_size
Expand Down Expand Up @@ -208,6 +208,10 @@ def main(args):
optimizer = AdamP(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay)
elif args.opt == "adabelief":
optimizer = AdaBelief(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay)
elif args.opt == "ademamix":
optimizer = AdEMAMix(
model_params, args.lr, betas=(0.95, 0.99, 0.9999), eps=1e-6, weight_decay=args.weight_decay
)

log_wb = lambda metrics: wandb.log(metrics) if args.wb else None
trainer = ClassificationTrainer(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_repvgg_reparametrize():
assert mod.weight.data.shape[2:] == (3, 3)
# Check that values are still matching
with torch.no_grad():
assert torch.allclose(out, model(x), atol=1e-4)
assert torch.allclose(out, model(x), atol=1e-3)


def test_mobileone_reparametrize():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from holocron import ops


@pytest.fixture()
@pytest.fixture
def boxes():
return torch.tensor(
[[0, 0, 100, 100], [50, 50, 100, 100], [50, 50, 150, 150], [100, 100, 200, 200]], dtype=torch.float32
Expand Down
4 changes: 4 additions & 0 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ def test_adamp():

def test_adan():
_test_optimizer("Adan")


def test_ademamix():
_test_optimizer("AdEMAMix")

0 comments on commit 2973184

Please sign in to comment.