20
20
# limitations under the License.
21
21
22
22
import math
23
- from typing import List , Tuple
23
+ from typing import List , Optional , Tuple
24
24
25
25
import torch
26
26
from torch import Tensor
@@ -56,6 +56,7 @@ class Adan(Optimizer):
56
56
eps: Term added to the denominator to improve numerical stability.
57
57
weight_decay: Decoupled weight decay (L2 penalty)
58
58
no_prox: How to perform the weight decay
59
+ caution: Enable caution from 'Cautious Optimizers'
59
60
foreach: If True would use torch._foreach implementation. Faster but uses slightly more memory.
60
61
"""
61
62
@@ -66,7 +67,8 @@ def __init__(self,
66
67
eps : float = 1e-8 ,
67
68
weight_decay : float = 0.0 ,
68
69
no_prox : bool = False ,
69
- foreach : bool = True ,
70
+ caution : bool = False ,
71
+ foreach : Optional [bool ] = None ,
70
72
):
71
73
if not 0.0 <= lr :
72
74
raise ValueError ('Invalid learning rate: {}' .format (lr ))
@@ -85,6 +87,7 @@ def __init__(self,
85
87
eps = eps ,
86
88
weight_decay = weight_decay ,
87
89
no_prox = no_prox ,
90
+ caution = caution ,
88
91
foreach = foreach ,
89
92
)
90
93
super ().__init__ (params , defaults )
@@ -93,6 +96,7 @@ def __setstate__(self, state):
93
96
super (Adan , self ).__setstate__ (state )
94
97
for group in self .param_groups :
95
98
group .setdefault ('no_prox' , False )
99
+ group .setdefault ('caution' , False )
96
100
97
101
@torch .no_grad ()
98
102
def restart_opt (self ):
@@ -118,6 +122,11 @@ def step(self, closure=None):
118
122
with torch .enable_grad ():
119
123
loss = closure ()
120
124
125
+ try :
126
+ has_scalar_maximum = 'Scalar' in torch .ops .aten ._foreach_maximum_ .overloads ()
127
+ except :
128
+ has_scalar_maximum = False
129
+
121
130
for group in self .param_groups :
122
131
params_with_grad = []
123
132
grads = []
@@ -161,9 +170,19 @@ def step(self, closure=None):
161
170
if not params_with_grad :
162
171
continue
163
172
164
- kwargs = dict (
165
- params = params_with_grad ,
166
- grads = grads ,
173
+ if group ['foreach' ] is None :
174
+ use_foreach = not group ['caution' ] or has_scalar_maximum
175
+ else :
176
+ use_foreach = group ['foreach' ]
177
+
178
+ if use_foreach :
179
+ func = _multi_tensor_adan
180
+ else :
181
+ func = _single_tensor_adan
182
+
183
+ func (
184
+ params_with_grad ,
185
+ grads ,
167
186
exp_avgs = exp_avgs ,
168
187
exp_avg_sqs = exp_avg_sqs ,
169
188
exp_avg_diffs = exp_avg_diffs ,
@@ -178,13 +197,9 @@ def step(self, closure=None):
178
197
weight_decay = group ['weight_decay' ],
179
198
eps = group ['eps' ],
180
199
no_prox = group ['no_prox' ],
200
+ caution = group ['caution' ],
181
201
)
182
202
183
- if group ['foreach' ]:
184
- _multi_tensor_adan (** kwargs )
185
- else :
186
- _single_tensor_adan (** kwargs )
187
-
188
203
return loss
189
204
190
205
@@ -206,6 +221,7 @@ def _single_tensor_adan(
206
221
weight_decay : float ,
207
222
eps : float ,
208
223
no_prox : bool ,
224
+ caution : bool ,
209
225
):
210
226
for i , param in enumerate (params ):
211
227
grad = grads [i ]
@@ -227,6 +243,12 @@ def _single_tensor_adan(
227
243
step_size_diff = lr * beta2 / bias_correction2
228
244
step_size = lr / bias_correction1
229
245
246
+ if caution :
247
+ # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
248
+ mask = (exp_avg * grad > 0 ).to (grad .dtype )
249
+ mask .div_ (mask .mean ().clamp_ (min = 1e-3 ))
250
+ exp_avg = exp_avg * mask
251
+
230
252
if no_prox :
231
253
param .mul_ (1 - lr * weight_decay )
232
254
param .addcdiv_ (exp_avg , denom , value = - step_size )
@@ -257,6 +279,7 @@ def _multi_tensor_adan(
257
279
weight_decay : float ,
258
280
eps : float ,
259
281
no_prox : bool ,
282
+ caution : bool ,
260
283
):
261
284
if len (params ) == 0 :
262
285
return
@@ -282,6 +305,15 @@ def _multi_tensor_adan(
282
305
step_size_diff = lr * beta2 / bias_correction2
283
306
step_size = lr / bias_correction1
284
307
308
+ if caution :
309
+ # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
310
+ masks = torch ._foreach_mul (exp_avgs , grads )
311
+ masks = [(m > 0 ).to (g .dtype ) for m , g in zip (masks , grads )]
312
+ mask_scale = [m .mean () for m in masks ]
313
+ torch ._foreach_maximum_ (mask_scale , 1e-3 )
314
+ torch ._foreach_div_ (masks , mask_scale )
315
+ exp_avgs = torch ._foreach_mul (exp_avgs , masks )
316
+
285
317
if no_prox :
286
318
torch ._foreach_mul_ (params , 1 - lr * weight_decay )
287
319
torch ._foreach_addcdiv_ (params , exp_avgs , denom , value = - step_size )
0 commit comments