-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathscheduler.py
97 lines (84 loc) · 4.43 KB
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
class GradualWarmupScheduler(_LRScheduler):
""" Gradually warm-up(increasing) learning rate in optimizer.
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
Args:
optimizer (Optimizer): Wrapped optimizer.
multiplier: target learning rate = base lr * multiplier
total_epoch: target learning rate is reached at total_epoch, gradually
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
"""
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
self.multiplier = multiplier
if self.multiplier < 1.:
raise ValueError('multiplier should be greater than 1.')
self.total_epoch = total_epoch
self.after_scheduler = after_scheduler
self.finished = False
super().__init__(optimizer)
def get_lr(self):
if self.last_epoch > self.total_epoch:
if self.after_scheduler:
if not self.finished:
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
self.finished = True
return self.after_scheduler.get_lr()
return [base_lr * self.multiplier for base_lr in self.base_lrs]
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
def step_ReduceLROnPlateau(self, metrics, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
if self.last_epoch <= self.total_epoch:
warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
param_group['lr'] = lr
else:
if epoch is None:
self.after_scheduler.step(metrics, None)
else:
self.after_scheduler.step(metrics, epoch - self.total_epoch)
def step(self, epoch=None, metrics=None):
if type(self.after_scheduler) != ReduceLROnPlateau: # if atter scheduler is not reduce LR Plateau scheduler
if self.finished and self.after_scheduler:
if epoch is None:
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.total_epoch)
else:
return super(GradualWarmupScheduler, self).step(epoch)
else:
self.step_ReduceLROnPlateau(metrics, epoch)
class PolynomialLRDecay(_LRScheduler):
"""Polynomial decay(decrease) learning rate until step reach to max_decay_step
Args:
optimizer (Optimizer): Wrapped optimizer.
max_decay_steps: after this step, we stop decreasing learning rate
end_learning_rate: scheduler stoping learning rate decay, value of learning rate must be this value
power: TBW
"""
def __init__(self, optimizer, max_decay_steps, end_learning_rate=0.0001, power=1.0):
if max_decay_steps <= 1.:
raise ValueError('max_decay_steps should be greater than 1.')
self.max_decay_steps = max_decay_steps
self.end_learning_rate = end_learning_rate
self.power = power
self.last_step = 0
super().__init__(optimizer)
def get_lr(self):
if self.last_step > self.max_decay_steps:
return [self.end_learning_rate for _ in self.base_lrs]
return [(base_lr - self.end_learning_rate) *
((1 - self.last_step / self.max_decay_steps) ** (self.power)) +
self.end_learning_rate for base_lr in self.base_lrs]
def step(self, step=None):
if step is None:
step = self.last_step + 1
self.last_step = step if step != 0 else 1
if self.last_step <= self.max_decay_steps:
decay_lrs = [(base_lr - self.end_learning_rate) *
((1 - self.last_step / self.max_decay_steps) ** (self.power)) +
self.end_learning_rate for base_lr in self.base_lrs]
for param_group, lr in zip(self.optimizer.param_groups, decay_lrs):
param_group['lr'] = lr