From a5c6199781fa44de2e4ab4b06a4d96412225de2c Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Wed, 24 Mar 2021 19:40:57 +0800 Subject: [PATCH 1/4] [Feature] Add OneCycleLrUpdaterHook --- mmcv/runner/hooks/lr_updater.py | 147 +++++++++++++++++++ mmcv/runner/hooks/momentum_updater.py | 199 +++++++++++++++++++++++++- tests/test_runner/test_hooks.py | 70 ++++++++- 3 files changed, 412 insertions(+), 4 deletions(-) diff --git a/mmcv/runner/hooks/lr_updater.py b/mmcv/runner/hooks/lr_updater.py index 0120b58d48..266c366dce 100644 --- a/mmcv/runner/hooks/lr_updater.py +++ b/mmcv/runner/hooks/lr_updater.py @@ -1,4 +1,5 @@ # Copyright (c) Open-MMLab. All rights reserved. +import numbers from math import cos, pi from .hook import HOOKS, Hook @@ -398,6 +399,124 @@ def get_lr(self, runner, base_lr): progress / (end_iter - start_iter)) +@HOOKS.register_module() +class OneCycleLrUpdaterHook(LrUpdaterHook): + """Cyclic LR Scheduler. + + The 1cycle learning rate policy changes the learning rate after every + batch. The one cycle learning rate policy is described in + https://arxiv.org/pdf/1708.07120.pdf + + Args: + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. + pct_start (float): The percentage of the cycle (in number of steps) + spent increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: 'cos' for cosine annealing, + 'linear' for linear annealing. + Default: 'cos' + div_factor (float): Determines the initial learning rate via + initial_lr = max_lr/div_factor + Default: 25 + final_div_factor (float): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor + Default: 1e4 + three_phase (bool): If three_phase is True, use a third phase of the + schedule to annihilate the learning rate according to + final_div_factor instead of modifying the second phase (the first + two phases will be symmetrical about the step indicated by + pct_start). + Default: False + """ + + def __init__(self, + max_lr, + pct_start=0.3, + anneal_strategy='cos', + div_factor=25, + final_div_factor=1e4, + three_phase=False, + **kwargs): + # validate by_epoch, currently only support by_epoch = False + if 'by_epoch' not in kwargs: + kwargs['by_epoch'] = False + else: + assert not kwargs['by_epoch'], \ + 'currently only support "by_epoch" = False' + if not isinstance(max_lr, (numbers.Number, list, dict)): + raise ValueError('the type of max_lr must be the one of list or ' + f'dict, but got {type(max_lr)}') + self._max_lr = max_lr + # validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError('expected float between 0 and 1 pct_start, but ' + f'got {pct_start}') + self.pct_start = pct_start + # validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError('anneal_strategy must be one of "cos" or ' + f'"linear", instead got {anneal_strategy}') + elif anneal_strategy == 'cos': + self.anneal_func = annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = annealing_linear + self.div_factor = div_factor + self.final_div_factor = final_div_factor + self.three_phase = three_phase + self.lr_phases = [] # init lr_phases + super(OneCycleLrUpdaterHook, self).__init__(**kwargs) + + def before_run(self, runner): + if isinstance(runner.optimizer, dict): + self.base_lr = {} + for k, optim in runner.optimizer.items(): + _max_lr = format_param(k, optim, self._max_lr) + self.base_lr[k] = [lr / self.div_factor for lr in _max_lr] + for group, lr in zip(optim.param_groups, self.base_lr[k]): + group.setdefault('initial_lr', lr) + else: + k = type(runner.optimizer).__name__ + _max_lr = format_param(k, runner.optimizer, self._max_lr) + self.base_lr = [lr / self.div_factor for lr in _max_lr] + for group, lr in zip(runner.optimizer.param_groups, self.base_lr): + group.setdefault('initial_lr', lr) + + if self.three_phase: + self.lr_phases.append([ + float(self.pct_start * runner.max_iters) - 1, 1, + self.div_factor + ]) + self.lr_phases.append([ + float(2 * self.pct_start * runner.max_iters) - 2, + self.div_factor, 1 + ]) + self.lr_phases.append( + [runner.max_iters - 1, 1, 1 / self.final_div_factor]) + else: + self.lr_phases.append([ + float(self.pct_start * runner.max_iters) - 1, 1, + self.div_factor + ]) + self.lr_phases.append([ + runner.max_iters - 1, self.div_factor, + 1 / self.final_div_factor + ]) + + def get_lr(self, runner, base_lr): + curr_iter = runner.iter + start_iter = 0 + for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases): + if curr_iter <= end_iter or i == len(self.lr_phases) - 1: + pct = (curr_iter - start_iter) / (end_iter - start_iter) + lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr, + pct) + break + start_iter = end_iter + return lr + + def annealing_cos(start, end, factor, weight=1): """Calculate annealing cos learning rate. @@ -414,3 +533,31 @@ def annealing_cos(start, end, factor, weight=1): """ cos_out = cos(pi * factor) + 1 return end + 0.5 * weight * (start - end) * cos_out + + +def annealing_linear(start, end, factor): + """Calculate annealing linear learning rate. + + Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0. + + Args: + start (float): The starting learning rate of the linear annealing. + end (float): The ending learing rate of the linear annealing. + factor (float): The coefficient of `pi` when calculating the current + percentage. Range from 0.0 to 1.0. + """ + return start + (end - start) * factor + + +def format_param(name, optim, param): + if isinstance(param, numbers.Number): + return [param] * len(optim.param_groups) + elif isinstance(param, (list, tuple)): # multi param groups + if len(param) != len(optim.param_groups): + raise ValueError(f'expected {len(optim.param_groups)} ' + f'values for {name}, got {len(param)}') + return param + else: # multi optimizers + if name not in param: + raise KeyError(f'{name} is not found in {param.keys()}') + return param[name] diff --git a/mmcv/runner/hooks/momentum_updater.py b/mmcv/runner/hooks/momentum_updater.py index b349071ee6..a9f70cfec0 100644 --- a/mmcv/runner/hooks/momentum_updater.py +++ b/mmcv/runner/hooks/momentum_updater.py @@ -1,5 +1,5 @@ from .hook import HOOKS, Hook -from .lr_updater import annealing_cos +from .lr_updater import annealing_cos, annealing_linear, format_param class MomentumUpdaterHook(Hook): @@ -130,7 +130,7 @@ def get_momentum(self, runner, base_momentum): class CyclicMomentumUpdaterHook(MomentumUpdaterHook): """Cyclic momentum Scheduler. - Implemet the cyclical momentum scheduler policy described in + Implement the cyclical momentum scheduler policy described in https://arxiv.org/pdf/1708.07120.pdf This momentum scheduler usually used together with the CyclicLRUpdater @@ -197,3 +197,198 @@ def get_momentum(self, runner, base_momentum): return annealing_cos(base_momentum * start_ratio, base_momentum * end_ratio, progress / (end_iter - start_iter)) + + +@HOOKS.register_module() +class OneCycleMomentumUpdaterHook(MomentumUpdaterHook): + """OneCycle momentum Scheduler. + + This momentum scheduler usually used together with the OneCycleLrUpdater + to improve the performance. + + Args: + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.85 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is + 'max_momentum' and learning rate is 'base_lr' + Default: 0.95 + pct_start (float): The percentage of the cycle (in number of steps) + spent increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: 'cos' for cosine annealing, + 'linear' for linear annealing. + Default: 'cos' + three_phase (bool): If three_phase is True, use a third phase of the + schedule to annihilate the learning rate according to + final_div_factor instead of modifying the second phase (the first + two phases will be symmetrical about the step indicated by + pct_start). + Default: False + """ + + def __init__(self, + base_momentum=0.85, + max_momentum=0.95, + pct_start=0.3, + anneal_strategy='cos', + three_phase=False, + **kwargs): + # validate by_epoch, currently only support by_epoch=False + if 'by_epoch' not in kwargs: + kwargs['by_epoch'] = False + else: + assert not kwargs['by_epoch'], \ + 'currently only support "by_epoch" = False' + if not isinstance(base_momentum, (float, list, dict)): + raise ValueError('base_momentum must be the type among of float,' + 'list or dict.') + self._base_momentum = base_momentum + if not isinstance(max_momentum, (float, list, dict)): + raise ValueError('max_momentum must be the type among of float,' + 'list or dict.') + self._max_momentum = max_momentum + # validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError('Expected float between 0 and 1 pct_start, but ' + f'got {pct_start}') + self.pct_start = pct_start + # validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError('anneal_strategy must by one of "cos" or ' + f'"linear", instead got {anneal_strategy}') + elif anneal_strategy == 'cos': + self.anneal_func = annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = annealing_linear + self.three_phase = three_phase + self.momentum_phases = [] # init momentum_phases + super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs) + + def before_run(self, runner): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + if ('momentum' not in optim.defaults + and 'betas' not in optim.defaults): + raise ValueError('optimizer must support momentum with' + 'option enabled') + self.use_beta1 = 'betas' in optim.defaults + _base_momentum = format_param(k, optim, self._base_momentum) + _max_momentum = format_param(k, optim, self._max_momentum) + for group, b_momentum, m_momentum in zip( + optim.param_groups, _base_momentum, _max_momentum): + if self.use_beta1: + _, beta2 = group['betas'] + group['betas'] = (m_momentum, beta2) + else: + group['momentum'] = m_momentum + group['base_momentum'] = b_momentum + group['max_momentum'] = m_momentum + else: + optim = runner.optimizer + if ('momentum' not in optim.defaults + and 'betas' not in optim.defaults): + raise ValueError('optimizer must support momentum with' + 'option enabled') + self.use_beta1 = 'betas' in optim.defaults + k = type(optim).__name__ + _base_momentum = format_param(k, optim, self._base_momentum) + _max_momentum = format_param(k, optim, self._max_momentum) + for group, b_momentum, m_momentum in zip(optim.param_groups, + _base_momentum, + _max_momentum): + if self.use_beta1: + _, beta2 = group['betas'] + group['betas'] = (m_momentum, beta2) + else: + group['momentum'] = m_momentum + group['base_momentum'] = b_momentum + group['max_momentum'] = m_momentum + + if self.three_phase: + self.momentum_phases.append({ + 'end_iter': + float(self.pct_start * runner.max_iters) - 1, + 'start_momentum': + 'max_momentum', + 'end_momentum': + 'base_momentum' + }) + self.momentum_phases.append({ + 'end_iter': + float(2 * self.pct_start * runner.max_iters) - 2, + 'start_momentum': + 'base_momentum', + 'end_momentum': + 'max_momentum' + }) + self.momentum_phases.append({ + 'end_iter': runner.max_iters - 1, + 'start_momentum': 'max_momentum', + 'end_momentum': 'max_momentum' + }) + else: + self.momentum_phases.append({ + 'end_iter': + float(self.pct_start * runner.max_iters) - 1, + 'start_momentum': + 'max_momentum', + 'end_momentum': + 'base_momentum' + }) + self.momentum_phases.append({ + 'end_iter': runner.max_iters - 1, + 'start_momentum': 'base_momentum', + 'end_momentum': 'max_momentum' + }) + + def _set_momentum(self, runner, momentum_groups): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + for param_group, mom in zip(optim.param_groups, + momentum_groups[k]): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + else: + for param_group, mom in zip(runner.optimizer.param_groups, + momentum_groups): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + + def get_momentum(self, runner, param_group): + curr_iter = runner.iter + start_iter = 0 + for i, phase in enumerate(self.momentum_phases): + end_iter = phase['end_iter'] + if curr_iter <= end_iter or i == len(self.momentum_phases) - 1: + pct = (curr_iter - start_iter) / (end_iter - start_iter) + lr = self.anneal_func(param_group[phase['start_momentum']], + param_group[phase['end_momentum']], pct) + break + start_iter = end_iter + return lr + + def get_regular_momentum(self, runner): + if isinstance(runner.optimizer, dict): + momentum_groups = {} + for k, optim in runner.optimizer.items(): + for param_group in optim.param_groups: + momentum_groups[k].append( + self.get_momentum(runner, param_group)) + return momentum_groups + else: + momentum_groups = [] + for param_group in runner.optimizer.param_groups: + momentum_groups.append(self.get_momentum(runner, param_group)) + return momentum_groups diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index 292ab483bc..f6dff699ea 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -1,7 +1,7 @@ """Tests the hooks with runners. CommandLine: - pytest tests/test_hooks.py + pytest tests/test_runner/test_hooks.py xdoctest tests/test_hooks.py zero """ import logging @@ -21,7 +21,8 @@ from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook, MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook, build_runner) -from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook +from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook, + OneCycleLrUpdaterHook) def test_checkpoint_hook(): @@ -251,6 +252,71 @@ def test_cosine_runner_hook(): hook.writer.add_scalars.assert_has_calls(calls, any_order=True) +def test_one_cycle_runner_hook(): + """Test OneCycleLrUpdaterHook and OneCycleMomentumUpdaterHook.""" + with pytest.raises(AssertionError): + # by_epoch should be True + OneCycleLrUpdaterHook(max_lr=0.1, by_epoch=True) + + with pytest.raises(ValueError): + # expected float between 0 and 1 + OneCycleLrUpdaterHook(max_lr=0.1, pct_start=-0.1) + + with pytest.raises(ValueError): + # anneal_strategy should be either 'cos' or 'linear' + OneCycleLrUpdaterHook(max_lr=0.1, anneal_strategy='sin') + + sys.modules['pavi'] = MagicMock() + loader = DataLoader(torch.ones((10, 2))) + runner = _build_demo_runner() + + # add momentum scheduler + hook_cfg = dict( + type='OneCycleMomentumUpdaterHook', + base_momentum=0.85, + max_momentum=0.95, + pct_start=0.5, + anneal_strategy='cos', + three_phase=False) + runner.register_hook_from_cfg(hook_cfg) + + # add momentum LR scheduler + hook_cfg = dict( + type='OneCycleLrUpdaterHook', + max_lr=0.01, + pct_start=0.5, + anneal_strategy='cos', + div_factor=25, + final_div_factor=1e4, + three_phase=False) + runner.register_hook_from_cfg(hook_cfg) + runner.register_hook_from_cfg(dict(type='IterTimerHook')) + runner.register_hook(IterTimerHook()) + # add pavi hook + hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True) + runner.register_hook(hook) + runner.run([loader], [('train', 1)]) + shutil.rmtree(runner.work_dir) + + # TODO: use a more elegant way to check values + assert hasattr(hook, 'writer') + calls = [ + call('train', { + 'learning_rate': 0.0003999999999999993, + 'momentum': 0.95 + }, 1), + call('train', { + 'learning_rate': 0.00904508879153485, + 'momentum': 0.8595491502812526 + }, 6), + call('train', { + 'learning_rate': 4e-08, + 'momentum': 0.95 + }, 10) + ] + hook.writer.add_scalars.assert_has_calls(calls, any_order=True) + + def test_cosine_restart_lr_update_hook(): """Test CosineRestartLrUpdaterHook.""" with pytest.raises(AssertionError): From a4d589abcc01a13c79afd4fb747723c044506ccd Mon Sep 17 00:00:00 2001 From: zhouzaida <58739961+zhouzaida@users.noreply.github.com> Date: Thu, 25 Mar 2021 16:05:25 +0800 Subject: [PATCH 2/4] fix docstring --- mmcv/runner/hooks/lr_updater.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/runner/hooks/lr_updater.py b/mmcv/runner/hooks/lr_updater.py index 266c366dce..f4ed47e07f 100644 --- a/mmcv/runner/hooks/lr_updater.py +++ b/mmcv/runner/hooks/lr_updater.py @@ -401,7 +401,7 @@ def get_lr(self, runner, base_lr): @HOOKS.register_module() class OneCycleLrUpdaterHook(LrUpdaterHook): - """Cyclic LR Scheduler. + """One Cycle LR Scheduler. The 1cycle learning rate policy changes the learning rate after every batch. The one cycle learning rate policy is described in From c56909f5668c4a487b6b12aacd665bed69f405a8 Mon Sep 17 00:00:00 2001 From: zhouzaida <58739961+zhouzaida@users.noreply.github.com> Date: Thu, 25 Mar 2021 19:06:45 +0800 Subject: [PATCH 3/4] fix docstring --- tests/test_runner/test_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index f6dff699ea..ff00fdbcc3 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -255,7 +255,7 @@ def test_cosine_runner_hook(): def test_one_cycle_runner_hook(): """Test OneCycleLrUpdaterHook and OneCycleMomentumUpdaterHook.""" with pytest.raises(AssertionError): - # by_epoch should be True + # by_epoch should be False OneCycleLrUpdaterHook(max_lr=0.1, by_epoch=True) with pytest.raises(ValueError): From 602fe1c3075775d53d231d6677b42d5509ea0099 Mon Sep 17 00:00:00 2001 From: zhouzaida <58739961+zhouzaida@users.noreply.github.com> Date: Wed, 31 Mar 2021 11:01:02 +0800 Subject: [PATCH 4/4] Remove redundant code --- mmcv/runner/hooks/lr_updater.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/runner/hooks/lr_updater.py b/mmcv/runner/hooks/lr_updater.py index f4ed47e07f..d86f0dcd0f 100644 --- a/mmcv/runner/hooks/lr_updater.py +++ b/mmcv/runner/hooks/lr_updater.py @@ -508,7 +508,7 @@ def get_lr(self, runner, base_lr): curr_iter = runner.iter start_iter = 0 for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases): - if curr_iter <= end_iter or i == len(self.lr_phases) - 1: + if curr_iter <= end_iter: pct = (curr_iter - start_iter) / (end_iter - start_iter) lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr, pct)