Skip to content

Commit 7371f85

Browse files
[Fea] Support amsgrad in Adam/AdamW (#1033)
* support amsgrad in Adam/AdamW * Apply suggestions from code review Co-authored-by: megemini <megemini@outlook.com> --------- Co-authored-by: megemini <megemini@outlook.com>
1 parent 213bd31 commit 7371f85

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

ppsci/optimizer/optimizer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ class Adam:
188188
weight_decay (Optional[Union[float, regularizer.L1Decay, regularizer.L2Decay]]): Regularization strategy. Defaults to None.
189189
grad_clip (Optional[Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm]]): Gradient clipping strategy. Defaults to None.
190190
lazy_mode (bool, optional): Whether to enable lazy mode for moving-average. Defaults to False.
191+
amsgrad (bool, optional): Whether to use the AMSGrad variant of this algorithm from the paper
192+
`On the Convergence of Adam and Beyond <https://openreview.net/forum?id=ryQu7f-RZ>`_. Defaults to False.
191193
192194
Examples:
193195
>>> import ppsci
@@ -208,6 +210,7 @@ def __init__(
208210
Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm]
209211
] = None,
210212
lazy_mode: bool = False,
213+
amsgrad: bool = False,
211214
):
212215
self.learning_rate = learning_rate
213216
self.beta1 = beta1
@@ -217,6 +220,7 @@ def __init__(
217220
self.weight_decay = weight_decay
218221
self.grad_clip = grad_clip
219222
self.lazy_mode = lazy_mode
223+
self.amsgrad = amsgrad
220224

221225
def __call__(self, model_list: Union[nn.Layer, Tuple[nn.Layer, ...]]):
222226
# model_list is None in static graph
@@ -225,6 +229,11 @@ def __call__(self, model_list: Union[nn.Layer, Tuple[nn.Layer, ...]]):
225229
parameters = (
226230
sum([m.parameters() for m in model_list], []) if model_list else None
227231
)
232+
import inspect
233+
234+
extra_kwargs = {}
235+
if "amsgrad" in inspect.signature(optim.Adam.__init__).parameters:
236+
extra_kwargs["amsgrad"] = self.amsgrad
228237
opt = optim.Adam(
229238
learning_rate=self.learning_rate,
230239
beta1=self.beta1,
@@ -234,6 +243,7 @@ def __call__(self, model_list: Union[nn.Layer, Tuple[nn.Layer, ...]]):
234243
grad_clip=self.grad_clip,
235244
lazy_mode=self.lazy_mode,
236245
parameters=parameters,
246+
**extra_kwargs,
237247
)
238248
return opt
239249

@@ -386,6 +396,8 @@ class AdamW:
386396
grad_clip (Optional[Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm]]): Gradient clipping strategy. Defaults to None.
387397
no_weight_decay_name (Optional[str]): List of names of no weight decay parameters split by white space. Defaults to None.
388398
one_dim_param_no_weight_decay (bool, optional): Apply no weight decay on 1-D parameter(s). Defaults to False.
399+
amsgrad (bool, optional): Whether to use the AMSGrad variant of this algorithm from the paper
400+
`On the Convergence of Adam and Beyond <https://openreview.net/forum?id=ryQu7f-RZ>`_. Defaults to False.
389401
390402
Examples:
391403
>>> import ppsci
@@ -405,6 +417,7 @@ def __init__(
405417
] = None,
406418
no_weight_decay_name: Optional[str] = None,
407419
one_dim_param_no_weight_decay: bool = False,
420+
amsgrad: bool = False,
408421
):
409422
super().__init__()
410423
self.learning_rate = learning_rate
@@ -417,6 +430,7 @@ def __init__(
417430
no_weight_decay_name.split() if no_weight_decay_name else []
418431
)
419432
self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay
433+
self.amsgrad = amsgrad
420434

421435
def __call__(self, model_list: Union[nn.Layer, Tuple[nn.Layer, ...]]):
422436
# model_list is None in static graph
@@ -458,6 +472,11 @@ def __call__(self, model_list: Union[nn.Layer, Tuple[nn.Layer, ...]]):
458472
if model_list
459473
else []
460474
)
475+
import inspect
476+
477+
extra_kwargs = {}
478+
if "amsgrad" in inspect.signature(optim.AdamW.__init__).parameters:
479+
extra_kwargs["amsgrad"] = self.amsgrad
461480

462481
opt = optim.AdamW(
463482
learning_rate=self.learning_rate,
@@ -468,6 +487,7 @@ def __call__(self, model_list: Union[nn.Layer, Tuple[nn.Layer, ...]]):
468487
weight_decay=self.weight_decay,
469488
grad_clip=self.grad_clip,
470489
apply_decay_param_fun=self._apply_decay_param_fun,
490+
**extra_kwargs,
471491
)
472492
return opt
473493

0 commit comments

Comments
 (0)