-
Notifications
You must be signed in to change notification settings - Fork 6
/
sgdr.py
64 lines (50 loc) · 2.05 KB
/
sgdr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
def adjust_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
class SGDRScheduler:
"""
Implements STOCHASTIC GRADIENT DESCENT WITH WARM RESTARTS (SGDR)
with cosine annealing from https://arxiv.org/pdf/1608.03983.pdf.
"""
def __init__(self, optimizer, max_lr, cycle_length, min_lr=1e-5, warmup_steps=10, current_step=0):
self.optimizer = optimizer
self.min_lr = min_lr
self.max_lr = max_lr
self.lr = optimizer.param_groups[0]['lr']
self.cycle_length = cycle_length
self.current_step = current_step
self.warmup_steps = warmup_steps
def calculate_lr(self):
"""
calculates new learning rate with cosine annealing
"""
step = self.current_step % self.cycle_length # get step in current cycle
self.lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * \
(1 + np.cos((step / self.cycle_length) * np.pi))
def step(self):
self.current_step += 1
self.calculate_lr()
if self.current_step in range(self.warmup_steps):
self.lr /= 10.0 # take a few steps with a lower lr to "warmup"
adjust_lr(self.optimizer, self.lr)
class LRFinderScheduler:
"""
Implements exponential learning rate finding schedule from
STOCHASTIC GRADIENT DESCENT WITH WARM RESTARTS (SGDR)
https://arxiv.org/pdf/1608.03983.pdf.
Increases the learning rate exponentially every step.
Plot loss vs learning rate and choose rate at which loss was decreases the most quickly.
"""
def __init__(self, optimizer, min_lr=1e-6, gamma=2.5, current_step=0):
self.optimizer = optimizer
self.min_lr = min_lr
self.gamma = gamma
self.lr = optimizer.param_groups[0]['lr']
self.current_step = current_step
def calculate_lr(self):
self.lr = self.min_lr * (self.current_step ** self.gamma)
def step(self):
self.current_step += 1
self.calculate_lr()
adjust_lr(self.optimizer, self.lr)