Skip to content

Commit

Permalink
[Refactory] Merge BEiT and ConvNext 's LR decay optimizer constructors (
Browse files Browse the repository at this point in the history
#1438)

* move layer_decay_optimizer_constructor

* fix

* fix

* merge test_core

* fix

* add DeprecationWarning

* fix DeprecationWarning

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix test

* fix

* fix

* fix

* fix

* fix ut

* fix

* fix

* Update tests/test_core/test_layer_decay_optimizer_constructor.py

* fix

* fix

* fix

* fix

Co-authored-by: MeowZheng <meowzheng@outlook.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 27, 2022
1 parent 62c3a7d commit f16bb06
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 299 deletions.
6 changes: 2 additions & 4 deletions mmseg/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from .builder import (OPTIMIZER_BUILDERS, build_optimizer,
build_optimizer_constructor)
from .evaluation import * # noqa: F401, F403
from .layer_decay_optimizer_constructor import \
LayerDecayOptimizerConstructor # noqa: F401
from .optimizers import * # noqa: F401, F403
from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403

__all__ = [
'LayerDecayOptimizerConstructor', 'OPTIMIZER_BUILDERS', 'build_optimizer',
'build_optimizer_constructor'
'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor'
]
87 changes: 0 additions & 87 deletions mmseg/core/layer_decay_optimizer_constructor.py

This file was deleted.

7 changes: 7 additions & 0 deletions mmseg/core/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .layer_decay_optimizer_constructor import (
LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor)

__all__ = [
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor'
]
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import warnings

from mmcv.runner import DefaultOptimizerConstructor, get_dist_info

from mmseg.utils import get_root_logger
from ..builder import OPTIMIZER_BUILDERS


def get_num_layer_layer_wise(var_name, num_max_layer=12):
def get_layer_id_for_convnext(var_name, max_layer_id):
"""Get the layer id to set the different learning rates in ``layer_wise``
decay_type.
Args:
var_name (str): The key of the model.
num_max_layer (int): Maximum number of backbone layers.
max_layer_id (int): Maximum number of backbone layers.
Returns:
int: The id number corresponding to different learning rate in
Expand All @@ -32,7 +33,7 @@ def get_num_layer_layer_wise(var_name, num_max_layer=12):
elif stage_id == 2:
layer_id = 3
elif stage_id == 3:
layer_id = num_max_layer
layer_id = max_layer_id
return layer_id
elif var_name.startswith('backbone.stages'):
stage_id = int(var_name.split('.')[2])
Expand All @@ -44,19 +45,20 @@ def get_num_layer_layer_wise(var_name, num_max_layer=12):
elif stage_id == 2:
layer_id = 3 + block_id // 3
elif stage_id == 3:
layer_id = num_max_layer
layer_id = max_layer_id
return layer_id
else:
return num_max_layer + 1
return max_layer_id + 1


def get_num_layer_stage_wise(var_name, num_max_layer):
"""Get the layer id to set the different learning rates in ``stage_wise``
def get_stage_id_for_convnext(var_name, max_stage_id):
"""Get the stage id to set the different learning rates in ``stage_wise``
decay_type.
Args:
var_name (str): The key of the model.
num_max_layer (int): Maximum number of backbone layers.
max_stage_id (int): Maximum number of backbone layers.
Returns:
int: The id number corresponding to different learning rate in
``LearningRateDecayOptimizerConstructor``.
Expand All @@ -71,14 +73,41 @@ def get_num_layer_stage_wise(var_name, num_max_layer):
stage_id = int(var_name.split('.')[2])
return stage_id + 1
else:
return num_max_layer - 1
return max_stage_id - 1


def get_layer_id_for_vit(var_name, max_layer_id):
"""Get the layer id to set the different learning rates.
Args:
var_name (str): The key of the model.
num_max_layer (int): Maximum number of backbone layers.
Returns:
int: Returns the layer id of the key.
"""

if var_name in ('backbone.cls_token', 'backbone.mask_token',
'backbone.pos_embed'):
return 0
elif var_name.startswith('backbone.patch_embed'):
return 0
elif var_name.startswith('backbone.layers'):
layer_id = int(var_name.split('.')[2])
return layer_id + 1
else:
return max_layer_id - 1


@OPTIMIZER_BUILDERS.register_module()
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
"""Different learning rates are set for different layers of backbone."""
"""Different learning rates are set for different layers of backbone.
def add_params(self, params, module):
Note: Currently, this optimizer constructor is built for ConvNeXt
and BEiT.
"""

def add_params(self, params, module, **kwargs):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
Expand All @@ -99,7 +128,6 @@ def add_params(self, params, module):
logger.info('Build LearningRateDecayOptimizerConstructor '
f'{decay_type} {decay_rate} - {num_layers}')
weight_decay = self.base_wd

for name, param in module.named_parameters():
if not param.requires_grad:
continue # frozen weights
Expand All @@ -110,14 +138,22 @@ def add_params(self, params, module):
else:
group_name = 'decay'
this_weight_decay = weight_decay

if decay_type == 'layer_wise':
layer_id = get_num_layer_layer_wise(
name, self.paramwise_cfg.get('num_layers'))
logger.info(f'set param {name} as id {layer_id}')
if 'layer_wise' in decay_type:
if 'ConvNeXt' in module.backbone.__class__.__name__:
layer_id = get_layer_id_for_convnext(
name, self.paramwise_cfg.get('num_layers'))
logger.info(f'set param {name} as id {layer_id}')
elif 'BEiT' in module.backbone.__class__.__name__:
layer_id = get_layer_id_for_vit(name, num_layers)
logger.info(f'set param {name} as id {layer_id}')
else:
raise NotImplementedError()
elif decay_type == 'stage_wise':
layer_id = get_num_layer_stage_wise(name, num_layers)
logger.info(f'set param {name} as id {layer_id}')
if 'ConvNeXt' in module.backbone.__class__.__name__:
layer_id = get_stage_id_for_convnext(name, num_layers)
logger.info(f'set param {name} as id {layer_id}')
else:
raise NotImplementedError()
group_name = f'layer_{layer_id}_{group_name}'

if group_name not in parameter_groups:
Expand Down Expand Up @@ -146,3 +182,26 @@ def add_params(self, params, module):
}
logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
params.extend(parameter_groups.values())


@OPTIMIZER_BUILDERS.register_module()
class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
"""Different learning rates are set for different layers of backbone.
Note: Currently, this optimizer constructor is built for BEiT,
and it will be deprecated.
Please use ``LearningRateDecayOptimizerConstructor`` instead.
"""

def __init__(self, optimizer_cfg, paramwise_cfg):
warnings.warn('DeprecationWarning: Original '
'LayerDecayOptimizerConstructor of BEiT '
'will be deprecated. Please use '
'LearningRateDecayOptimizerConstructor instead, '
'and set decay_type = layer_wise_vit in paramwise_cfg.')
paramwise_cfg.update({'decay_type': 'layer_wise_vit'})
warnings.warn('DeprecationWarning: Layer_decay_rate will '
'be deleted, please use decay_rate instead.')
paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate')
super(LayerDecayOptimizerConstructor,
self).__init__(optimizer_cfg, paramwise_cfg)
7 changes: 1 addition & 6 deletions mmseg/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dist_util import check_dist_init, sync_random_seed
from .layer_decay_optimizer_constructor import \
LearningRateDecayOptimizerConstructor
from .misc import add_prefix

__all__ = [
'add_prefix', 'LearningRateDecayOptimizerConstructor', 'check_dist_init',
'sync_random_seed'
]
__all__ = ['add_prefix', 'check_dist_init', 'sync_random_seed']
Loading

0 comments on commit f16bb06

Please sign in to comment.