diff --git a/mmseg/core/__init__.py b/mmseg/core/__init__.py index c91334926e..1a077d2f1f 100644 --- a/mmseg/core/__init__.py +++ b/mmseg/core/__init__.py @@ -2,12 +2,10 @@ from .builder import (OPTIMIZER_BUILDERS, build_optimizer, build_optimizer_constructor) from .evaluation import * # noqa: F401, F403 -from .layer_decay_optimizer_constructor import \ - LayerDecayOptimizerConstructor # noqa: F401 +from .optimizers import * # noqa: F401, F403 from .seg import * # noqa: F401, F403 from .utils import * # noqa: F401, F403 __all__ = [ - 'LayerDecayOptimizerConstructor', 'OPTIMIZER_BUILDERS', 'build_optimizer', - 'build_optimizer_constructor' + 'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor' ] diff --git a/mmseg/core/layer_decay_optimizer_constructor.py b/mmseg/core/layer_decay_optimizer_constructor.py deleted file mode 100644 index bd3db92c5c..0000000000 --- a/mmseg/core/layer_decay_optimizer_constructor.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmcv.runner import DefaultOptimizerConstructor, get_dist_info - -from mmseg.utils import get_root_logger -from .builder import OPTIMIZER_BUILDERS - - -def get_num_layer_for_vit(var_name, num_max_layer): - """Get the layer id to set the different learning rates. - - Args: - var_name (str): The key of the model. - num_max_layer (int): Maximum number of backbone layers. - Returns: - layer id (int): Returns the layer id of the key. - """ - - if var_name in ('backbone.cls_token', 'backbone.mask_token', - 'backbone.pos_embed'): - return 0 - elif var_name.startswith('backbone.patch_embed'): - return 0 - elif var_name.startswith('backbone.layers'): - layer_id = int(var_name.split('.')[2]) - return layer_id + 1 - else: - return num_max_layer - 1 - - -@OPTIMIZER_BUILDERS.register_module() -class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor): - """Different learning rates are set for different layers of backbone.""" - - def add_params(self, params, module): - """Add all parameters of module to the params list. - - The parameters of the given module will be added to the list of param - groups, with specific rules defined by paramwise_cfg. - Args: - params (list[dict]): A list of param groups, it will be modified - in place. - module (nn.Module): The module to be added. - """ - parameter_groups = {} - logger = get_root_logger() - logger.info(self.paramwise_cfg) - num_layers = self.paramwise_cfg.get('num_layers') + 2 - layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate') - logger.info(f'Build LayerDecayOptimizerConstructor ' - f'{layer_decay_rate} - {num_layers}') - weight_decay = self.base_wd - for name, param in module.named_parameters(): - if not param.requires_grad: - continue # frozen weights - if len(param.shape) == 1 or name.endswith('.bias') or name in ( - 'pos_embed', 'cls_token'): - group_name = 'no_decay' - this_weight_decay = 0. - else: - group_name = 'decay' - this_weight_decay = weight_decay - layer_id = get_num_layer_for_vit(name, num_layers) - group_name = f'layer_{layer_id}_{group_name}' - if group_name not in parameter_groups: - scale = layer_decay_rate**(num_layers - layer_id - 1) - parameter_groups[group_name] = { - 'weight_decay': this_weight_decay, - 'params': [], - 'param_names': [], - 'lr_scale': scale, - 'group_name': group_name, - 'lr': scale * self.base_lr - } - parameter_groups[group_name]['params'].append(param) - parameter_groups[group_name]['param_names'].append(name) - rank, _ = get_dist_info() - if rank == 0: - to_display = {} - for key in parameter_groups: - to_display[key] = { - 'param_names': parameter_groups[key]['param_names'], - 'lr_scale': parameter_groups[key]['lr_scale'], - 'lr': parameter_groups[key]['lr'], - 'weight_decay': parameter_groups[key]['weight_decay'] - } - logger.info(f'Param groups ={to_display}') - params.extend(parameter_groups.values()) diff --git a/mmseg/core/optimizers/__init__.py b/mmseg/core/optimizers/__init__.py new file mode 100644 index 0000000000..4fbf4ecfcd --- /dev/null +++ b/mmseg/core/optimizers/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .layer_decay_optimizer_constructor import ( + LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) + +__all__ = [ + 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor' +] diff --git a/mmseg/core/utils/layer_decay_optimizer_constructor.py b/mmseg/core/optimizers/layer_decay_optimizer_constructor.py similarity index 59% rename from mmseg/core/utils/layer_decay_optimizer_constructor.py rename to mmseg/core/optimizers/layer_decay_optimizer_constructor.py index 29804878cd..ce376760bd 100644 --- a/mmseg/core/utils/layer_decay_optimizer_constructor.py +++ b/mmseg/core/optimizers/layer_decay_optimizer_constructor.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import json +import warnings from mmcv.runner import DefaultOptimizerConstructor, get_dist_info @@ -7,13 +8,13 @@ from ..builder import OPTIMIZER_BUILDERS -def get_num_layer_layer_wise(var_name, num_max_layer=12): +def get_layer_id_for_convnext(var_name, max_layer_id): """Get the layer id to set the different learning rates in ``layer_wise`` decay_type. Args: var_name (str): The key of the model. - num_max_layer (int): Maximum number of backbone layers. + max_layer_id (int): Maximum number of backbone layers. Returns: int: The id number corresponding to different learning rate in @@ -32,7 +33,7 @@ def get_num_layer_layer_wise(var_name, num_max_layer=12): elif stage_id == 2: layer_id = 3 elif stage_id == 3: - layer_id = num_max_layer + layer_id = max_layer_id return layer_id elif var_name.startswith('backbone.stages'): stage_id = int(var_name.split('.')[2]) @@ -44,19 +45,20 @@ def get_num_layer_layer_wise(var_name, num_max_layer=12): elif stage_id == 2: layer_id = 3 + block_id // 3 elif stage_id == 3: - layer_id = num_max_layer + layer_id = max_layer_id return layer_id else: - return num_max_layer + 1 + return max_layer_id + 1 -def get_num_layer_stage_wise(var_name, num_max_layer): - """Get the layer id to set the different learning rates in ``stage_wise`` +def get_stage_id_for_convnext(var_name, max_stage_id): + """Get the stage id to set the different learning rates in ``stage_wise`` decay_type. Args: var_name (str): The key of the model. - num_max_layer (int): Maximum number of backbone layers. + max_stage_id (int): Maximum number of backbone layers. + Returns: int: The id number corresponding to different learning rate in ``LearningRateDecayOptimizerConstructor``. @@ -71,14 +73,41 @@ def get_num_layer_stage_wise(var_name, num_max_layer): stage_id = int(var_name.split('.')[2]) return stage_id + 1 else: - return num_max_layer - 1 + return max_stage_id - 1 + + +def get_layer_id_for_vit(var_name, max_layer_id): + """Get the layer id to set the different learning rates. + + Args: + var_name (str): The key of the model. + num_max_layer (int): Maximum number of backbone layers. + + Returns: + int: Returns the layer id of the key. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.patch_embed'): + return 0 + elif var_name.startswith('backbone.layers'): + layer_id = int(var_name.split('.')[2]) + return layer_id + 1 + else: + return max_layer_id - 1 @OPTIMIZER_BUILDERS.register_module() class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor): - """Different learning rates are set for different layers of backbone.""" + """Different learning rates are set for different layers of backbone. - def add_params(self, params, module): + Note: Currently, this optimizer constructor is built for ConvNeXt + and BEiT. + """ + + def add_params(self, params, module, **kwargs): """Add all parameters of module to the params list. The parameters of the given module will be added to the list of param @@ -99,7 +128,6 @@ def add_params(self, params, module): logger.info('Build LearningRateDecayOptimizerConstructor ' f'{decay_type} {decay_rate} - {num_layers}') weight_decay = self.base_wd - for name, param in module.named_parameters(): if not param.requires_grad: continue # frozen weights @@ -110,14 +138,22 @@ def add_params(self, params, module): else: group_name = 'decay' this_weight_decay = weight_decay - - if decay_type == 'layer_wise': - layer_id = get_num_layer_layer_wise( - name, self.paramwise_cfg.get('num_layers')) - logger.info(f'set param {name} as id {layer_id}') + if 'layer_wise' in decay_type: + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_layer_id_for_convnext( + name, self.paramwise_cfg.get('num_layers')) + logger.info(f'set param {name} as id {layer_id}') + elif 'BEiT' in module.backbone.__class__.__name__: + layer_id = get_layer_id_for_vit(name, num_layers) + logger.info(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() elif decay_type == 'stage_wise': - layer_id = get_num_layer_stage_wise(name, num_layers) - logger.info(f'set param {name} as id {layer_id}') + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_stage_id_for_convnext(name, num_layers) + logger.info(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() group_name = f'layer_{layer_id}_{group_name}' if group_name not in parameter_groups: @@ -146,3 +182,26 @@ def add_params(self, params, module): } logger.info(f'Param groups = {json.dumps(to_display, indent=2)}') params.extend(parameter_groups.values()) + + +@OPTIMIZER_BUILDERS.register_module() +class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor): + """Different learning rates are set for different layers of backbone. + + Note: Currently, this optimizer constructor is built for BEiT, + and it will be deprecated. + Please use ``LearningRateDecayOptimizerConstructor`` instead. + """ + + def __init__(self, optimizer_cfg, paramwise_cfg): + warnings.warn('DeprecationWarning: Original ' + 'LayerDecayOptimizerConstructor of BEiT ' + 'will be deprecated. Please use ' + 'LearningRateDecayOptimizerConstructor instead, ' + 'and set decay_type = layer_wise_vit in paramwise_cfg.') + paramwise_cfg.update({'decay_type': 'layer_wise_vit'}) + warnings.warn('DeprecationWarning: Layer_decay_rate will ' + 'be deleted, please use decay_rate instead.') + paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate') + super(LayerDecayOptimizerConstructor, + self).__init__(optimizer_cfg, paramwise_cfg) diff --git a/mmseg/core/utils/__init__.py b/mmseg/core/utils/__init__.py index cb5a0c3fd3..28882893a5 100644 --- a/mmseg/core/utils/__init__.py +++ b/mmseg/core/utils/__init__.py @@ -1,10 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dist_util import check_dist_init, sync_random_seed -from .layer_decay_optimizer_constructor import \ - LearningRateDecayOptimizerConstructor from .misc import add_prefix -__all__ = [ - 'add_prefix', 'LearningRateDecayOptimizerConstructor', 'check_dist_init', - 'sync_random_seed' -] +__all__ = ['add_prefix', 'check_dist_init', 'sync_random_seed'] diff --git a/tests/test_core/test_layer_decay_optimizer_constructor.py b/tests/test_core/test_layer_decay_optimizer_constructor.py index f595d31331..268a9a1489 100644 --- a/tests/test_core/test_layer_decay_optimizer_constructor.py +++ b/tests/test_core/test_layer_decay_optimizer_constructor.py @@ -1,11 +1,94 @@ # Copyright (c) OpenMMLab. All rights reserved. +import pytest import torch import torch.nn as nn +from mmcv.cnn import ConvModule -from mmseg.core.layer_decay_optimizer_constructor import \ - LayerDecayOptimizerConstructor +from mmseg.core.optimizers.layer_decay_optimizer_constructor import ( + LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) -layer_wise_gt_lst = [{ +base_lr = 1 +decay_rate = 2 +base_wd = 0.05 +weight_decay = 0.05 + +expected_stage_wise_lr_wd_convnext = [{ + 'weight_decay': 0.0, + 'lr_scale': 128 +}, { + 'weight_decay': 0.0, + 'lr_scale': 1 +}, { + 'weight_decay': 0.05, + 'lr_scale': 64 +}, { + 'weight_decay': 0.0, + 'lr_scale': 64 +}, { + 'weight_decay': 0.05, + 'lr_scale': 32 +}, { + 'weight_decay': 0.0, + 'lr_scale': 32 +}, { + 'weight_decay': 0.05, + 'lr_scale': 16 +}, { + 'weight_decay': 0.0, + 'lr_scale': 16 +}, { + 'weight_decay': 0.05, + 'lr_scale': 8 +}, { + 'weight_decay': 0.0, + 'lr_scale': 8 +}, { + 'weight_decay': 0.05, + 'lr_scale': 128 +}, { + 'weight_decay': 0.05, + 'lr_scale': 1 +}] + +expected_layer_wise_lr_wd_convnext = [{ + 'weight_decay': 0.0, + 'lr_scale': 128 +}, { + 'weight_decay': 0.0, + 'lr_scale': 1 +}, { + 'weight_decay': 0.05, + 'lr_scale': 64 +}, { + 'weight_decay': 0.0, + 'lr_scale': 64 +}, { + 'weight_decay': 0.05, + 'lr_scale': 32 +}, { + 'weight_decay': 0.0, + 'lr_scale': 32 +}, { + 'weight_decay': 0.05, + 'lr_scale': 16 +}, { + 'weight_decay': 0.0, + 'lr_scale': 16 +}, { + 'weight_decay': 0.05, + 'lr_scale': 2 +}, { + 'weight_decay': 0.0, + 'lr_scale': 2 +}, { + 'weight_decay': 0.05, + 'lr_scale': 128 +}, { + 'weight_decay': 0.05, + 'lr_scale': 1 +}] + +expected_layer_wise_wd_lr_beit = [{ 'weight_decay': 0.0, 'lr_scale': 16 }, { @@ -26,45 +109,143 @@ }, { 'weight_decay': 0.0, 'lr_scale': 2 +}, { + 'weight_decay': 0.05, + 'lr_scale': 1 +}, { + 'weight_decay': 0.0, + 'lr_scale': 1 }] -class BEiTExampleModel(nn.Module): +class ToyConvNeXt(nn.Module): - def __init__(self, depth): + def __init__(self): super().__init__() - self.backbone = nn.ModuleList() + self.stages = nn.ModuleList() + for i in range(4): + stage = nn.Sequential(ConvModule(3, 4, kernel_size=1, bias=True)) + self.stages.append(stage) + self.norm0 = nn.BatchNorm2d(2) # add some variables to meet unit test coverate rate - self.backbone.cls_token = nn.Parameter(torch.ones(1)) - self.backbone.patch_embed = nn.Parameter(torch.ones(1)) - self.backbone.layers = nn.ModuleList() - for _ in range(depth): + self.cls_token = nn.Parameter(torch.ones(1)) + self.mask_token = nn.Parameter(torch.ones(1)) + self.pos_embed = nn.Parameter(torch.ones(1)) + self.stem_norm = nn.Parameter(torch.ones(1)) + self.downsample_norm0 = nn.BatchNorm2d(2) + self.downsample_norm1 = nn.BatchNorm2d(2) + self.downsample_norm2 = nn.BatchNorm2d(2) + self.lin = nn.Parameter(torch.ones(1)) + self.lin.requires_grad = False + self.downsample_layers = nn.ModuleList() + for _ in range(4): + stage = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=True)) + self.downsample_layers.append(stage) + + +class ToyBEiT(nn.Module): + + def __init__(self): + super().__init__() + # add some variables to meet unit test coverate rate + self.cls_token = nn.Parameter(torch.ones(1)) + self.patch_embed = nn.Parameter(torch.ones(1)) + self.layers = nn.ModuleList() + for _ in range(3): layer = nn.Conv2d(3, 3, 1) - self.backbone.layers.append(layer) + self.layers.append(layer) -def check_beit_adamw_optimizer(optimizer, gt_lst): +class ToySegmentor(nn.Module): + + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + self.decode_head = nn.Conv2d(2, 2, kernel_size=1, groups=2) + + +class PseudoDataParallel(nn.Module): + + def __init__(self, model): + super().__init__() + self.module = model + + +class ToyViT(nn.Module): + + def __init__(self): + super().__init__() + + +def check_optimizer_lr_wd(optimizer, gt_lr_wd): assert isinstance(optimizer, torch.optim.AdamW) - assert optimizer.defaults['lr'] == 1 - assert optimizer.defaults['weight_decay'] == 0.05 + assert optimizer.defaults['lr'] == base_lr + assert optimizer.defaults['weight_decay'] == base_wd param_groups = optimizer.param_groups - # 1 layer (cls_token and patch_embed) + 3 layers * 2 (w, b) = 7 layers - assert len(param_groups) == 7 + print(param_groups) + assert len(param_groups) == len(gt_lr_wd) for i, param_dict in enumerate(param_groups): - assert param_dict['weight_decay'] == gt_lst[i]['weight_decay'] - assert param_dict['lr_scale'] == gt_lst[i]['lr_scale'] + assert param_dict['weight_decay'] == gt_lr_wd[i]['weight_decay'] + assert param_dict['lr_scale'] == gt_lr_wd[i]['lr_scale'] assert param_dict['lr_scale'] == param_dict['lr'] +def test_learning_rate_decay_optimizer_constructor(): + + # Test lr wd for ConvNeXT + backbone = ToyConvNeXt() + model = PseudoDataParallel(ToySegmentor(backbone)) + optimizer_cfg = dict( + type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05) + # stagewise decay + stagewise_paramwise_cfg = dict( + decay_rate=decay_rate, decay_type='stage_wise', num_layers=6) + optim_constructor = LearningRateDecayOptimizerConstructor( + optimizer_cfg, stagewise_paramwise_cfg) + optimizer = optim_constructor(model) + check_optimizer_lr_wd(optimizer, expected_stage_wise_lr_wd_convnext) + # layerwise decay + layerwise_paramwise_cfg = dict( + decay_rate=decay_rate, decay_type='layer_wise', num_layers=6) + optim_constructor = LearningRateDecayOptimizerConstructor( + optimizer_cfg, layerwise_paramwise_cfg) + optimizer = optim_constructor(model) + check_optimizer_lr_wd(optimizer, expected_layer_wise_lr_wd_convnext) + + # Test lr wd for BEiT + backbone = ToyBEiT() + model = PseudoDataParallel(ToySegmentor(backbone)) + + layerwise_paramwise_cfg = dict( + decay_rate=decay_rate, decay_type='layer_wise', num_layers=3) + optim_constructor = LearningRateDecayOptimizerConstructor( + optimizer_cfg, layerwise_paramwise_cfg) + optimizer = optim_constructor(model) + check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_beit) + + # Test invalidation of lr wd for Vit + backbone = ToyViT() + model = PseudoDataParallel(ToySegmentor(backbone)) + with pytest.raises(NotImplementedError): + optim_constructor = LearningRateDecayOptimizerConstructor( + optimizer_cfg, layerwise_paramwise_cfg) + optimizer = optim_constructor(model) + with pytest.raises(NotImplementedError): + optim_constructor = LearningRateDecayOptimizerConstructor( + optimizer_cfg, stagewise_paramwise_cfg) + optimizer = optim_constructor(model) + + def test_beit_layer_decay_optimizer_constructor(): - # paramwise_cfg with ConvNeXtExampleModel - model = BEiTExampleModel(depth=3) + # paramwise_cfg with BEiTExampleModel + backbone = ToyBEiT() + model = PseudoDataParallel(ToySegmentor(backbone)) optimizer_cfg = dict( type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05) - paramwise_cfg = dict(num_layers=3, layer_decay_rate=2) + paramwise_cfg = dict(layer_decay_rate=2, num_layers=3) optim_constructor = LayerDecayOptimizerConstructor(optimizer_cfg, paramwise_cfg) optimizer = optim_constructor(model) - check_beit_adamw_optimizer(optimizer, layer_wise_gt_lst) + check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_beit) diff --git a/tests/test_core/test_learning_rate_decay_optimizer_constructor.py b/tests/test_core/test_learning_rate_decay_optimizer_constructor.py deleted file mode 100644 index 204ca45b9e..0000000000 --- a/tests/test_core/test_learning_rate_decay_optimizer_constructor.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import torch.nn as nn -from mmcv.cnn import ConvModule - -from mmseg.core.utils.layer_decay_optimizer_constructor import \ - LearningRateDecayOptimizerConstructor - -base_lr = 1 -decay_rate = 2 -base_wd = 0.05 -weight_decay = 0.05 - -stage_wise_gt_lst = [{ - 'weight_decay': 0.0, - 'lr_scale': 128 -}, { - 'weight_decay': 0.0, - 'lr_scale': 1 -}, { - 'weight_decay': 0.05, - 'lr_scale': 64 -}, { - 'weight_decay': 0.0, - 'lr_scale': 64 -}, { - 'weight_decay': 0.05, - 'lr_scale': 32 -}, { - 'weight_decay': 0.0, - 'lr_scale': 32 -}, { - 'weight_decay': 0.05, - 'lr_scale': 16 -}, { - 'weight_decay': 0.0, - 'lr_scale': 16 -}, { - 'weight_decay': 0.05, - 'lr_scale': 8 -}, { - 'weight_decay': 0.0, - 'lr_scale': 8 -}, { - 'weight_decay': 0.05, - 'lr_scale': 128 -}, { - 'weight_decay': 0.05, - 'lr_scale': 1 -}] - -layer_wise_gt_lst = [{ - 'weight_decay': 0.0, - 'lr_scale': 128 -}, { - 'weight_decay': 0.0, - 'lr_scale': 1 -}, { - 'weight_decay': 0.05, - 'lr_scale': 64 -}, { - 'weight_decay': 0.0, - 'lr_scale': 64 -}, { - 'weight_decay': 0.05, - 'lr_scale': 32 -}, { - 'weight_decay': 0.0, - 'lr_scale': 32 -}, { - 'weight_decay': 0.05, - 'lr_scale': 16 -}, { - 'weight_decay': 0.0, - 'lr_scale': 16 -}, { - 'weight_decay': 0.05, - 'lr_scale': 2 -}, { - 'weight_decay': 0.0, - 'lr_scale': 2 -}, { - 'weight_decay': 0.05, - 'lr_scale': 128 -}, { - 'weight_decay': 0.05, - 'lr_scale': 1 -}] - - -class ConvNeXtExampleModel(nn.Module): - - def __init__(self): - super().__init__() - self.backbone = nn.ModuleList() - self.backbone.stages = nn.ModuleList() - for i in range(4): - stage = nn.Sequential(ConvModule(3, 4, kernel_size=1, bias=True)) - self.backbone.stages.append(stage) - self.backbone.norm0 = nn.BatchNorm2d(2) - - # add some variables to meet unit test coverate rate - self.backbone.cls_token = nn.Parameter(torch.ones(1)) - self.backbone.mask_token = nn.Parameter(torch.ones(1)) - self.backbone.pos_embed = nn.Parameter(torch.ones(1)) - self.backbone.stem_norm = nn.Parameter(torch.ones(1)) - self.backbone.downsample_norm0 = nn.BatchNorm2d(2) - self.backbone.downsample_norm1 = nn.BatchNorm2d(2) - self.backbone.downsample_norm2 = nn.BatchNorm2d(2) - self.backbone.lin = nn.Parameter(torch.ones(1)) - self.backbone.lin.requires_grad = False - - self.backbone.downsample_layers = nn.ModuleList() - for i in range(4): - stage = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=True)) - self.backbone.downsample_layers.append(stage) - - self.decode_head = nn.Conv2d(2, 2, kernel_size=1, groups=2) - - -class PseudoDataParallel(nn.Module): - - def __init__(self): - super().__init__() - self.module = ConvNeXtExampleModel() - - def forward(self, x): - return x - - -def check_convnext_adamw_optimizer(optimizer, gt_lst): - assert isinstance(optimizer, torch.optim.AdamW) - assert optimizer.defaults['lr'] == base_lr - assert optimizer.defaults['weight_decay'] == base_wd - param_groups = optimizer.param_groups - assert len(param_groups) == 12 - for i, param_dict in enumerate(param_groups): - assert param_dict['weight_decay'] == gt_lst[i]['weight_decay'] - assert param_dict['lr_scale'] == gt_lst[i]['lr_scale'] - assert param_dict['lr_scale'] == param_dict['lr'] - - -def test_convnext_learning_rate_decay_optimizer_constructor(): - - # paramwise_cfg with ConvNeXtExampleModel - model = ConvNeXtExampleModel() - optimizer_cfg = dict( - type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05) - stagewise_paramwise_cfg = dict( - decay_rate=decay_rate, decay_type='stage_wise', num_layers=6) - optim_constructor = LearningRateDecayOptimizerConstructor( - optimizer_cfg, stagewise_paramwise_cfg) - optimizer = optim_constructor(model) - check_convnext_adamw_optimizer(optimizer, stage_wise_gt_lst) - - layerwise_paramwise_cfg = dict( - decay_rate=decay_rate, decay_type='layer_wise', num_layers=6) - optim_constructor = LearningRateDecayOptimizerConstructor( - optimizer_cfg, layerwise_paramwise_cfg) - optimizer = optim_constructor(model) - check_convnext_adamw_optimizer(optimizer, layer_wise_gt_lst)