Skip to content

Commit

Permalink
[Feature] Add auto scale lr fucntion (open-mmlab#270)
Browse files Browse the repository at this point in the history
* Add auto scale lr fucntion

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

Co-authored-by: wangjiabao1.vendor <wangjiabao@pjlab.org.cn>
  • Loading branch information
jbwang1997 and jbwang1997 authored Jun 6, 2022
1 parent 65bc950 commit 8f3fcee
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 1 deletion.
80 changes: 79 additions & 1 deletion mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,17 @@ class Runner:
If ``test_cfg`` specified, :attr:`test_dataloader` should also be
specified. Defaults to None.
See :meth:`build_test_loop` for more details.
auto_scale_lr_cfg (dict, Optional): Config to scale the learning rate
automatically. It includes ``base_batch_size`` and ``enable``.
``base_batch_size`` is the batch size that the optimizer lr is
based on. ``enable`` is the switch to turn on and off the feature.
optim_wrapper (OptimWrapper or dict, optional):
Computing gradient of model parameters. If specified,
:attr:`train_dataloader` should also be specified. If automatic
mixed precision or gradient accmulation
training is required. The type of ``optim_wrapper`` should be
AmpOptimizerWrapper. See :meth:`build_optim_wrapper` for
examples. Defaults to None.
param_scheduler (_ParamScheduler or dict or list, optional):
Parameter scheduler for updating optimizer parameters. If
specified, :attr:`optimizer` should also be specified.
Expand Down Expand Up @@ -185,6 +188,7 @@ class Runner:
>>> sampler=dict(type='DefaultSampler', shuffle=False),
>>> batch_size=1,
>>> num_workers=0),
>>> auto_scale_lr_cfg=dict(base_batch_size=16, enable=False),
>>> optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict(
>>> type='SGD', lr=0.01)),
>>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
Expand Down Expand Up @@ -226,6 +230,7 @@ def __init__(
train_cfg: Optional[Dict] = None,
val_cfg: Optional[Dict] = None,
test_cfg: Optional[Dict] = None,
auto_scale_lr_cfg: Optional[Dict] = None,
optim_wrapper: Optional[Union[OptimWrapper, Dict]] = None,
param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None,
val_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
Expand Down Expand Up @@ -273,6 +278,8 @@ def __init__(
self.optim_wrapper: Optional[Union[OptimWrapper, dict]]
self.optim_wrapper = optim_wrapper

self.auto_scale_lr_cfg = auto_scale_lr_cfg

# If there is no need to adjust learning rate, momentum or other
# parameters of optimizer, param_scheduler can be None
if param_scheduler is not None and self.optim_wrapper is None:
Expand Down Expand Up @@ -411,6 +418,7 @@ def from_cfg(cls, cfg: ConfigType) -> 'Runner':
train_cfg=cfg.get('train_cfg'),
val_cfg=cfg.get('val_cfg'),
test_cfg=cfg.get('test_cfg'),
auto_scale_lr_cfg=cfg.get('auto_scale_lr_cfg'),
optim_wrapper=cfg.get('optim_wrapper'),
param_scheduler=cfg.get('param_scheduler'),
val_evaluator=cfg.get('val_evaluator'),
Expand Down Expand Up @@ -814,6 +822,66 @@ def wrap_model(self, model_wrapper_cfg: Optional[Dict],

return model

def scale_lr(self,
optim_wrapper: OptimWrapper,
auto_scale_lr_cfg: Optional[Dict] = None) -> None:
"""Automatically scaling learning rate in training according to the
ratio of ``base_batch_size`` in ``autoscalelr_cfg`` and real batch
size.
It scales the learning rate linearly according to the
`paper <https://arxiv.org/abs/1706.02677>`_.
Note:
``scale_lr`` must be called after building optimizer wrappers
and before building parameter schedulers.
Args:
optim_wrapper (OptimWrapper): An OptimWrapper object whose
parameter groups' learning rate need to be scaled.
auto_scale_lr_cfg (Dict, Optional): Config to scale the learning
rate automatically. It includes ``base_batch_size`` and
``enable``. ``base_batch_size`` is the batch size that the
optimizer lr is based on. ``enable`` is the switch to turn on
and off the feature.
"""
if (auto_scale_lr_cfg is None
or not auto_scale_lr_cfg.get('enable', False)):
return None

assert 'base_batch_size' in auto_scale_lr_cfg, \
'Lack of `base_batch_size` in `auto_scale_lr_cfg`.'
dataloader: Union[DataLoader, Dict] = self._train_dataloader
bs = dataloader.batch_size if isinstance(
dataloader, DataLoader) else dataloader['batch_size']
real_bs = self.world_size * bs
base_bs = auto_scale_lr_cfg['base_batch_size']
ratio = float(real_bs) / float(base_bs)
self.logger.info(f'LR is set based on batch size of {base_bs} '
f'and the current batch size is {real_bs}. '
f'Scaling the original LR by {ratio}.')

def _is_built(schedulers):
if isinstance(schedulers, dict):
return False if 'type' in schedulers else any(
_is_built(s) for s in schedulers.values())
if isinstance(schedulers, list):
return any(_is_built(s) for s in schedulers)
return isinstance(schedulers, _ParamScheduler)

if _is_built(self.param_schedulers):
raise RuntimeError('`scale_lr` should be called before building '
'ParamScheduler because ParamScheduler will '
'store initial lr from optimizer wrappers')

assert isinstance(optim_wrapper, OptimWrapper), \
'`scale_lr should be called after building OptimWrapper'
wrappers = list(optim_wrapper.values()) if isinstance(
optim_wrapper, OptimWrapperDict) else [optim_wrapper]
for wrapper in wrappers:
for group in wrapper.optimizer.param_groups:
group['lr'] = group['lr'] * ratio

def build_optim_wrapper(
self, optim_wrapper: Union[Optimizer, OptimWrapper, Dict]
) -> Union[OptimWrapper, OptimWrapperDict]:
Expand Down Expand Up @@ -1439,6 +1507,9 @@ def train(self) -> None:
# because the latter depends on the former
self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)

# Automatically scaling lr by linear scaling rule
self.scale_lr(self.optim_wrapper, self.auto_scale_lr_cfg)

if self.param_schedulers:
self.param_schedulers = self.build_param_scheduler( # type: ignore
self.param_schedulers) # type: ignore
Expand Down Expand Up @@ -1717,6 +1788,13 @@ def resume(self,
self.logger.info(
'Number of GPU used for current experiment is not '
'consistent with resuming from checkpoint')
if (self.auto_scale_lr_cfg is None
or not self.auto_scale_lr_cfg.get('enable', False)):
raise RuntimeError(
'Cannot automatically rescale lr in resuming. Please '
'make sure the number of GPU is consistent with the '
'previous training state resuming from the checkpoint '
'or set `enable` in `auto_scale_lr to False.')

# resume meta information meta
self.meta = checkpoint['meta']
Expand Down
38 changes: 38 additions & 0 deletions tests/test_runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def setUp(self):
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
auto_scale_lr_cfg=dict(base_batch_size=16, enable=False),
optim_wrapper=dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
Expand Down Expand Up @@ -651,6 +652,43 @@ def test_wrap_model(self):
runner = Runner.from_cfg(cfg)
self.assertIsInstance(runner.model, CustomModelWrapper)

def test_scale_lr(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_scale_lr'
runner = Runner.from_cfg(cfg)

# When no base_batch_size in auto_scale_lr_cfg, an
# assertion error will raise.
auto_scale_lr_cfg = dict(enable=True)
optim_wrapper = OptimWrapper(SGD(runner.model.parameters(), lr=0.01))
with self.assertRaises(AssertionError):
runner.scale_lr(optim_wrapper, auto_scale_lr_cfg)

# When auto_scale_lr_cfg is None or enable is False, the lr will
# not be linearly scaled.
auto_scale_lr_cfg = dict(base_batch_size=16, enable=False)
optim_wrapper = OptimWrapper(SGD(runner.model.parameters(), lr=0.01))
runner.scale_lr(optim_wrapper)
self.assertEqual(optim_wrapper.optimizer.param_groups[0]['lr'], 0.01)
runner.scale_lr(optim_wrapper, auto_scale_lr_cfg)
self.assertEqual(optim_wrapper.optimizer.param_groups[0]['lr'], 0.01)

# When auto_scale_lr_cfg is correct and enable is True, the lr will
# be linearly scaled.
auto_scale_lr_cfg = dict(base_batch_size=16, enable=True)
real_bs = runner.world_size * cfg.train_dataloader['batch_size']
optim_wrapper = OptimWrapper(SGD(runner.model.parameters(), lr=0.01))
runner.scale_lr(optim_wrapper, auto_scale_lr_cfg)
self.assertEqual(optim_wrapper.optimizer.param_groups[0]['lr'],
0.01 * (real_bs / 16))

# Test when optim_wrapper is an OptimWrapperDict
optim_wrapper = OptimWrapper(SGD(runner.model.parameters(), lr=0.01))
wrapper_dict = OptimWrapperDict(wrapper=optim_wrapper)
runner.scale_lr(wrapper_dict, auto_scale_lr_cfg)
scaled_lr = wrapper_dict['wrapper'].optimizer.param_groups[0]['lr']
self.assertEqual(scaled_lr, 0.01 * (real_bs / 16))

def test_build_optim_wrapper(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_optim_wrapper'
Expand Down

0 comments on commit 8f3fcee

Please sign in to comment.