forked from mila-iqia/covid_p2p_risk_prediction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
opts.py
65 lines (54 loc) · 1.91 KB
/
opts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import math
class WarmCosineMixin(object):
"""A Mixin class for torch.optim.Optimizer that implements a warmup + cosine annealing schedule."""
def __init__(self, *args, num_warmup_steps, num_steps, eta_min, eta_max, **kwargs):
if "lr" not in kwargs:
kwargs["lr"] = 0.0
super().__init__(*args, **kwargs)
self.num_warmup_steps = num_warmup_steps
self.num_steps = num_steps
self.eta_min = eta_min
self.eta_max = eta_max
# Privates
self._step = 0
self._set_lr()
def _set_lr(self):
rate = self.rate()
# noinspection PyUnresolvedReferences
for group in self.param_groups:
group["lr"] = rate
@torch.no_grad()
def step(self, closure=None):
self._step += 1
self._set_lr()
# noinspection PyUnresolvedReferences
super().step(closure=closure)
def rate(self, step=None):
"""Return a learning rate at a step (given or taken from internal attribute)."""
if step is None:
step = self._step
if step == 0:
return 0.0
if step > self.num_steps:
return self.eta_min
cos_lr = (
self.eta_min
+ (self.eta_max - self.eta_min)
* (1 + math.cos(math.pi * step / self.num_steps))
/ 2
)
lin_lr = (step / self.num_warmup_steps) * self.eta_max
return min(lin_lr, cos_lr)
class WarmCosineAdam(WarmCosineMixin, torch.optim.Adam):
pass
# noinspection PyUnresolvedReferences
class WarmCosineRMSprop(WarmCosineMixin, torch.optim.RMSprop):
pass
def __getattr__(name):
obj_in_globals = globals().get(name, None)
if obj_in_globals is not None:
assert issubclass(obj_in_globals, torch.optim.Optimizer)
return obj_in_globals
# Object not found in globals, look for optim
return getattr(torch.optim, name)