Skip to content

Commit f0c43fd

Browse files
authored
[Feature] Add OneCycleLrUpdaterHook (#906)
* [Feature] Add OneCycleLrUpdaterHook * fix docstring * fix docstring * Remove redundant code
1 parent 3ae1b25 commit f0c43fd

File tree

3 files changed

+412
-4
lines changed

3 files changed

+412
-4
lines changed

mmcv/runner/hooks/lr_updater.py

+147
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Open-MMLab. All rights reserved.
2+
import numbers
23
from math import cos, pi
34

45
from .hook import HOOKS, Hook
@@ -398,6 +399,124 @@ def get_lr(self, runner, base_lr):
398399
progress / (end_iter - start_iter))
399400

400401

402+
@HOOKS.register_module()
403+
class OneCycleLrUpdaterHook(LrUpdaterHook):
404+
"""One Cycle LR Scheduler.
405+
406+
The 1cycle learning rate policy changes the learning rate after every
407+
batch. The one cycle learning rate policy is described in
408+
https://arxiv.org/pdf/1708.07120.pdf
409+
410+
Args:
411+
max_lr (float or list): Upper learning rate boundaries in the cycle
412+
for each parameter group.
413+
pct_start (float): The percentage of the cycle (in number of steps)
414+
spent increasing the learning rate.
415+
Default: 0.3
416+
anneal_strategy (str): {'cos', 'linear'}
417+
Specifies the annealing strategy: 'cos' for cosine annealing,
418+
'linear' for linear annealing.
419+
Default: 'cos'
420+
div_factor (float): Determines the initial learning rate via
421+
initial_lr = max_lr/div_factor
422+
Default: 25
423+
final_div_factor (float): Determines the minimum learning rate via
424+
min_lr = initial_lr/final_div_factor
425+
Default: 1e4
426+
three_phase (bool): If three_phase is True, use a third phase of the
427+
schedule to annihilate the learning rate according to
428+
final_div_factor instead of modifying the second phase (the first
429+
two phases will be symmetrical about the step indicated by
430+
pct_start).
431+
Default: False
432+
"""
433+
434+
def __init__(self,
435+
max_lr,
436+
pct_start=0.3,
437+
anneal_strategy='cos',
438+
div_factor=25,
439+
final_div_factor=1e4,
440+
three_phase=False,
441+
**kwargs):
442+
# validate by_epoch, currently only support by_epoch = False
443+
if 'by_epoch' not in kwargs:
444+
kwargs['by_epoch'] = False
445+
else:
446+
assert not kwargs['by_epoch'], \
447+
'currently only support "by_epoch" = False'
448+
if not isinstance(max_lr, (numbers.Number, list, dict)):
449+
raise ValueError('the type of max_lr must be the one of list or '
450+
f'dict, but got {type(max_lr)}')
451+
self._max_lr = max_lr
452+
# validate pct_start
453+
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
454+
raise ValueError('expected float between 0 and 1 pct_start, but '
455+
f'got {pct_start}')
456+
self.pct_start = pct_start
457+
# validate anneal_strategy
458+
if anneal_strategy not in ['cos', 'linear']:
459+
raise ValueError('anneal_strategy must be one of "cos" or '
460+
f'"linear", instead got {anneal_strategy}')
461+
elif anneal_strategy == 'cos':
462+
self.anneal_func = annealing_cos
463+
elif anneal_strategy == 'linear':
464+
self.anneal_func = annealing_linear
465+
self.div_factor = div_factor
466+
self.final_div_factor = final_div_factor
467+
self.three_phase = three_phase
468+
self.lr_phases = [] # init lr_phases
469+
super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
470+
471+
def before_run(self, runner):
472+
if isinstance(runner.optimizer, dict):
473+
self.base_lr = {}
474+
for k, optim in runner.optimizer.items():
475+
_max_lr = format_param(k, optim, self._max_lr)
476+
self.base_lr[k] = [lr / self.div_factor for lr in _max_lr]
477+
for group, lr in zip(optim.param_groups, self.base_lr[k]):
478+
group.setdefault('initial_lr', lr)
479+
else:
480+
k = type(runner.optimizer).__name__
481+
_max_lr = format_param(k, runner.optimizer, self._max_lr)
482+
self.base_lr = [lr / self.div_factor for lr in _max_lr]
483+
for group, lr in zip(runner.optimizer.param_groups, self.base_lr):
484+
group.setdefault('initial_lr', lr)
485+
486+
if self.three_phase:
487+
self.lr_phases.append([
488+
float(self.pct_start * runner.max_iters) - 1, 1,
489+
self.div_factor
490+
])
491+
self.lr_phases.append([
492+
float(2 * self.pct_start * runner.max_iters) - 2,
493+
self.div_factor, 1
494+
])
495+
self.lr_phases.append(
496+
[runner.max_iters - 1, 1, 1 / self.final_div_factor])
497+
else:
498+
self.lr_phases.append([
499+
float(self.pct_start * runner.max_iters) - 1, 1,
500+
self.div_factor
501+
])
502+
self.lr_phases.append([
503+
runner.max_iters - 1, self.div_factor,
504+
1 / self.final_div_factor
505+
])
506+
507+
def get_lr(self, runner, base_lr):
508+
curr_iter = runner.iter
509+
start_iter = 0
510+
for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases):
511+
if curr_iter <= end_iter:
512+
pct = (curr_iter - start_iter) / (end_iter - start_iter)
513+
lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr,
514+
pct)
515+
break
516+
start_iter = end_iter
517+
return lr
518+
519+
401520
def annealing_cos(start, end, factor, weight=1):
402521
"""Calculate annealing cos learning rate.
403522
@@ -414,3 +533,31 @@ def annealing_cos(start, end, factor, weight=1):
414533
"""
415534
cos_out = cos(pi * factor) + 1
416535
return end + 0.5 * weight * (start - end) * cos_out
536+
537+
538+
def annealing_linear(start, end, factor):
539+
"""Calculate annealing linear learning rate.
540+
541+
Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
542+
543+
Args:
544+
start (float): The starting learning rate of the linear annealing.
545+
end (float): The ending learing rate of the linear annealing.
546+
factor (float): The coefficient of `pi` when calculating the current
547+
percentage. Range from 0.0 to 1.0.
548+
"""
549+
return start + (end - start) * factor
550+
551+
552+
def format_param(name, optim, param):
553+
if isinstance(param, numbers.Number):
554+
return [param] * len(optim.param_groups)
555+
elif isinstance(param, (list, tuple)): # multi param groups
556+
if len(param) != len(optim.param_groups):
557+
raise ValueError(f'expected {len(optim.param_groups)} '
558+
f'values for {name}, got {len(param)}')
559+
return param
560+
else: # multi optimizers
561+
if name not in param:
562+
raise KeyError(f'{name} is not found in {param.keys()}')
563+
return param[name]

mmcv/runner/hooks/momentum_updater.py

+197-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .hook import HOOKS, Hook
2-
from .lr_updater import annealing_cos
2+
from .lr_updater import annealing_cos, annealing_linear, format_param
33

44

55
class MomentumUpdaterHook(Hook):
@@ -130,7 +130,7 @@ def get_momentum(self, runner, base_momentum):
130130
class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
131131
"""Cyclic momentum Scheduler.
132132
133-
Implemet the cyclical momentum scheduler policy described in
133+
Implement the cyclical momentum scheduler policy described in
134134
https://arxiv.org/pdf/1708.07120.pdf
135135
136136
This momentum scheduler usually used together with the CyclicLRUpdater
@@ -197,3 +197,198 @@ def get_momentum(self, runner, base_momentum):
197197
return annealing_cos(base_momentum * start_ratio,
198198
base_momentum * end_ratio,
199199
progress / (end_iter - start_iter))
200+
201+
202+
@HOOKS.register_module()
203+
class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
204+
"""OneCycle momentum Scheduler.
205+
206+
This momentum scheduler usually used together with the OneCycleLrUpdater
207+
to improve the performance.
208+
209+
Args:
210+
base_momentum (float or list): Lower momentum boundaries in the cycle
211+
for each parameter group. Note that momentum is cycled inversely
212+
to learning rate; at the peak of a cycle, momentum is
213+
'base_momentum' and learning rate is 'max_lr'.
214+
Default: 0.85
215+
max_momentum (float or list): Upper momentum boundaries in the cycle
216+
for each parameter group. Functionally,
217+
it defines the cycle amplitude (max_momentum - base_momentum).
218+
Note that momentum is cycled inversely
219+
to learning rate; at the start of a cycle, momentum is
220+
'max_momentum' and learning rate is 'base_lr'
221+
Default: 0.95
222+
pct_start (float): The percentage of the cycle (in number of steps)
223+
spent increasing the learning rate.
224+
Default: 0.3
225+
anneal_strategy (str): {'cos', 'linear'}
226+
Specifies the annealing strategy: 'cos' for cosine annealing,
227+
'linear' for linear annealing.
228+
Default: 'cos'
229+
three_phase (bool): If three_phase is True, use a third phase of the
230+
schedule to annihilate the learning rate according to
231+
final_div_factor instead of modifying the second phase (the first
232+
two phases will be symmetrical about the step indicated by
233+
pct_start).
234+
Default: False
235+
"""
236+
237+
def __init__(self,
238+
base_momentum=0.85,
239+
max_momentum=0.95,
240+
pct_start=0.3,
241+
anneal_strategy='cos',
242+
three_phase=False,
243+
**kwargs):
244+
# validate by_epoch, currently only support by_epoch=False
245+
if 'by_epoch' not in kwargs:
246+
kwargs['by_epoch'] = False
247+
else:
248+
assert not kwargs['by_epoch'], \
249+
'currently only support "by_epoch" = False'
250+
if not isinstance(base_momentum, (float, list, dict)):
251+
raise ValueError('base_momentum must be the type among of float,'
252+
'list or dict.')
253+
self._base_momentum = base_momentum
254+
if not isinstance(max_momentum, (float, list, dict)):
255+
raise ValueError('max_momentum must be the type among of float,'
256+
'list or dict.')
257+
self._max_momentum = max_momentum
258+
# validate pct_start
259+
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
260+
raise ValueError('Expected float between 0 and 1 pct_start, but '
261+
f'got {pct_start}')
262+
self.pct_start = pct_start
263+
# validate anneal_strategy
264+
if anneal_strategy not in ['cos', 'linear']:
265+
raise ValueError('anneal_strategy must by one of "cos" or '
266+
f'"linear", instead got {anneal_strategy}')
267+
elif anneal_strategy == 'cos':
268+
self.anneal_func = annealing_cos
269+
elif anneal_strategy == 'linear':
270+
self.anneal_func = annealing_linear
271+
self.three_phase = three_phase
272+
self.momentum_phases = [] # init momentum_phases
273+
super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs)
274+
275+
def before_run(self, runner):
276+
if isinstance(runner.optimizer, dict):
277+
for k, optim in runner.optimizer.items():
278+
if ('momentum' not in optim.defaults
279+
and 'betas' not in optim.defaults):
280+
raise ValueError('optimizer must support momentum with'
281+
'option enabled')
282+
self.use_beta1 = 'betas' in optim.defaults
283+
_base_momentum = format_param(k, optim, self._base_momentum)
284+
_max_momentum = format_param(k, optim, self._max_momentum)
285+
for group, b_momentum, m_momentum in zip(
286+
optim.param_groups, _base_momentum, _max_momentum):
287+
if self.use_beta1:
288+
_, beta2 = group['betas']
289+
group['betas'] = (m_momentum, beta2)
290+
else:
291+
group['momentum'] = m_momentum
292+
group['base_momentum'] = b_momentum
293+
group['max_momentum'] = m_momentum
294+
else:
295+
optim = runner.optimizer
296+
if ('momentum' not in optim.defaults
297+
and 'betas' not in optim.defaults):
298+
raise ValueError('optimizer must support momentum with'
299+
'option enabled')
300+
self.use_beta1 = 'betas' in optim.defaults
301+
k = type(optim).__name__
302+
_base_momentum = format_param(k, optim, self._base_momentum)
303+
_max_momentum = format_param(k, optim, self._max_momentum)
304+
for group, b_momentum, m_momentum in zip(optim.param_groups,
305+
_base_momentum,
306+
_max_momentum):
307+
if self.use_beta1:
308+
_, beta2 = group['betas']
309+
group['betas'] = (m_momentum, beta2)
310+
else:
311+
group['momentum'] = m_momentum
312+
group['base_momentum'] = b_momentum
313+
group['max_momentum'] = m_momentum
314+
315+
if self.three_phase:
316+
self.momentum_phases.append({
317+
'end_iter':
318+
float(self.pct_start * runner.max_iters) - 1,
319+
'start_momentum':
320+
'max_momentum',
321+
'end_momentum':
322+
'base_momentum'
323+
})
324+
self.momentum_phases.append({
325+
'end_iter':
326+
float(2 * self.pct_start * runner.max_iters) - 2,
327+
'start_momentum':
328+
'base_momentum',
329+
'end_momentum':
330+
'max_momentum'
331+
})
332+
self.momentum_phases.append({
333+
'end_iter': runner.max_iters - 1,
334+
'start_momentum': 'max_momentum',
335+
'end_momentum': 'max_momentum'
336+
})
337+
else:
338+
self.momentum_phases.append({
339+
'end_iter':
340+
float(self.pct_start * runner.max_iters) - 1,
341+
'start_momentum':
342+
'max_momentum',
343+
'end_momentum':
344+
'base_momentum'
345+
})
346+
self.momentum_phases.append({
347+
'end_iter': runner.max_iters - 1,
348+
'start_momentum': 'base_momentum',
349+
'end_momentum': 'max_momentum'
350+
})
351+
352+
def _set_momentum(self, runner, momentum_groups):
353+
if isinstance(runner.optimizer, dict):
354+
for k, optim in runner.optimizer.items():
355+
for param_group, mom in zip(optim.param_groups,
356+
momentum_groups[k]):
357+
if 'momentum' in param_group.keys():
358+
param_group['momentum'] = mom
359+
elif 'betas' in param_group.keys():
360+
param_group['betas'] = (mom, param_group['betas'][1])
361+
else:
362+
for param_group, mom in zip(runner.optimizer.param_groups,
363+
momentum_groups):
364+
if 'momentum' in param_group.keys():
365+
param_group['momentum'] = mom
366+
elif 'betas' in param_group.keys():
367+
param_group['betas'] = (mom, param_group['betas'][1])
368+
369+
def get_momentum(self, runner, param_group):
370+
curr_iter = runner.iter
371+
start_iter = 0
372+
for i, phase in enumerate(self.momentum_phases):
373+
end_iter = phase['end_iter']
374+
if curr_iter <= end_iter or i == len(self.momentum_phases) - 1:
375+
pct = (curr_iter - start_iter) / (end_iter - start_iter)
376+
lr = self.anneal_func(param_group[phase['start_momentum']],
377+
param_group[phase['end_momentum']], pct)
378+
break
379+
start_iter = end_iter
380+
return lr
381+
382+
def get_regular_momentum(self, runner):
383+
if isinstance(runner.optimizer, dict):
384+
momentum_groups = {}
385+
for k, optim in runner.optimizer.items():
386+
for param_group in optim.param_groups:
387+
momentum_groups[k].append(
388+
self.get_momentum(runner, param_group))
389+
return momentum_groups
390+
else:
391+
momentum_groups = []
392+
for param_group in runner.optimizer.param_groups:
393+
momentum_groups.append(self.get_momentum(runner, param_group))
394+
return momentum_groups

0 commit comments

Comments
 (0)