Skip to content

Commit 364c567

Browse files
authored
Merge pull request #2357 from huggingface/more_opt_stuff
Add caution to Adan. Add decouple decay option to LAMB.
2 parents a02b1a8 + afdf11d commit 364c567

File tree

4 files changed

+89
-11
lines changed

4 files changed

+89
-11
lines changed

Diff for: timm/optim/_optim_factory.py

+37
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,20 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
485485
has_betas=True,
486486
defaults={'trust_clip': True}
487487
),
488+
OptimInfo(
489+
name='lambw',
490+
opt_class=Lamb,
491+
description='LAMB with decoupled weight decay',
492+
has_betas=True,
493+
defaults={'decoupled_decay': True}
494+
),
495+
OptimInfo(
496+
name='lambcw',
497+
opt_class=Lamb,
498+
description='LAMB with trust ratio clipping for stability and decoupled decay',
499+
has_betas=True,
500+
defaults={'trust_clip': True, 'decoupled_decay': True}
501+
),
488502
OptimInfo(
489503
name='lars',
490504
opt_class=Lars,
@@ -544,6 +558,22 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
544558
description='Cautious Adopt',
545559
defaults={'caution': True}
546560
),
561+
OptimInfo(
562+
name='cadan',
563+
opt_class=Adan,
564+
description='Cautious Adaptive Nesterov Momentum Algorithm',
565+
defaults={'caution': True, 'no_prox': False},
566+
has_betas=True,
567+
num_betas=3
568+
),
569+
OptimInfo(
570+
name='cadanw',
571+
opt_class=Adan,
572+
description='Cautious Adaptive Nesterov Momentum with decoupled weight decay',
573+
defaults={'caution': True, 'no_prox': True},
574+
has_betas=True,
575+
num_betas=3
576+
),
547577
OptimInfo(
548578
name='cadoptw',
549579
opt_class=Adopt,
@@ -557,6 +587,13 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
557587
has_betas=True,
558588
defaults={'caution': True}
559589
),
590+
OptimInfo(
591+
name='clambw',
592+
opt_class=Lamb,
593+
description='Cautious LAMB with decoupled weight decay',
594+
has_betas=True,
595+
defaults={'caution': True, 'decoupled_decay': True}
596+
),
560597
OptimInfo(
561598
name='claprop',
562599
opt_class=LaProp,

Diff for: timm/optim/adan.py

+42-10
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# limitations under the License.
2121

2222
import math
23-
from typing import List, Tuple
23+
from typing import List, Optional, Tuple
2424

2525
import torch
2626
from torch import Tensor
@@ -56,6 +56,7 @@ class Adan(Optimizer):
5656
eps: Term added to the denominator to improve numerical stability.
5757
weight_decay: Decoupled weight decay (L2 penalty)
5858
no_prox: How to perform the weight decay
59+
caution: Enable caution from 'Cautious Optimizers'
5960
foreach: If True would use torch._foreach implementation. Faster but uses slightly more memory.
6061
"""
6162

@@ -66,7 +67,8 @@ def __init__(self,
6667
eps: float = 1e-8,
6768
weight_decay: float = 0.0,
6869
no_prox: bool = False,
69-
foreach: bool = True,
70+
caution: bool = False,
71+
foreach: Optional[bool] = None,
7072
):
7173
if not 0.0 <= lr:
7274
raise ValueError('Invalid learning rate: {}'.format(lr))
@@ -85,6 +87,7 @@ def __init__(self,
8587
eps=eps,
8688
weight_decay=weight_decay,
8789
no_prox=no_prox,
90+
caution=caution,
8891
foreach=foreach,
8992
)
9093
super().__init__(params, defaults)
@@ -93,6 +96,7 @@ def __setstate__(self, state):
9396
super(Adan, self).__setstate__(state)
9497
for group in self.param_groups:
9598
group.setdefault('no_prox', False)
99+
group.setdefault('caution', False)
96100

97101
@torch.no_grad()
98102
def restart_opt(self):
@@ -118,6 +122,11 @@ def step(self, closure=None):
118122
with torch.enable_grad():
119123
loss = closure()
120124

125+
try:
126+
has_scalar_maximum = 'Scalar' in torch.ops.aten._foreach_maximum_.overloads()
127+
except:
128+
has_scalar_maximum = False
129+
121130
for group in self.param_groups:
122131
params_with_grad = []
123132
grads = []
@@ -161,9 +170,19 @@ def step(self, closure=None):
161170
if not params_with_grad:
162171
continue
163172

164-
kwargs = dict(
165-
params=params_with_grad,
166-
grads=grads,
173+
if group['foreach'] is None:
174+
use_foreach = not group['caution'] or has_scalar_maximum
175+
else:
176+
use_foreach = group['foreach']
177+
178+
if use_foreach:
179+
func = _multi_tensor_adan
180+
else:
181+
func = _single_tensor_adan
182+
183+
func(
184+
params_with_grad,
185+
grads,
167186
exp_avgs=exp_avgs,
168187
exp_avg_sqs=exp_avg_sqs,
169188
exp_avg_diffs=exp_avg_diffs,
@@ -178,13 +197,9 @@ def step(self, closure=None):
178197
weight_decay=group['weight_decay'],
179198
eps=group['eps'],
180199
no_prox=group['no_prox'],
200+
caution=group['caution'],
181201
)
182202

183-
if group['foreach']:
184-
_multi_tensor_adan(**kwargs)
185-
else:
186-
_single_tensor_adan(**kwargs)
187-
188203
return loss
189204

190205

@@ -206,6 +221,7 @@ def _single_tensor_adan(
206221
weight_decay: float,
207222
eps: float,
208223
no_prox: bool,
224+
caution: bool,
209225
):
210226
for i, param in enumerate(params):
211227
grad = grads[i]
@@ -227,6 +243,12 @@ def _single_tensor_adan(
227243
step_size_diff = lr * beta2 / bias_correction2
228244
step_size = lr / bias_correction1
229245

246+
if caution:
247+
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
248+
mask = (exp_avg * grad > 0).to(grad.dtype)
249+
mask.div_(mask.mean().clamp_(min=1e-3))
250+
exp_avg = exp_avg * mask
251+
230252
if no_prox:
231253
param.mul_(1 - lr * weight_decay)
232254
param.addcdiv_(exp_avg, denom, value=-step_size)
@@ -257,6 +279,7 @@ def _multi_tensor_adan(
257279
weight_decay: float,
258280
eps: float,
259281
no_prox: bool,
282+
caution: bool,
260283
):
261284
if len(params) == 0:
262285
return
@@ -282,6 +305,15 @@ def _multi_tensor_adan(
282305
step_size_diff = lr * beta2 / bias_correction2
283306
step_size = lr / bias_correction1
284307

308+
if caution:
309+
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
310+
masks = torch._foreach_mul(exp_avgs, grads)
311+
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
312+
mask_scale = [m.mean() for m in masks]
313+
torch._foreach_maximum_(mask_scale, 1e-3)
314+
torch._foreach_div_(masks, mask_scale)
315+
exp_avgs = torch._foreach_mul(exp_avgs, masks)
316+
285317
if no_prox:
286318
torch._foreach_mul_(params, 1 - lr * weight_decay)
287319
torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)

Diff for: timm/optim/lamb.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
trust_clip: bool = False,
9595
always_adapt: bool = False,
9696
caution: bool = False,
97+
decoupled_decay: bool = False,
9798
):
9899
defaults = dict(
99100
lr=lr,
@@ -106,13 +107,15 @@ def __init__(
106107
trust_clip=trust_clip,
107108
always_adapt=always_adapt,
108109
caution=caution,
110+
decoupled_decay=decoupled_decay,
109111
)
110112
super().__init__(params, defaults)
111113

112114
def __setstate__(self, state):
113115
super().__setstate__(state)
114116
for group in self.param_groups:
115117
group.setdefault('caution', False)
118+
group.setdefault('decoupled_decay', False)
116119

117120
def _get_clip_grad_norm(self):
118121
max_grad_norm = self.defaults['max_grad_norm']
@@ -199,7 +202,10 @@ def step(self, closure=None):
199202

200203
weight_decay = group['weight_decay']
201204
if weight_decay != 0:
202-
update.add_(p, alpha=weight_decay)
205+
if group.get('decoupled_decay', False):
206+
p.add_(p, alpha=-group['lr'] * weight_decay)
207+
else:
208+
update.add_(p, alpha=weight_decay)
203209

204210
if weight_decay != 0 or group['always_adapt']:
205211
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are

Diff for: timm/optim/mars.py

+3
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,13 @@ def _mars_single_tensor_step(
5454
if c_t_norm > 1.:
5555
c_t = c_t / c_t_norm
5656
exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1)
57+
5758
if caution:
59+
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
5860
mask = (exp_avg * grad > 0).to(grad.dtype)
5961
mask.div_(mask.mean().clamp_(min=1e-3))
6062
exp_avg = exp_avg * mask
63+
6164
if mars_type == "adamw":
6265
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
6366
bias_correction1 = 1.0 - beta1 ** step

0 commit comments

Comments
 (0)