From 1db545adc858c24f93a0f54ca24570ad126fe017 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 27 Mar 2020 03:07:22 +0100 Subject: [PATCH] Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent --- CHANGELOG.md | 2 +- docs/source/fast_training.rst | 29 ++++++++-------- pytorch_lightning/trainer/trainer.py | 3 +- tests/models/test_cpu.py | 50 ++++++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c948e22e7b553..0fa590d0e1873 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,11 +29,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - - `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)). - Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114)) - Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132)) - Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191)) +- Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251)) ## [0.7.1] - 2020-03-07 diff --git a/docs/source/fast_training.rst b/docs/source/fast_training.rst index 4500ebde88dc6..b741107ca17b4 100644 --- a/docs/source/fast_training.rst +++ b/docs/source/fast_training.rst @@ -1,10 +1,10 @@ Fast Training -================ +============= There are multiple options to speed up different parts of the training by choosing to train on a subset of data. This could be done for speed or debugging purposes. Check validation every n epochs -------------------------------------- +------------------------------- If you have a small dataset you might want to check validation every n epochs .. code-block:: python @@ -13,7 +13,7 @@ If you have a small dataset you might want to check validation every n epochs trainer = Trainer(check_val_every_n_epoch=1) Force training for min or max epochs -------------------------------------- +------------------------------------ It can be useful to force training for a minimum number of epochs or limit to a max number. .. seealso:: @@ -26,7 +26,7 @@ It can be useful to force training for a minimum number of epochs or limit to a Set validation check frequency within 1 training epoch -------------------------------------------------------- +------------------------------------------------------ For large datasets it's often desirable to check validation multiple times within a training loop. Pass in a float to check that often within 1 training epoch. Pass in an int k to check every k training batches. Must use an int if using an IterableDataset. @@ -43,7 +43,7 @@ Must use an int if using an IterableDataset. trainer = Trainer(val_check_interval=100) Use training data subset ----------------------------------- +------------------------ If you don't want to check 100% of the training set (for debugging or if it's huge), set this flag. .. code-block:: python @@ -54,12 +54,11 @@ If you don't want to check 100% of the training set (for debugging or if it's hu # check 10% only trainer = Trainer(train_percent_check=0.1) -.. note:: train_percent_check will be overwritten by overfit_pct if overfit_pct > 0 +.. note:: ``train_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0. Use test data subset -------------------------------------- -If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag -test_percent_check will be overwritten by overfit_pct if overfit_pct > 0. +-------------------- +If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag. .. code-block:: python @@ -69,10 +68,11 @@ test_percent_check will be overwritten by overfit_pct if overfit_pct > 0. # check 10% only trainer = Trainer(test_percent_check=0.1) +.. note:: ``test_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0. + Use validation data subset --------------------------------------------- -If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag -val_percent_check will be overwritten by overfit_pct if overfit_pct > 0 +-------------------------- +If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag. .. code-block:: python @@ -80,4 +80,7 @@ val_percent_check will be overwritten by overfit_pct if overfit_pct > 0 trainer = Trainer(val_percent_check=1.0) # check 10% only - trainer = Trainer(val_percent_check=0.1) \ No newline at end of file + trainer = Trainer(val_percent_check=0.1) + +.. note:: ``val_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0 and ignored if + ``fast_dev_run=True``. \ No newline at end of file diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 881a2e9103301..4b8a42aba90a2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -876,7 +876,8 @@ def run_pretrain_routine(self, model: LightningModule): return # check if we should run validation during training - self.disable_validation = not self.is_overriden('validation_step') and not self.fast_dev_run + self.disable_validation = not (self.is_overriden('validation_step') and self.val_percent_check > 0) \ + and not self.fast_dev_run # run tiny validation (if validation defined) # to make sure program won't crash during val diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index f5f5095d33a6d..4fd7b3839d907 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -14,6 +14,7 @@ LightTrainDataloader, LightningTestModel, LightTestMixin, + LightValidationMixin ) @@ -156,6 +157,55 @@ class CurrentTestModel(LightTrainDataloader, LightTestMixin, TestModelBase): tutils.assert_ok_model_acc(trainer) +def test_disabled_validation(): + """Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`.""" + tutils.reset_seed() + + class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase): + + validation_step_invoked = False + validation_end_invoked = False + + def validation_step(self, *args, **kwargs): + self.validation_step_invoked = True + return super().validation_step(*args, **kwargs) + + def validation_end(self, *args, **kwargs): + self.validation_end_invoked = True + return super().validation_end(*args, **kwargs) + + hparams = tutils.get_default_hparams() + model = CurrentModel(hparams) + + trainer_options = dict( + show_progress_bar=False, + max_epochs=2, + train_percent_check=0.4, + val_percent_check=0.0, + fast_dev_run=False, + ) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # check that val_percent_check=0 turns off validation + assert result == 1, 'training failed to complete' + assert trainer.current_epoch == 1 + assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`' + assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`' + + # check that val_percent_check has no influence when fast_dev_run is turned on + model = CurrentModel(hparams) + trainer_options.update(fast_dev_run=True) + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + assert result == 1, 'training failed to complete' + assert trainer.current_epoch == 0 + assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`' + assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`' + + def test_single_gpu_batch_parse(): tutils.reset_seed()