Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[TA] Fix preemption of cosine scheduler (#3599)
Browse files Browse the repository at this point in the history
* [TA] Fix preemption of cosine scheduler

* Kurt's advice

* Lint.
  • Loading branch information
stephenroller authored Apr 19, 2021
1 parent 7d26830 commit 8664fd5
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 12 deletions.
10 changes: 7 additions & 3 deletions parlai/nn/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
See ParlAILRScheduler (super class) and subclasses for detailed documentation
"""

import math
from typing import Optional
from parlai.core.params import ParlaiParser
from parlai.core.opt import Opt
Expand Down Expand Up @@ -462,9 +463,12 @@ def __init__(
"""
super().__init__(hard_reset, warmup_updates, warmup_rate)
if max_lr_steps <= 0:
raise ValueError('--lr-scheduler cosine requires setting --max-lr-steps')
raise ValueError('--lr-scheduler cosine requires setting --max-train-steps')
self.max_lr_steps = max_lr_steps
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, max_lr_steps)
self.scheduler = optim.lr_scheduler.LambdaLR(optimizer, self._cosine_lr)

def _cosine_lr(self, step):
return math.cos(math.pi * step / (2 * self.max_lr_steps))

def train_step(self, scheduler_steps):
if scheduler_steps >= self.max_lr_steps:
Expand Down Expand Up @@ -497,7 +501,7 @@ def __init__(
"""
super().__init__(hard_reset, warmup_updates, warmup_rate)
if max_lr_steps <= 0:
raise ValueError('--lr-scheduler linear requires setting --max-lr-steps')
raise ValueError('--lr-scheduler linear requires setting --max-train-steps')
self.max_lr_steps = max_lr_steps
self.scheduler = optim.lr_scheduler.LambdaLR(optimizer, self._linear_lr)

Expand Down
26 changes: 20 additions & 6 deletions parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,11 +784,13 @@ def log(self):
if opt['wandb_log'] and is_primary_worker():
self.wb_logger.log_metrics('train', self.parleys, train_report)

def train(self):
return train_report

def train_steps(self):
"""
Perform a training run.
Core training loop.
:return: tuple of reports (validation_report, test_report)
Yields a metrics dict with each log.
"""
logging.info('training...')
opt = self.opt
Expand All @@ -814,7 +816,7 @@ def train(self):

# check counters and timers
if self._total_epochs >= self.max_num_epochs:
self.log()
yield self.log()
logging.info(
f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s'
)
Expand All @@ -832,7 +834,7 @@ def train(self):
log_time > self.log_every_n_secs
or self._last_log_steps >= self.log_every_n_steps
):
self.log()
yield self.log()
if (
validate_time > self.val_every_n_secs
or self._total_epochs - self.last_valid_epoch
Expand All @@ -842,7 +844,8 @@ def train(self):
):
try:
# log before we validate
self.log()
if self._last_log_steps:
yield self.log()
world.reset_metrics()
stop_training = self.validate()
except StopTrainException:
Expand Down Expand Up @@ -886,6 +889,17 @@ def train(self):
# reload best validation model
self.agent = create_agent(opt)

def train(self):
"""
Perform a training run.
:return: tuple of reports (validation_report, test_report)
"""
opt = self.opt
for _train_log in self.train_steps():
# we've already done what we need in these
pass

# perform final validation/testing
valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1
Expand Down
91 changes: 91 additions & 0 deletions tests/test_lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import unittest
import torch
import parlai.nn.lr_scheduler as lr_scheduler
Expand Down Expand Up @@ -142,3 +143,93 @@ def test_end2end_linear(self):
def test_end2end_invsqrt(self):
self._run_end2end(lr_scheduler='invsqrt', warmup_updates=0)
self._run_end2end(lr_scheduler='invsqrt', warmup_updates=50)


class TestLRIntegration(unittest.TestCase):
"""
Deep LR scheduler tests to check how we handle preemption.
"""

PREEMPT = 30

def _test_scheduler(self, **kwargs):
from parlai.scripts.train_model import TrainModel, TrainLoop

# shallow copy to prevent overwrites
kwargs = kwargs.copy()
with testing_utils.tempdir() as tmpdir:
kwargs['model'] = 'test_agents/unigram'
kwargs['task'] = 'integration_tests'
kwargs['skip_generation'] = True
kwargs['validation_metric'] = 'loss'
kwargs['model_file'] = os.path.join(tmpdir, 'model')
kwargs['dict_file'] = 'zoo:unittest/transformer_generator2/model.dict'
kwargs['log_every_n_steps'] = 1
kwargs['validation_every_n_steps'] = 10
kwargs['max_train_steps'] = 100
kwargs['save_after_valid'] = True
kwargs['learningrate'] = 1
opt = TrainModel.setup_args().parse_kwargs(**kwargs)

logs_first = []
for i, train_step_log in enumerate(TrainLoop(opt).train_steps(), 1):
logs_first.append(train_step_log)
if i >= self.PREEMPT - 2:
# simulate preemption
break

# resume training
logs_second = []
for train_step_log in TrainLoop(opt).train_steps():
logs_second.append(train_step_log)

# check correctness
assert (
logs_first[20]['total_train_updates']
== logs_second[0]['total_train_updates']
)
assert logs_first[20]['lr'] == logs_second[0]['lr']

if 'warump_updates' in kwargs:
full_logs = logs_first[:20] + logs_second
assert full_logs[kwargs['warmup_updates']]['lr'] == 1.0

return logs_first, logs_second

def test_invsqrt(self):
self._test_scheduler(lr_scheduler='invsqrt')

def test_invsqrt_warmup(self):
self._test_scheduler(lr_scheduler='invsqrt', warmup_updates=25)

def test_invsqrt_long_warmup(self):
self._test_scheduler(lr_scheduler='invsqrt', warmup_updates=self.PREEMPT + 30)

def test_reduceonplateau(self):
self._test_scheduler(lr_scheduler='reduceonplateau')

def test_reduceonplateau_warmup(self):
self._test_scheduler(lr_scheduler='reduceonplateau', warmup_updates=25)

def test_reduceonplateau_long_warmup(self):
self._test_scheduler(
lr_scheduler='reduceonplateau', warmup_updates=self.PREEMPT + 30
)

def test_linear(self):
self._test_scheduler(lr_scheduler='linear')

def test_linear_warmup(self):
self._test_scheduler(lr_scheduler='linear', warmup_updates=25)

def test_linear_long_warmup(self):
self._test_scheduler(lr_scheduler='linear', warmup_updates=self.PREEMPT + 30)

def test_cosine(self):
self._test_scheduler(lr_scheduler='cosine')

def test_cosine_warmup(self):
self._test_scheduler(lr_scheduler='cosine', warmup_updates=25)

def test_cosine_long_warmup(self):
self._test_scheduler(lr_scheduler='cosine', warmup_updates=self.PREEMPT + 30)
4 changes: 1 addition & 3 deletions tests/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,7 @@ def get_tl(tmpdir):
tl.valid_reports[-1]['total_train_updates'], num_train_steps - 1
)
self.assertEqual(len(tl.valid_reports), num_validations)
self.assertEqual(
len(tl.train_reports), num_logs + num_validations
) # log every valid as well
self.assertEqual(len(tl.train_reports), num_logs) # log every valid as well

def test_opt_step(self):
self._test_opt_step_opts(1)
Expand Down

0 comments on commit 8664fd5

Please sign in to comment.