Skip to content

Commit

Permalink
Disable validation when val_percent_check=0 (Lightning-AI#1251)
Browse files Browse the repository at this point in the history
* fix disable validation

* add test

* update changelog

* update docs for val_percent_check

* make "fast training" docs consistent
  • Loading branch information
Adrian Wälchli authored and akarnachev committed Apr 3, 2020
1 parent 004d7ec commit 1db545a
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 15 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed


- `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)).
- Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114))
- Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132))
- Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191))
- Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251))

## [0.7.1] - 2020-03-07

Expand Down
29 changes: 16 additions & 13 deletions docs/source/fast_training.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
Fast Training
================
=============
There are multiple options to speed up different parts of the training by choosing to train
on a subset of data. This could be done for speed or debugging purposes.

Check validation every n epochs
-------------------------------------
-------------------------------
If you have a small dataset you might want to check validation every n epochs

.. code-block:: python
Expand All @@ -13,7 +13,7 @@ If you have a small dataset you might want to check validation every n epochs
trainer = Trainer(check_val_every_n_epoch=1)
Force training for min or max epochs
-------------------------------------
------------------------------------
It can be useful to force training for a minimum number of epochs or limit to a max number.

.. seealso::
Expand All @@ -26,7 +26,7 @@ It can be useful to force training for a minimum number of epochs or limit to a
Set validation check frequency within 1 training epoch
-------------------------------------------------------
------------------------------------------------------
For large datasets it's often desirable to check validation multiple times within a training loop.
Pass in a float to check that often within 1 training epoch. Pass in an int k to check every k training batches.
Must use an int if using an IterableDataset.
Expand All @@ -43,7 +43,7 @@ Must use an int if using an IterableDataset.
trainer = Trainer(val_check_interval=100)
Use training data subset
----------------------------------
------------------------
If you don't want to check 100% of the training set (for debugging or if it's huge), set this flag.

.. code-block:: python
Expand All @@ -54,12 +54,11 @@ If you don't want to check 100% of the training set (for debugging or if it's hu
# check 10% only
trainer = Trainer(train_percent_check=0.1)
.. note:: train_percent_check will be overwritten by overfit_pct if overfit_pct > 0
.. note:: ``train_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0.

Use test data subset
-------------------------------------
If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag
test_percent_check will be overwritten by overfit_pct if overfit_pct > 0.
--------------------
If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag.

.. code-block:: python
Expand All @@ -69,15 +68,19 @@ test_percent_check will be overwritten by overfit_pct if overfit_pct > 0.
# check 10% only
trainer = Trainer(test_percent_check=0.1)
.. note:: ``test_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0.

Use validation data subset
--------------------------------------------
If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag
val_percent_check will be overwritten by overfit_pct if overfit_pct > 0
--------------------------
If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag.

.. code-block:: python
# DEFAULT
trainer = Trainer(val_percent_check=1.0)
# check 10% only
trainer = Trainer(val_percent_check=0.1)
trainer = Trainer(val_percent_check=0.1)
.. note:: ``val_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0 and ignored if
``fast_dev_run=True``.
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,8 @@ def run_pretrain_routine(self, model: LightningModule):
return

# check if we should run validation during training
self.disable_validation = not self.is_overriden('validation_step') and not self.fast_dev_run
self.disable_validation = not (self.is_overriden('validation_step') and self.val_percent_check > 0) \
and not self.fast_dev_run

# run tiny validation (if validation defined)
# to make sure program won't crash during val
Expand Down
50 changes: 50 additions & 0 deletions tests/models/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
LightTrainDataloader,
LightningTestModel,
LightTestMixin,
LightValidationMixin
)


Expand Down Expand Up @@ -156,6 +157,55 @@ class CurrentTestModel(LightTrainDataloader, LightTestMixin, TestModelBase):
tutils.assert_ok_model_acc(trainer)


def test_disabled_validation():
"""Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`."""
tutils.reset_seed()

class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase):

validation_step_invoked = False
validation_end_invoked = False

def validation_step(self, *args, **kwargs):
self.validation_step_invoked = True
return super().validation_step(*args, **kwargs)

def validation_end(self, *args, **kwargs):
self.validation_end_invoked = True
return super().validation_end(*args, **kwargs)

hparams = tutils.get_default_hparams()
model = CurrentModel(hparams)

trainer_options = dict(
show_progress_bar=False,
max_epochs=2,
train_percent_check=0.4,
val_percent_check=0.0,
fast_dev_run=False,
)

trainer = Trainer(**trainer_options)
result = trainer.fit(model)

# check that val_percent_check=0 turns off validation
assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 1
assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`'
assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`'

# check that val_percent_check has no influence when fast_dev_run is turned on
model = CurrentModel(hparams)
trainer_options.update(fast_dev_run=True)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)

assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 0
assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`'
assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`'


def test_single_gpu_batch_parse():
tutils.reset_seed()

Expand Down

0 comments on commit 1db545a

Please sign in to comment.