-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
schedulers.py
executable file
·136 lines (110 loc) · 5.1 KB
/
schedulers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
class LRScheduler(_LRScheduler):
def __init__(self, optimizer, last_epoch=-1):
# Check if using mixed precision training
self.mixed_training = False
base_optimizer = optimizer
# Check that optimizer param is valid
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
super(LRScheduler, self).__init__(base_optimizer, last_epoch)
def step(self, epoch=None):
# Set the current training step
# ('epoch' is used to be consistent with _LRScheduler)
if self.mixed_training:
# The assumption is that the step will be constant
state_dict = self.optimizer.state[self.optimizer.param_groups[0]['params'][0]]
if 'step' in state_dict:
self.last_epoch = state_dict['step'] + 1
else:
self.last_epoch = 1
else:
self.last_epoch = epoch if epoch is not None else self.last_epoch + 1
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class ConstantLR(LRScheduler):
def get_lr(self):
return self.base_lrs
class CosineWarmUpScheduler(LRScheduler):
"""
Applies a warm up period to the learning rate.
"""
def __init__(self, optimizer, warmup, total_steps, last_epoch=-1):
self.warmup = warmup
self.total_steps = total_steps
super(CosineWarmUpScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
progress = self.last_epoch / self.total_steps
if progress < self.warmup:
return [base_lr * progress / self.warmup for base_lr in self.base_lrs]
else:
return [base_lr * (0.5 * (1.0 + torch.cos(math.pi + progress))) for base_lr in self.base_lrs]
class ConstantWarmUpScheduler(LRScheduler):
"""
Applies a warm up period to the learning rate.
"""
def __init__(self, optimizer, warmup, total_steps, last_epoch=-1):
self.warmup = warmup
self.total_steps = total_steps
super(ConstantWarmUpScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
progress = self.last_epoch / self.total_steps
if progress < self.warmup:
return [base_lr * progress / self.warmup for base_lr in self.base_lrs]
else:
return self.base_lrs
class LinearWarmUpScheduler(LRScheduler):
"""
Applies a warm up period to the learning rate.
"""
def __init__(self, optimizer, warmup, total_steps, last_epoch=-1):
self.warmup = warmup
self.total_steps = total_steps
super(LinearWarmUpScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
progress = self.last_epoch / self.total_steps
if progress < self.warmup:
return [base_lr * progress / self.warmup for base_lr in self.base_lrs]
else:
return [base_lr * max(( progress - 1.0)/(self.warmup - 1.0), 0.) for base_lr in self.base_lrs]
class PolyWarmUpScheduler(LRScheduler):
"""
Applies a warm up period to the learning rate.
"""
def __init__(self, optimizer, warmup, total_steps, degree=0.5, last_epoch=-1, base_lr=1., device='cpu'):
self.warmup = torch.tensor(warmup, device=device)
self.total_steps = torch.tensor(total_steps, device=device)
self.degree = torch.tensor(degree, device=device)
device_last_epoch = torch.tensor(last_epoch, device=device)
self.base_lr = torch.tensor(base_lr, device=device)
self.device = device
super(PolyWarmUpScheduler, self).__init__(optimizer, device_last_epoch)
def step(self, epoch=None):
param_group = self.optimizer.param_groups[0]
if 'step' in param_group:
self.last_epoch = param_group['step'] + 1
else:
self.last_epoch = torch.tensor(1., device=self.device)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
def get_lr(self):
progress = self.last_epoch / self.total_steps
lr_tensor = torch.where(progress < self.warmup, self.base_lr * progress / self.warmup, self.base_lr * ((1.0 - progress) ** self.degree))
return [lr_tensor for _ in range(len(self.optimizer.param_groups))]