This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
lr_scheduler.py
525 lines (459 loc) · 17.9 KB
/
lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Code for LR Schedulers.
See ParlAILRScheduler (super class) and subclasses for detailed documentation
"""
import math
from typing import Optional
from parlai.core.params import ParlaiParser
from parlai.core.opt import Opt
from abc import abstractmethod
from torch import optim
import numpy as np
from parlai.core.exceptions import StopTrainException
from parlai.utils.misc import warn_once
class ParlAILRScheduler(object):
"""
Class for LR Schedulers.
Includes some basic functionality by default - setting up the warmup
scheduler, passing the correct number of steps to train_step, loading and
saving states.
Subclasses must implement abstract methods train_step() and valid_step().
Schedulers should be initialized with lr_scheduler_factory().
__init__() should not be called directly.
"""
def __init__(self, hard_reset, warmup_updates, warmup_rate):
"""
Initialize warmup scheduler. Specific main schedulers should be initialized in
the subclasses. Do not invoke this method diretly.
:param optimizer optimizer:
Optimizer being used for training. May be wrapped in
fp16_optimizer_wrapper depending on whether fp16 is used.
:param state_dict states:
Possible state_dict provided by model checkpoint, for restoring
LR state.
:param bool hard_reset:
If true, the LR scheduler should ignore the state dictionary.
:param int warmup_updates:
Number of training step updates warmup scheduler should take.
:param float warmup_rate:
Starting multiplier for warmup scheduler.
"""
self._number_training_updates = 0
self.warmup_updates = max(0, warmup_updates)
self.warmup_rate = warmup_rate
self.hard_reset = hard_reset
def _init_warmup_scheduler(self, optimizer, states):
updates_so_far = states.get('number_training_updates', 0)
if self.warmup_updates > 0 and (
updates_so_far < self.warmup_updates or self.hard_reset
):
self.warmup_scheduler = optim.lr_scheduler.LambdaLR(
optimizer, self._warmup_lr
)
if states.get('warmup_scheduler'):
self.warmup_scheduler.load_state_dict(states['warmup_scheduler'])
else:
self.warmup_scheduler = None
def get_last_lr(self):
s = self.warmup_scheduler if self._is_lr_warming_up() else self.scheduler
try:
# pytorch 1.5 or newer
return s.get_last_lr()[0]
except AttributeError:
# TODO: upon getting rid of pytorch 1.4, kill this
# pytorch 1.4 or older
return s.optimizer.param_groups[0]['lr']
def _is_lr_warming_up(self):
"""
Check if we're warming up the learning rate.
"""
return (
hasattr(self, 'warmup_scheduler')
and self.warmup_scheduler is not None
and self._number_training_updates < self.warmup_updates
)
def _warmup_lr(self, step):
"""
Return lr multiplier (on initial lr) for warmup scheduler.
"""
start = self.warmup_rate
end = 1.0
progress = min(1.0, step / self.warmup_updates)
lr_mult = start + (end - start) * progress
return lr_mult
def load_state(self, states):
"""
Load state of scheduler from states.
"""
if states.get('warmup_scheduler') and getattr(self, 'warmup_scheduler', None):
self.warmup_scheduler.load_state_dict(states['warmup_scheduler'])
if self.scheduler and 'lr_scheduler' in states:
self.scheduler.load_state_dict(states['lr_scheduler'])
self._number_training_updates = states.get('number_training_updates', 0)
try:
if self._is_lr_warming_up():
self.warmup_scheduler.get_last_lr()
else:
self.scheduler.get_last_lr()
except AttributeError:
# on older pytorches
self.step(self._number_training_updates)
def get_initial_number_training_updates(self):
return self._number_training_updates
def get_state_dict(self):
"""
Return scheduler state dictionary.
"""
return self.scheduler.state_dict()
def get_warmup_state_dict(self):
"""
Return warmup scheduler state dictionary.
"""
if self.warmup_scheduler is None:
return None
return self.warmup_scheduler.state_dict()
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
lr_group = parser.add_argument_group('Learning Rate Scheduler')
lr_group.add_argument(
'--lr-scheduler',
type=str,
default='reduceonplateau',
choices=['reduceonplateau', 'none', 'fixed', 'invsqrt', 'cosine', 'linear'],
help='Learning rate scheduler.',
)
lr_group.add_argument(
'--lr-scheduler-patience',
type=int,
default=3,
help='LR scheduler patience. In number of validation runs. If using '
'fixed scheduler, LR is decayed every <patience> validations.',
)
lr_group.add_argument(
'--lr-scheduler-decay',
type=float,
default=0.5,
help='Decay factor for LR scheduler, or how much LR is multiplied by '
'when it is lowered.',
)
lr_group.add_argument(
'--invsqrt-lr-decay-gamma',
type=int,
default=-1,
help='Constant used only to find the lr multiplier for the invsqrt '
'scheduler. Must be set for --lr-scheduler invsqrt',
)
lr_group.add_argument(
'--warmup-updates',
type=int,
default=-1,
hidden=True,
help='Learning rate warmup period, in number of SGD updates. '
'Linearly scales up LR over period. Only enabled if > 0.',
)
lr_group.add_argument(
'--warmup-rate',
type=float,
default=1e-4,
hidden=True,
help='Warmup learning rate *multiplier*. Initial LR is multiplied by '
'this value. Linearly adjusted up to 1.0 across --warmup-updates '
'steps.',
)
lr_group.add_argument(
'--update-freq',
type=int,
default=1,
hidden=True,
help='Accumulate gradients N times before performing an optimizer.step().',
)
return parser
@classmethod
def lr_scheduler_factory(cls, opt, optimizer, states, hard_reset=False):
"""
Create the learning rate scheduler, and assign it to self.scheduler. This
scheduler will be updated upon a call to receive_metrics. May also create
self.warmup_scheduler, if appropriate.
:param opt opt:
Arguments received by torch_agent
:param optimizer optimizer:
Optimizer being used for training. May be wrapped in
fp16_optimizer_wrapper depending on whether fp16 is used.
:param state_dict states:
Possible state_dict provided by model checkpoint, for restoring
LR state.
:param bool hard_reset:
If true, the LR scheduler should ignore the state dictionary.
:return: ParlAILRScheduler object
"""
patience = opt.get('lr_scheduler_patience', 3)
decay = opt.get('lr_scheduler_decay', 0.5)
warmup_updates = opt.get('warmup_updates', -1)
warmup_rate = opt.get('warmup_rate', 1e-4)
max_lr_steps = opt.get('max_train_steps', -1)
if opt.get('max_lr_steps', -1) > 0:
raise ValueError(
'--max-lr-steps is **DEPRECATED**; please set --max-train-steps directly'
)
invsqrt_lr_decay_gamma = opt.get('invsqrt_lr_decay_gamma', -1)
if opt.get('lr_scheduler') == 'none':
return None
elif decay == 1.0:
warn_once(
"Your LR decay is set to 1.0. Assuming you meant you wanted "
"to disable learning rate scheduling. Adjust --lr-scheduler-decay "
"if this is not correct."
)
return None
elif opt.get('lr_scheduler') == 'reduceonplateau':
scheduler = ReduceOnPlateauLRScheduler(
optimizer, hard_reset, patience, decay, warmup_updates, warmup_rate
)
elif opt.get('lr_scheduler') == 'fixed':
scheduler = FixedLRScheduler(
optimizer, hard_reset, patience, decay, warmup_updates, warmup_rate
)
elif opt.get('lr_scheduler') == 'invsqrt':
scheduler = InvSqrtLRScheduler(
optimizer,
hard_reset,
patience,
decay,
warmup_updates,
warmup_rate,
invsqrt_lr_decay_gamma,
max_lr_steps,
)
elif opt.get('lr_scheduler') == 'cosine':
scheduler = CosineLRScheduler(
optimizer,
hard_reset,
patience,
decay,
warmup_updates,
warmup_rate,
max_lr_steps,
)
elif opt.get('lr_scheduler') == 'linear':
scheduler = LinearLRScheduler(
optimizer,
hard_reset,
patience,
decay,
warmup_updates,
warmup_rate,
max_lr_steps,
)
else:
raise ValueError(
"Don't know what to do with --lr-scheduler '{}'".format(
opt.get('lr_scheduler')
)
)
# time to load LR state from the checkpoint, if possible.
if (
# there is already an old LR scheduler saved on disk
states
# and there was a scheduler in the dump
and 'lr_scheduler_type' in states
# and the old LR scheduler is different
and states.get('lr_scheduler_type') != opt['lr_scheduler']
# and we're not already using a fresh scheduler
and not hard_reset
):
# the LR scheduler changed, start things fresh
warn_once(
f"LR scheduler ({opt['lr_scheduler']}) is different from saved "
f"({states.get('lr_scheduler_type')}). Starting fresh!"
)
hard_reset = True
if not hard_reset:
# do the actual loading (if possible)
scheduler.load_state(states)
# setup warmup scheduler after loading saved scheduler
scheduler._init_warmup_scheduler(optimizer, states)
return scheduler
def step(self, num_steps):
"""
Use the number of train steps to adjust the warmup scheduler or the main
scheduler, depending on where in training we are.
Override this method to override the behavior for training schedulers.
"""
self._number_training_updates = num_steps
if self._is_lr_warming_up():
self.warmup_scheduler.step()
else:
scheduler_steps = num_steps - self.warmup_updates
self.train_step(scheduler_steps)
@abstractmethod
def train_step(self, scheduler_steps):
"""
Use the number of train steps to decide when to adjust LR schedule.
Override this method to override the behavior for training schedulers.
"""
pass
@abstractmethod
def valid_step(self, metrics_dict):
"""
Use the metrics to decide when to adjust LR schedule.
This uses the loss as the validation metric if present, if not this
function does nothing. Note that the model must be reporting loss for
this to work.
Override this method to override the behavior for validation schedulers.
"""
pass
class ReduceOnPlateauLRScheduler(ParlAILRScheduler):
"""
Scheduler that decays by a multiplicative rate when valid loss plateaus.
"""
def __init__(
self, optimizer, hard_reset, patience, decay, warmup_updates, warmup_rate
):
super().__init__(hard_reset, warmup_updates, warmup_rate)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, 'min', factor=decay, patience=patience, verbose=True
)
def train_step(self, scheduler_steps):
pass
def valid_step(self, metrics_dict):
if self._is_lr_warming_up():
# we're not done warming up, so don't start using validation
# metrics to adjust schedule
return
if 'loss' not in metrics_dict:
# nothing to step on, just skip
warn_once("LR scheduler expected to see loss metric, but didn't.")
return
self.scheduler.step(metrics_dict['loss'])
class FixedLRScheduler(ParlAILRScheduler):
"""
Scheduler that decays by a fixed multiplicative rate at each valid step.
"""
def __init__(
self, optimizer, hard_reset, patience, decay, warmup_updates, warmup_rate
):
super().__init__(hard_reset, warmup_updates, warmup_rate)
self.scheduler = optim.lr_scheduler.StepLR(optimizer, patience, gamma=decay)
def train_step(self, scheduler_steps):
pass
def valid_step(self, metrics_dict):
if self._is_lr_warming_up():
# we're not done warming up, so don't start using validation
# metrics to adjust schedule
return
self.scheduler.step()
class InvSqrtLRScheduler(ParlAILRScheduler):
"""
Scheduler that decays at an inverse square root rate.
"""
def __init__(
self,
optimizer,
hard_reset,
patience,
decay,
warmup_updates,
warmup_rate,
invsqrt_lr_decay_gamma,
max_lr_steps,
):
"""
invsqrt_lr_decay_gamma determines the cycle length of the inverse square root
scheduler.
When steps taken == invsqrt_lr_decay_gamma, the lr multiplier is 1
"""
super().__init__(hard_reset, warmup_updates, warmup_rate)
assert self.warmup_updates >= 0
self.max_lr_steps = max_lr_steps - self.warmup_updates
self.invsqrt_lr_decay_gamma = invsqrt_lr_decay_gamma
if invsqrt_lr_decay_gamma <= 0:
warn_once(
'--lr-scheduler invsqrt requires a value for '
'--invsqrt-lr-decay-gamma. Defaulting to set gamma to '
'--warmup-updates value for backwards compatibility.'
)
self.invsqrt_lr_decay_gamma = self.warmup_updates
self.decay_factor = np.sqrt(max(1, self.invsqrt_lr_decay_gamma))
self.scheduler = optim.lr_scheduler.LambdaLR(optimizer, self._invsqrt_lr)
def _invsqrt_lr(self, step):
return self.decay_factor / np.sqrt(max(1, self.invsqrt_lr_decay_gamma + step))
def train_step(self, scheduler_steps):
if self.max_lr_steps > 0 and scheduler_steps >= self.max_lr_steps:
raise StopTrainException('Maximum LR steps')
self.scheduler.step()
def valid_step(self, metrics_dict):
# this is a training step lr scheduler, nothing to adjust in validation
pass
class CosineLRScheduler(ParlAILRScheduler):
"""
Scheduler that decays by a cosine function.
"""
def __init__(
self,
optimizer,
hard_reset,
patience,
decay,
warmup_updates,
warmup_rate,
max_lr_steps,
):
"""
max_lr_steps determines the cycle length of the cosine annealing.
It indicates the number of steps from 1.0 multiplier to 0.0, which corresponds
to going from cos(0) to cos(pi)
"""
super().__init__(hard_reset, warmup_updates, warmup_rate)
if max_lr_steps <= 0:
raise ValueError('--lr-scheduler cosine requires setting --max-train-steps')
assert self.warmup_updates >= 0
self.max_lr_steps = max_lr_steps - self.warmup_updates
self.scheduler = optim.lr_scheduler.LambdaLR(optimizer, self._cosine_lr)
def _cosine_lr(self, step):
return math.cos(math.pi * step / (2 * self.max_lr_steps))
def train_step(self, scheduler_steps):
if scheduler_steps >= self.max_lr_steps:
raise StopTrainException('End of Cosine LR Schedule')
self.scheduler.step()
def valid_step(self, metrics_dict):
pass
class LinearLRScheduler(ParlAILRScheduler):
"""
Scheduler that decays linearly.
"""
def __init__(
self,
optimizer,
hard_reset,
patience,
decay,
warmup_updates,
warmup_rate,
max_lr_steps,
):
"""
max_lr_steps determines the cycle length of the linear annealing.
It indicates the number of steps from 1.0 multiplier to 0.0
"""
super().__init__(hard_reset, warmup_updates, warmup_rate)
if max_lr_steps <= 0:
raise ValueError('--lr-scheduler linear requires setting --max-train-steps')
assert self.warmup_updates >= 0
self.max_lr_steps = max_lr_steps - self.warmup_updates
self.scheduler = optim.lr_scheduler.LambdaLR(optimizer, self._linear_lr)
def _linear_lr(self, step):
# this multiplicative factor ensures linear decay rate
lr_mult = max(0.0, 1.0 - step / self.max_lr_steps)
return lr_mult
def train_step(self, scheduler_steps):
if scheduler_steps >= self.max_lr_steps:
raise StopTrainException('End of Linear LR Schedule')
self.scheduler.step()
def valid_step(self, metrics_dict):
pass