Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
LR Scheduler fixes (#3025)
Browse files Browse the repository at this point in the history
* Avoid some LR scheduler warnings.

* Round out schedulers. And end to end tests.

* unigram agent works with beam search
  • Loading branch information
stephenroller authored Sep 4, 2020
1 parent 9c894ba commit 5de0fbc
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 19 deletions.
51 changes: 51 additions & 0 deletions parlai/agents/test_agents/unigram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
UnigramAgent always predicts the unigram distribution.
It is a full TorchGeneratorAgent model, so it can be used heavily in testing, while
being very quick to optimize.
"""

import torch
import torch.nn as nn
from parlai.core.torch_generator_agent import TorchGeneratorAgent, TorchGeneratorModel


class UnigramEncoder(nn.Module):
def forward(self, x):
return None


class UnigramDecoder(nn.Module):
def forward(self, x, encoder_state, incr_state=None):
return x.unsqueeze(-1), None


class UnigramModel(TorchGeneratorModel):
def __init__(self, dictionary):
super().__init__()
self.encoder = UnigramEncoder()
self.decoder = UnigramDecoder()
self.v = len(dictionary)
self.p = nn.Parameter(torch.zeros(self.v))

def output(self, do):
desired = list(do.shape)[:2] + [self.v]
x = self.p.unsqueeze(0).unsqueeze(0)
return x.expand(desired)

def reorder_encoder_states(self, *args):
return None

def reorder_decoder_incremental_state(self, *args):
return None


class UnigramAgent(TorchGeneratorAgent):
def build_model(self):
return UnigramModel(self.dict)
2 changes: 1 addition & 1 deletion parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ def report(self):

# only report LR if we have a scheduler
if hasattr(self, 'scheduler') and self.scheduler is not None:
report['lr'] = GlobalAverageMetric(self.optimizer.param_groups[0]['lr'])
report['lr'] = GlobalAverageMetric(self.scheduler.get_last_lr())

if self.use_cuda:
report['gpu_mem'] = GlobalAverageMetric(self._gpu_usage())
Expand Down
42 changes: 32 additions & 10 deletions parlai/nn/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,33 @@ def __init__(self, hard_reset, warmup_updates, warmup_rate):
Starting multiplier for warmup scheduler.
"""
self._number_training_updates = 0
self.warmup_updates = warmup_updates
self.warmup_updates = max(0, warmup_updates)
self.warmup_rate = warmup_rate
self.hard_reset = hard_reset

def _init_warmup_scheduler(self, optimizer, states):
updates_so_far = states.get('number_training_updates', 0)
if self.warmup_updates > 0 and (
updates_so_far < self.warmup_updates or self.hard_reset
updates_so_far <= self.warmup_updates or self.hard_reset
):
self.warmup_scheduler = optim.lr_scheduler.LambdaLR(
optimizer, self._warmup_lr
)
if states.get('warmup_scheduler'):
self.warmup_scheduler.load_state_dict(states['warmup_scheduler'])
else:
self.warmup_scheduler = None

def get_last_lr(self):
s = self.warmup_scheduler if self._is_lr_warming_up() else self.scheduler
try:
# pytorch 1.5 or newer
return s.get_last_lr()[0]
except AttributeError:
# TODO: upon getting rid of pytorch 1.4, kill this
# pytorch 1.4 or older
return s.optimizer.param_groups[0]['lr']

def _is_lr_warming_up(self):
"""
Check if we're warming up the learning rate.
Expand All @@ -87,12 +99,17 @@ def load_state(self, states):
"""
Load state of scheduler from states.
"""
if self.scheduler and 'lr_scheduler' in states:
self.scheduler.load_state_dict(states['lr_scheduler'])
if states.get('warmup_scheduler') and getattr(self, 'warmup_scheduler', None):
self.warmup_scheduler.load_state_dict(states['warmup_scheduler'])
if self.scheduler and 'lr_scheduler' in states:
self.scheduler.load_state_dict(states['lr_scheduler'])
self._number_training_updates = states.get('number_training_updates', 0)
self.step(self._number_training_updates)

try:
self.scheduler.get_last_lr()
except AttributeError:
# on older pytorches
self.step(self._number_training_updates)

def get_initial_number_training_updates(self):
return self._number_training_updates
Expand Down Expand Up @@ -228,6 +245,7 @@ def lr_scheduler_factory(cls, opt, optimizer, states, hard_reset=False):
warmup_updates,
warmup_rate,
invsqrt_lr_decay_gamma,
max_lr_steps,
)
elif opt.get('lr_scheduler') == 'cosine':
scheduler = CosineLRScheduler(
Expand Down Expand Up @@ -292,7 +310,7 @@ def step(self, num_steps):
"""
self._number_training_updates = num_steps
if self._is_lr_warming_up():
self.warmup_scheduler.step(epoch=num_steps)
self.warmup_scheduler.step()
else:
scheduler_steps = num_steps - self.warmup_updates
self.train_step(scheduler_steps)
Expand Down Expand Up @@ -384,6 +402,7 @@ def __init__(
warmup_updates,
warmup_rate,
invsqrt_lr_decay_gamma,
max_lr_steps,
):
"""
invsqrt_lr_decay_gamma determines the cycle length of the inverse square root
Expand All @@ -392,6 +411,7 @@ def __init__(
When steps taken == invsqrt_lr_decay_gamma, the lr multiplier is 1
"""
super().__init__(hard_reset, warmup_updates, warmup_rate)
self.max_lr_steps = max_lr_steps
self.invsqrt_lr_decay_gamma = invsqrt_lr_decay_gamma
if invsqrt_lr_decay_gamma <= 0:
warn_once(
Expand All @@ -408,7 +428,9 @@ def _invsqrt_lr(self, step):
return self.decay_factor / np.sqrt(max(1, self.invsqrt_lr_decay_gamma + step))

def train_step(self, scheduler_steps):
self.scheduler.step(epoch=scheduler_steps)
if self.max_lr_steps > 0 and scheduler_steps >= self.max_lr_steps:
raise StopTrainException('Maximum LR steps')
self.scheduler.step()

def valid_step(self, metrics_dict):
# this is a training step lr scheduler, nothing to adjust in validation
Expand Down Expand Up @@ -445,7 +467,7 @@ def __init__(
def train_step(self, scheduler_steps):
if scheduler_steps >= self.max_lr_steps:
raise StopTrainException('End of Cosine LR Schedule')
self.scheduler.step(epoch=scheduler_steps)
self.scheduler.step()

def valid_step(self, metrics_dict):
pass
Expand Down Expand Up @@ -480,13 +502,13 @@ def __init__(
def _linear_lr(self, step):
# this multiplicative factor ensures linear decay rate
# lr_mult = float(self.max_lr_steps - step - 1) / float(self.max_lr_steps - step)
lr_mult = max(0.0, 1.0 - step / self.max_lr_steps)
lr_mult = max(0.0, 1e-6 + (1.0 - step / self.max_lr_steps) * (1 - 1e-6))
return lr_mult

def train_step(self, scheduler_steps):
if scheduler_steps >= self.max_lr_steps:
raise StopTrainException('End of Linear LR Schedule')
self.scheduler.step(epoch=scheduler_steps)
self.scheduler.step()

def valid_step(self, metrics_dict):
pass
3 changes: 2 additions & 1 deletion parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,8 @@ def train(self):
# do one example / batch of examples
try:
world.parley()
except StopTrainException:
except StopTrainException as e:
logging.info(f"Stopping from {e}")
break

self.parleys += 1
Expand Down
91 changes: 84 additions & 7 deletions tests/test_lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import unittest
import torch
import parlai.nn.lr_scheduler as lr_scheduler
import parlai.utils.testing as testing_utils


class TestLRSchedulers(unittest.TestCase):
Expand All @@ -22,33 +23,70 @@ def _run_pass(self, max_lr=1.0, warmup_updates=0, total_steps=1000, **args):
output = []
for step in range(total_steps):
scheduler.step(step)
output.append(optimizer.param_groups[0]['lr'])
output.append(scheduler.get_last_lr())
for step, o in enumerate(output): # noqa: B007
assert o <= max_lr
assert o > 0
assert o > 0 or step == total_steps - 1
warmup_updates = args.get('warmup_updates', 0)
if warmup_updates > 0:
assert output[warmup_updates] == max_lr
assert output[warmup_updates - 1] == max_lr
# no steep cliffs of > 50% of LR
assert (output[warmup_updates + 1] - max_lr) / max_lr < 0.5
assert (output[warmup_updates] - max_lr) / max_lr < 0.5
# LR is always linear
for step in range(warmup_updates - 1):
self.assertAlmostEqual(
output[step + 1] - output[step], max_lr / warmup_updates, places=3,
output[step + 1] - output[step], max_lr / warmup_updates, places=3
)
return output

def _run_resume(self, max_lr=1.0, warmup_updates=0, total_steps=200, **args):
args['warmup_updates'] = warmup_updates
if 'max_lr_steps' not in args:
args['max_lr_steps'] = total_steps - warmup_updates
p = torch.nn.Parameter(torch.randn(4, 4))
optimizer = torch.optim.SGD([p], lr=max_lr)
scheduler = lr_scheduler.ParlAILRScheduler.lr_scheduler_factory(
args, optimizer, {}, True
)

for step in range(total_steps):
p = torch.nn.Parameter(torch.randn(4, 4))
optimizer2 = torch.optim.SGD([p], lr=max_lr)
sd = {
'number_training_updates': step + 1,
'lr_scheduler': scheduler.get_state_dict(),
'lr_scheduler_type': args['lr_scheduler'],
'warmup_scheduler': scheduler.get_warmup_state_dict(),
}
scheduler2 = lr_scheduler.ParlAILRScheduler.lr_scheduler_factory(
args, optimizer2, sd, False
)
assert scheduler.get_last_lr() == scheduler2.get_last_lr(), step
scheduler.step(step)

sd = {
'number_training_updates': step,
'lr_scheduler': scheduler.get_state_dict(),
'lr_scheduler_type': args['lr_scheduler'],
'warmup_scheduler': scheduler.get_warmup_state_dict(),
}
optimizer2 = torch.optim.SGD([p], lr=max_lr)
scheduler2 = lr_scheduler.ParlAILRScheduler.lr_scheduler_factory(
args, optimizer2, sd, False
)
assert scheduler.get_last_lr() == scheduler2.get_last_lr()

def test_cosine(self):
self._run_pass(lr_scheduler='cosine', warmup_updates=0)
self._run_pass(lr_scheduler='cosine', warmup_updates=50)
with self.assertRaises(lr_scheduler.StopTrainException):
self._run_pass(lr_scheduler='cosine', max_lr_steps=100, total_steps=1000)

def test_linear_warmup(self):
def test_linear(self):
self._run_pass(lr_scheduler='linear', warmup_updates=0)
self._run_pass(lr_scheduler='linear', warmup_updates=50)
with self.assertRaises(lr_scheduler.StopTrainException):
self._run_pass(lr_scheduler='cosine', max_lr_steps=100, total_steps=1000)
self._run_pass(lr_scheduler='linear', max_lr_steps=100, total_steps=1000)

def test_invsqrt(self):
self._run_pass(lr_scheduler='invsqrt', warmup_updates=0)
Expand All @@ -65,3 +103,42 @@ def test_invsqrt(self):
lr_scheduler='invsqrt', warmup_updates=50, invsqrt_lr_decay_gamma=5000
)
assert all(x > 0.9 for x in steps[50:])

def test_cosine_resume(self):
self._run_resume(lr_scheduler='cosine', warmup_updates=0)
self._run_resume(lr_scheduler='cosine', warmup_updates=50)

def test_linear_resume(self):
self._run_resume(lr_scheduler='linear', warmup_updates=0)
self._run_resume(lr_scheduler='linear', warmup_updates=50)

def test_invsqrt_resume(self):
self._run_resume(lr_scheduler='invsqrt', warmup_updates=0)
self._run_resume(lr_scheduler='invsqrt', warmup_updates=50)

def _run_end2end(
self, lr_scheduler, max_lr=1.0, warmup_updates=0, total_steps=100, **args
):
testing_utils.train_model(
{
'task': 'integration_tests:nocandidate',
'model': 'test_agents/unigram',
'skip_generation': True,
'lr_scheduler': lr_scheduler,
'max_lr_steps': total_steps,
'warmup_updates': warmup_updates,
'learningrate': max_lr,
}
)

def test_end2end_cosine(self):
self._run_end2end(lr_scheduler='cosine', warmup_updates=0)
self._run_end2end(lr_scheduler='cosine', warmup_updates=50)

def test_end2end_linear(self):
self._run_end2end(lr_scheduler='linear', warmup_updates=0)
self._run_end2end(lr_scheduler='linear', warmup_updates=50)

def test_end2end_invsqrt(self):
self._run_end2end(lr_scheduler='invsqrt', warmup_updates=0)
self._run_end2end(lr_scheduler='invsqrt', warmup_updates=50)
1 change: 1 addition & 0 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,7 @@ def _test_learning_rate_resuming(self, args):
init_model=os.path.join(tmpdir, 'model'),
model_file=os.path.join(tmpdir, 'newmodel2'),
lr_scheduler='reduceonplateau',
log_every_n_secs=0.001,
**args,
)
)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_unigram.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,17 @@ def test_unigram(self):
{'model': 'unigram', 'task': 'integration_tests', 'num_epochs': 0.01}
)
assert valid['f1'] > 0


class TestUnigramTorchAgent(unittest.TestCase):
def test_unigram(self):
valid, test = testing_utils.train_model(
{
'model': 'test_agents/unigram',
'task': 'integration_tests',
'num_epochs': 1.0,
'batchsize': 32,
'truncate': 4,
}
)
assert valid['f1'] > 0

0 comments on commit 5de0fbc

Please sign in to comment.