Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactory] Merge BEiT and ConvNext 's LR decay optimizer constructors #1438

Merged
merged 36 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f71cef5
move layer_decay_optimizer_constructor
linfangjian01 Mar 31, 2022
6aa86d2
fix
linfangjian01 Mar 31, 2022
936ab01
fix
linfangjian01 Mar 31, 2022
102c54f
merge test_core
linfangjian01 Mar 31, 2022
16e0b2f
fix
linfangjian01 Apr 1, 2022
d7c0d67
add DeprecationWarning
linfangjian01 Apr 1, 2022
2e7e579
fix DeprecationWarning
linfangjian01 Apr 1, 2022
278fc81
fix
linfangjian01 Apr 2, 2022
e57d79e
fix
linfangjian01 Apr 2, 2022
4d8131d
fix
linfangjian01 Apr 3, 2022
29ebc06
fix
linfangjian01 Apr 5, 2022
14d2026
fix
linfangjian01 Apr 5, 2022
db10656
fix
linfangjian01 Apr 6, 2022
e8a6a6b
fix
linfangjian01 Apr 6, 2022
f8eb1b7
fix
linfangjian01 Apr 6, 2022
1f58c69
fix
linfangjian01 Apr 7, 2022
c931d33
fix
linfangjian01 Apr 12, 2022
b2bac1f
Merge branch 'open-mmlab:master' into moveoptim
linfangjian01 Apr 13, 2022
c7e461f
Merge branch 'open-mmlab:master' into moveoptim
linfangjian01 Apr 13, 2022
e9b1999
fix
linfangjian01 Apr 25, 2022
1b71c3c
fix
linfangjian01 Apr 25, 2022
f7d1f44
fix
linfangjian01 Apr 25, 2022
3052fe7
fix test
linfangjian01 Apr 25, 2022
5a056fe
fix
linfangjian01 Apr 25, 2022
875d047
fix
linfangjian01 Apr 25, 2022
6eb5bc0
fix
linfangjian01 Apr 25, 2022
f99eb9a
fix
linfangjian01 Apr 26, 2022
97fb999
fix ut
MeowZheng Apr 26, 2022
60eff24
fix ut
MeowZheng Apr 26, 2022
2acb909
fix
linfangjian01 Apr 26, 2022
494c8ac
fix
linfangjian01 Apr 26, 2022
50d56c6
Update tests/test_core/test_layer_decay_optimizer_constructor.py
MeowZheng Apr 26, 2022
4d13e41
fix
linfangjian01 Apr 26, 2022
6285ae8
fix
linfangjian01 Apr 26, 2022
8e11179
fix
linfangjian01 Apr 26, 2022
23b920d
fix
linfangjian01 Apr 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mmseg/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .evaluation import * # noqa: F401, F403
from .layer_decay_optimizer_constructor import \
LayerDecayOptimizerConstructor # noqa: F401
from .optimizers import * # noqa: F401, F403
from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
87 changes: 0 additions & 87 deletions mmseg/core/layer_decay_optimizer_constructor.py

This file was deleted.

7 changes: 7 additions & 0 deletions mmseg/core/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .layer_decay_optimizer_constructor import (
LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor)

__all__ = [
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor'
]
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import warnings

from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor,
get_dist_info)
Expand Down Expand Up @@ -74,9 +75,32 @@ def get_num_layer_stage_wise(var_name, num_max_layer):
return num_max_layer - 1


def get_num_layer_for_vit(var_name, num_max_layer):
linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved
"""Get the layer id to set the different learning rates.

Args:
var_name (str): The key of the model.
num_max_layer (int): Maximum number of backbone layers.
Returns:
linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved
layer id (int): Returns the layer id of the key.
"""

if var_name in ('backbone.cls_token', 'backbone.mask_token',
'backbone.pos_embed'):
return 0
elif var_name.startswith('backbone.patch_embed'):
return 0
elif var_name.startswith('backbone.layers'):
layer_id = int(var_name.split('.')[2])
return layer_id + 1
else:
return num_max_layer - 1


@OPTIMIZER_BUILDERS.register_module()
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
"""Different learning rates are set for different layers of backbone."""
# Different learning rates are set for different layers of backbone.
# Note: Currently, this optimizer constructor is built for ConvNeXt.
linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved

def add_params(self, params, module):
linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved
"""Add all parameters of module to the params list.
Expand Down Expand Up @@ -115,6 +139,9 @@ def add_params(self, params, module):
layer_id = get_num_layer_layer_wise(
name, self.paramwise_cfg.get('num_layers'))
logger.info(f'set param {name} as id {layer_id}')
elif decay_type == 'layer_wise_vit':
layer_id = get_num_layer_for_vit(name, num_layers)
logger.info(f'set param {name} as id {layer_id}')
linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved
elif decay_type == 'stage_wise':
layer_id = get_num_layer_stage_wise(name, num_layers)
logger.info(f'set param {name} as id {layer_id}')
Expand Down Expand Up @@ -146,3 +173,22 @@ def add_params(self, params, module):
}
logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
params.extend(parameter_groups.values())


@OPTIMIZER_BUILDERS.register_module()
class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
# Different learning rates are set for different layers of backbone.
# Note: Currently, this optimizer constructor is built for BEiT.
linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, optimizer_cfg, paramwise_cfg):
warnings.warn('DeprecationWarning: Original '
'LayerDecayOptimizerConstructor of BEiT '
'will be deprecated. Please use '
'LearningRateDecayOptimizerConstructor instead, '
'and set decay_type = layer_wise_vit in paramwise_cfg.')
paramwise_cfg.update({'decay_type': 'layer_wise_vit'})
warnings.warn('DeprecationWarning: Layer_decay_rate will '
'be deleted, please use decay_rate instead.')
paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate')
super(LayerDecayOptimizerConstructor,
self).__init__(optimizer_cfg, paramwise_cfg)
7 changes: 1 addition & 6 deletions mmseg/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dist_util import check_dist_init, sync_random_seed
from .layer_decay_optimizer_constructor import \
LearningRateDecayOptimizerConstructor
from .misc import add_prefix

__all__ = [
'add_prefix', 'LearningRateDecayOptimizerConstructor', 'check_dist_init',
'sync_random_seed'
]
__all__ = ['add_prefix', 'check_dist_init', 'sync_random_seed']
164 changes: 160 additions & 4 deletions tests/test_core/test_layer_decay_optimizer_constructor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,167 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmseg.core.layer_decay_optimizer_constructor import \
LayerDecayOptimizerConstructor
from mmseg.core.optimizers.layer_decay_optimizer_constructor import (
LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor)
linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved

base_lr = 1
decay_rate = 2
base_wd = 0.05
weight_decay = 0.05

stage_wise_gt_lst = [{
'weight_decay': 0.0,
'lr_scale': 128
}, {
'weight_decay': 0.0,
'lr_scale': 1
}, {
'weight_decay': 0.05,
'lr_scale': 64
}, {
'weight_decay': 0.0,
'lr_scale': 64
}, {
'weight_decay': 0.05,
'lr_scale': 32
}, {
'weight_decay': 0.0,
'lr_scale': 32
}, {
'weight_decay': 0.05,
'lr_scale': 16
}, {
'weight_decay': 0.0,
'lr_scale': 16
}, {
'weight_decay': 0.05,
'lr_scale': 8
}, {
'weight_decay': 0.0,
'lr_scale': 8
}, {
'weight_decay': 0.05,
'lr_scale': 128
}, {
'weight_decay': 0.05,
'lr_scale': 1
}]

layer_wise_gt_lst = [{
'weight_decay': 0.0,
'lr_scale': 128
}, {
'weight_decay': 0.0,
'lr_scale': 1
}, {
'weight_decay': 0.05,
'lr_scale': 64
}, {
'weight_decay': 0.0,
'lr_scale': 64
}, {
'weight_decay': 0.05,
'lr_scale': 32
}, {
'weight_decay': 0.0,
'lr_scale': 32
}, {
'weight_decay': 0.05,
'lr_scale': 16
}, {
'weight_decay': 0.0,
'lr_scale': 16
}, {
'weight_decay': 0.05,
'lr_scale': 2
}, {
'weight_decay': 0.0,
'lr_scale': 2
}, {
'weight_decay': 0.05,
'lr_scale': 128
}, {
'weight_decay': 0.05,
'lr_scale': 1
}]


class ConvNeXtExampleModel(nn.Module):

def __init__(self):
super().__init__()
self.backbone = nn.ModuleList()
self.backbone.stages = nn.ModuleList()
for i in range(4):
stage = nn.Sequential(ConvModule(3, 4, kernel_size=1, bias=True))
self.backbone.stages.append(stage)
self.backbone.norm0 = nn.BatchNorm2d(2)

# add some variables to meet unit test coverate rate
self.backbone.cls_token = nn.Parameter(torch.ones(1))
self.backbone.mask_token = nn.Parameter(torch.ones(1))
self.backbone.pos_embed = nn.Parameter(torch.ones(1))
self.backbone.stem_norm = nn.Parameter(torch.ones(1))
self.backbone.downsample_norm0 = nn.BatchNorm2d(2)
self.backbone.downsample_norm1 = nn.BatchNorm2d(2)
self.backbone.downsample_norm2 = nn.BatchNorm2d(2)
self.backbone.lin = nn.Parameter(torch.ones(1))
self.backbone.lin.requires_grad = False

self.backbone.downsample_layers = nn.ModuleList()
for i in range(4):
stage = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=True))
self.backbone.downsample_layers.append(stage)

self.decode_head = nn.Conv2d(2, 2, kernel_size=1, groups=2)


class PseudoDataParallel(nn.Module):

def __init__(self):
super().__init__()
self.module = ConvNeXtExampleModel()

def forward(self, x):
return x


def check_convnext_adamw_optimizer(optimizer, gt_lst):
assert isinstance(optimizer, torch.optim.AdamW)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['weight_decay'] == base_wd
param_groups = optimizer.param_groups
assert len(param_groups) == 12
for i, param_dict in enumerate(param_groups):
assert param_dict['weight_decay'] == gt_lst[i]['weight_decay']
assert param_dict['lr_scale'] == gt_lst[i]['lr_scale']
assert param_dict['lr_scale'] == param_dict['lr']


def test_convnext_learning_rate_decay_optimizer_constructor():
linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved

# paramwise_cfg with ConvNeXtExampleModel
model = ConvNeXtExampleModel()
optimizer_cfg = dict(
type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05)
stagewise_paramwise_cfg = dict(
decay_rate=decay_rate, decay_type='stage_wise', num_layers=6)
optim_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg, stagewise_paramwise_cfg)
optimizer = optim_constructor(model)
check_convnext_adamw_optimizer(optimizer, stage_wise_gt_lst)
linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved

layerwise_paramwise_cfg = dict(
decay_rate=decay_rate, decay_type='layer_wise', num_layers=6)
optim_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg, layerwise_paramwise_cfg)
optimizer = optim_constructor(model)
check_convnext_adamw_optimizer(optimizer, layer_wise_gt_lst)


layer_wise_wd_lr = [{
'weight_decay': 0.0,
'lr_scale': 16
}, {
Expand Down Expand Up @@ -63,8 +219,8 @@ def test_beit_layer_decay_optimizer_constructor():
model = BEiTExampleModel(depth=3)
optimizer_cfg = dict(
type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05)
paramwise_cfg = dict(num_layers=3, layer_decay_rate=2)
paramwise_cfg = dict(layer_decay_rate=2, num_layers=3)
optim_constructor = LayerDecayOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
optimizer = optim_constructor(model)
check_beit_adamw_optimizer(optimizer, layer_wise_gt_lst)
check_beit_adamw_optimizer(optimizer, layer_wise_wd_lr)
Loading