From f48327e3c8540932d1ffc7fdafc49a9c16343656 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 28 Jun 2023 14:50:52 +0300 Subject: [PATCH] Include additional norm classes into list of types that have bias --- src/super_gradients/training/utils/optimizer_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/super_gradients/training/utils/optimizer_utils.py b/src/super_gradients/training/utils/optimizer_utils.py index 5f47a7f72e..0d52ee5acf 100755 --- a/src/super_gradients/training/utils/optimizer_utils.py +++ b/src/super_gradients/training/utils/optimizer_utils.py @@ -62,11 +62,11 @@ def _get_no_decay_param_ids(module: nn.Module): NOTE - ALL MODULES WITH ATTRIBUTES NAMED BIAS AND ARE INSTANCE OF nn.Parameter WILL BE CONSIDERED A BIAS PARAM FOR ZERO WEIGHT DECAY. """ - batchnorm_types = (_BatchNorm,) + norm_types = (_BatchNorm, nn.GroupNorm, nn.LayerNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d) torch_weight_with_bias_types = (_ConvNd, nn.Linear) no_decay_ids = [] for name, m in module.named_modules(): - if isinstance(m, batchnorm_types): + if isinstance(m, norm_types): no_decay_ids.append(id(m.weight)) no_decay_ids.append(id(m.bias)) elif hasattr(m, "bias") and isinstance(m.bias, nn.Parameter):