@@ -188,6 +188,8 @@ class Adam:
188188 weight_decay (Optional[Union[float, regularizer.L1Decay, regularizer.L2Decay]]): Regularization strategy. Defaults to None.
189189 grad_clip (Optional[Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm]]): Gradient clipping strategy. Defaults to None.
190190 lazy_mode (bool, optional): Whether to enable lazy mode for moving-average. Defaults to False.
191+ amsgrad (bool, optional): Whether to use the AMSGrad variant of this algorithm from the paper
192+ `On the Convergence of Adam and Beyond <https://openreview.net/forum?id=ryQu7f-RZ>`_. Defaults to False.
191193
192194 Examples:
193195 >>> import ppsci
@@ -208,6 +210,7 @@ def __init__(
208210 Union [nn .ClipGradByNorm , nn .ClipGradByValue , nn .ClipGradByGlobalNorm ]
209211 ] = None ,
210212 lazy_mode : bool = False ,
213+ amsgrad : bool = False ,
211214 ):
212215 self .learning_rate = learning_rate
213216 self .beta1 = beta1
@@ -217,6 +220,7 @@ def __init__(
217220 self .weight_decay = weight_decay
218221 self .grad_clip = grad_clip
219222 self .lazy_mode = lazy_mode
223+ self .amsgrad = amsgrad
220224
221225 def __call__ (self , model_list : Union [nn .Layer , Tuple [nn .Layer , ...]]):
222226 # model_list is None in static graph
@@ -225,6 +229,11 @@ def __call__(self, model_list: Union[nn.Layer, Tuple[nn.Layer, ...]]):
225229 parameters = (
226230 sum ([m .parameters () for m in model_list ], []) if model_list else None
227231 )
232+ import inspect
233+
234+ extra_kwargs = {}
235+ if "amsgrad" in inspect .signature (optim .Adam .__init__ ).parameters :
236+ extra_kwargs ["amsgrad" ] = self .amsgrad
228237 opt = optim .Adam (
229238 learning_rate = self .learning_rate ,
230239 beta1 = self .beta1 ,
@@ -234,6 +243,7 @@ def __call__(self, model_list: Union[nn.Layer, Tuple[nn.Layer, ...]]):
234243 grad_clip = self .grad_clip ,
235244 lazy_mode = self .lazy_mode ,
236245 parameters = parameters ,
246+ ** extra_kwargs ,
237247 )
238248 return opt
239249
@@ -386,6 +396,8 @@ class AdamW:
386396 grad_clip (Optional[Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm]]): Gradient clipping strategy. Defaults to None.
387397 no_weight_decay_name (Optional[str]): List of names of no weight decay parameters split by white space. Defaults to None.
388398 one_dim_param_no_weight_decay (bool, optional): Apply no weight decay on 1-D parameter(s). Defaults to False.
399+ amsgrad (bool, optional): Whether to use the AMSGrad variant of this algorithm from the paper
400+ `On the Convergence of Adam and Beyond <https://openreview.net/forum?id=ryQu7f-RZ>`_. Defaults to False.
389401
390402 Examples:
391403 >>> import ppsci
@@ -405,6 +417,7 @@ def __init__(
405417 ] = None ,
406418 no_weight_decay_name : Optional [str ] = None ,
407419 one_dim_param_no_weight_decay : bool = False ,
420+ amsgrad : bool = False ,
408421 ):
409422 super ().__init__ ()
410423 self .learning_rate = learning_rate
@@ -417,6 +430,7 @@ def __init__(
417430 no_weight_decay_name .split () if no_weight_decay_name else []
418431 )
419432 self .one_dim_param_no_weight_decay = one_dim_param_no_weight_decay
433+ self .amsgrad = amsgrad
420434
421435 def __call__ (self , model_list : Union [nn .Layer , Tuple [nn .Layer , ...]]):
422436 # model_list is None in static graph
@@ -458,6 +472,11 @@ def __call__(self, model_list: Union[nn.Layer, Tuple[nn.Layer, ...]]):
458472 if model_list
459473 else []
460474 )
475+ import inspect
476+
477+ extra_kwargs = {}
478+ if "amsgrad" in inspect .signature (optim .AdamW .__init__ ).parameters :
479+ extra_kwargs ["amsgrad" ] = self .amsgrad
461480
462481 opt = optim .AdamW (
463482 learning_rate = self .learning_rate ,
@@ -468,6 +487,7 @@ def __call__(self, model_list: Union[nn.Layer, Tuple[nn.Layer, ...]]):
468487 weight_decay = self .weight_decay ,
469488 grad_clip = self .grad_clip ,
470489 apply_decay_param_fun = self ._apply_decay_param_fun ,
490+ ** extra_kwargs ,
471491 )
472492 return opt
473493
0 commit comments