From 8664fd5dfd114c09d9e574ff7f84832addc29f7b Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Mon, 19 Apr 2021 11:52:35 -0400 Subject: [PATCH] [TA] Fix preemption of cosine scheduler (#3599) * [TA] Fix preemption of cosine scheduler * Kurt's advice * Lint. --- parlai/nn/lr_scheduler.py | 10 ++-- parlai/scripts/train_model.py | 26 +++++++--- tests/test_lr_schedulers.py | 91 +++++++++++++++++++++++++++++++++++ tests/test_train_model.py | 4 +- 4 files changed, 119 insertions(+), 12 deletions(-) diff --git a/parlai/nn/lr_scheduler.py b/parlai/nn/lr_scheduler.py index a4d5bb25633..c718633c960 100644 --- a/parlai/nn/lr_scheduler.py +++ b/parlai/nn/lr_scheduler.py @@ -9,6 +9,7 @@ See ParlAILRScheduler (super class) and subclasses for detailed documentation """ +import math from typing import Optional from parlai.core.params import ParlaiParser from parlai.core.opt import Opt @@ -462,9 +463,12 @@ def __init__( """ super().__init__(hard_reset, warmup_updates, warmup_rate) if max_lr_steps <= 0: - raise ValueError('--lr-scheduler cosine requires setting --max-lr-steps') + raise ValueError('--lr-scheduler cosine requires setting --max-train-steps') self.max_lr_steps = max_lr_steps - self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, max_lr_steps) + self.scheduler = optim.lr_scheduler.LambdaLR(optimizer, self._cosine_lr) + + def _cosine_lr(self, step): + return math.cos(math.pi * step / (2 * self.max_lr_steps)) def train_step(self, scheduler_steps): if scheduler_steps >= self.max_lr_steps: @@ -497,7 +501,7 @@ def __init__( """ super().__init__(hard_reset, warmup_updates, warmup_rate) if max_lr_steps <= 0: - raise ValueError('--lr-scheduler linear requires setting --max-lr-steps') + raise ValueError('--lr-scheduler linear requires setting --max-train-steps') self.max_lr_steps = max_lr_steps self.scheduler = optim.lr_scheduler.LambdaLR(optimizer, self._linear_lr) diff --git a/parlai/scripts/train_model.py b/parlai/scripts/train_model.py index f17fa9c40ad..ff8a3b4f694 100644 --- a/parlai/scripts/train_model.py +++ b/parlai/scripts/train_model.py @@ -784,11 +784,13 @@ def log(self): if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_metrics('train', self.parleys, train_report) - def train(self): + return train_report + + def train_steps(self): """ - Perform a training run. + Core training loop. - :return: tuple of reports (validation_report, test_report) + Yields a metrics dict with each log. """ logging.info('training...') opt = self.opt @@ -814,7 +816,7 @@ def train(self): # check counters and timers if self._total_epochs >= self.max_num_epochs: - self.log() + yield self.log() logging.info( f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s' ) @@ -832,7 +834,7 @@ def train(self): log_time > self.log_every_n_secs or self._last_log_steps >= self.log_every_n_steps ): - self.log() + yield self.log() if ( validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch @@ -842,7 +844,8 @@ def train(self): ): try: # log before we validate - self.log() + if self._last_log_steps: + yield self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: @@ -886,6 +889,17 @@ def train(self): # reload best validation model self.agent = create_agent(opt) + def train(self): + """ + Perform a training run. + + :return: tuple of reports (validation_report, test_report) + """ + opt = self.opt + for _train_log in self.train_steps(): + # we've already done what we need in these + pass + # perform final validation/testing valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1 diff --git a/tests/test_lr_schedulers.py b/tests/test_lr_schedulers.py index 843a476252b..2ce941c446a 100644 --- a/tests/test_lr_schedulers.py +++ b/tests/test_lr_schedulers.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import os import unittest import torch import parlai.nn.lr_scheduler as lr_scheduler @@ -142,3 +143,93 @@ def test_end2end_linear(self): def test_end2end_invsqrt(self): self._run_end2end(lr_scheduler='invsqrt', warmup_updates=0) self._run_end2end(lr_scheduler='invsqrt', warmup_updates=50) + + +class TestLRIntegration(unittest.TestCase): + """ + Deep LR scheduler tests to check how we handle preemption. + """ + + PREEMPT = 30 + + def _test_scheduler(self, **kwargs): + from parlai.scripts.train_model import TrainModel, TrainLoop + + # shallow copy to prevent overwrites + kwargs = kwargs.copy() + with testing_utils.tempdir() as tmpdir: + kwargs['model'] = 'test_agents/unigram' + kwargs['task'] = 'integration_tests' + kwargs['skip_generation'] = True + kwargs['validation_metric'] = 'loss' + kwargs['model_file'] = os.path.join(tmpdir, 'model') + kwargs['dict_file'] = 'zoo:unittest/transformer_generator2/model.dict' + kwargs['log_every_n_steps'] = 1 + kwargs['validation_every_n_steps'] = 10 + kwargs['max_train_steps'] = 100 + kwargs['save_after_valid'] = True + kwargs['learningrate'] = 1 + opt = TrainModel.setup_args().parse_kwargs(**kwargs) + + logs_first = [] + for i, train_step_log in enumerate(TrainLoop(opt).train_steps(), 1): + logs_first.append(train_step_log) + if i >= self.PREEMPT - 2: + # simulate preemption + break + + # resume training + logs_second = [] + for train_step_log in TrainLoop(opt).train_steps(): + logs_second.append(train_step_log) + + # check correctness + assert ( + logs_first[20]['total_train_updates'] + == logs_second[0]['total_train_updates'] + ) + assert logs_first[20]['lr'] == logs_second[0]['lr'] + + if 'warump_updates' in kwargs: + full_logs = logs_first[:20] + logs_second + assert full_logs[kwargs['warmup_updates']]['lr'] == 1.0 + + return logs_first, logs_second + + def test_invsqrt(self): + self._test_scheduler(lr_scheduler='invsqrt') + + def test_invsqrt_warmup(self): + self._test_scheduler(lr_scheduler='invsqrt', warmup_updates=25) + + def test_invsqrt_long_warmup(self): + self._test_scheduler(lr_scheduler='invsqrt', warmup_updates=self.PREEMPT + 30) + + def test_reduceonplateau(self): + self._test_scheduler(lr_scheduler='reduceonplateau') + + def test_reduceonplateau_warmup(self): + self._test_scheduler(lr_scheduler='reduceonplateau', warmup_updates=25) + + def test_reduceonplateau_long_warmup(self): + self._test_scheduler( + lr_scheduler='reduceonplateau', warmup_updates=self.PREEMPT + 30 + ) + + def test_linear(self): + self._test_scheduler(lr_scheduler='linear') + + def test_linear_warmup(self): + self._test_scheduler(lr_scheduler='linear', warmup_updates=25) + + def test_linear_long_warmup(self): + self._test_scheduler(lr_scheduler='linear', warmup_updates=self.PREEMPT + 30) + + def test_cosine(self): + self._test_scheduler(lr_scheduler='cosine') + + def test_cosine_warmup(self): + self._test_scheduler(lr_scheduler='cosine', warmup_updates=25) + + def test_cosine_long_warmup(self): + self._test_scheduler(lr_scheduler='cosine', warmup_updates=self.PREEMPT + 30) diff --git a/tests/test_train_model.py b/tests/test_train_model.py index d06c086938a..c3aa0f5f842 100644 --- a/tests/test_train_model.py +++ b/tests/test_train_model.py @@ -154,9 +154,7 @@ def get_tl(tmpdir): tl.valid_reports[-1]['total_train_updates'], num_train_steps - 1 ) self.assertEqual(len(tl.valid_reports), num_validations) - self.assertEqual( - len(tl.train_reports), num_logs + num_validations - ) # log every valid as well + self.assertEqual(len(tl.train_reports), num_logs) # log every valid as well def test_opt_step(self): self._test_opt_step_opts(1)