diff --git a/CHANGELOG.md b/CHANGELOG.md index cc430356191c3..8d9c3d8a1f186 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,10 +41,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209)) +- Disabled training when `limit_train_batches=0` ([#4371](https://github.com/PyTorchLightning/pytorch-lightning/pull/4371)) + - Fixed that metrics do not store computational graph for all seen data ([#4313](https://github.com/PyTorchLightning/pytorch-lightning/pull/4313)) - Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) + ## [1.0.4] - 2020-10-27 ### Added diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d3cc2f2e7278f..49cf232f76ac7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -482,6 +482,10 @@ def train(self): # hook self.train_loop.on_train_start() + if self.train_loop.should_skip_training(): + self.train_loop.on_train_end() + return + try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0a931257f560f..3845b7eb728ac 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -77,6 +77,15 @@ def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable()) return num_optimizers + def should_skip_training(self): + if self.trainer.current_epoch >= self.trainer.max_epochs: + return True + + if self.trainer.limit_train_batches == 0: + return True + + return False + def on_train_start(self): # clear cache before training if self.trainer.on_gpu and self.trainer.root_gpu is not None: @@ -597,7 +606,7 @@ def run_training_epoch(self): self.trainer.total_batch_idx += 1 # stop epoch if we limited the number of training batches - if batch_idx + 1 >= self.trainer.num_training_batches: + if (batch_idx + 1) >= self.trainer.num_training_batches: break # progress global step according to grads progress diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index d354b59682240..221844244ad75 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -231,7 +231,7 @@ def on_validation_epoch_end(self, trainer, pl_module): default_root_dir=tmpdir, max_epochs=1, num_sanity_val_steps=2, - limit_train_batches=0, + limit_train_batches=1, limit_val_batches=limit_val_batches, callbacks=[progress_bar], logger=False, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d45c1c50cc060..51fa89ee5539b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -747,6 +747,68 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): assert trainer.tested_ckpt_path == ckpt_path +def test_disabled_training(tmpdir): + """Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`.""" + + class CurrentModel(BoringModel): + + training_step_invoked = False + training_epoch_end_invoked = False + + def training_step(self, *args, **kwargs): + self.training_step_invoked = True + return super().training_step(*args, **kwargs) + + def training_epoch_end(self, *args, **kwargs): + self.training_epoch_end_invoked = True + return super().training_epoch_end(*args, **kwargs) + + model = CurrentModel() + + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=2, + limit_train_batches=0.0, + limit_val_batches=0.2, + fast_dev_run=False, + ) + + before_state_dict = deepcopy(model.state_dict()) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + after_state_dict = model.state_dict() + + for key in before_state_dict.keys(): + assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])) + + # check that limit_train_batches=0 turns off training + assert result == 1, "training failed to complete" + assert trainer.current_epoch == 0 + assert not model.training_step_invoked, "`training_step` should not run when `limit_train_batches=0`" + assert not model.training_epoch_end_invoked, "`training_epoch_end` should not run when `limit_train_batches=0`" + + # check that limit_train_batches has no influence when fast_dev_run is turned on + model = CurrentModel() + trainer_options.update(fast_dev_run=True) + before_state_dict = deepcopy(model.state_dict()) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + after_state_dict = model.state_dict() + + for key in before_state_dict.keys(): + assert not torch.all(torch.eq(before_state_dict[key], after_state_dict[key])) + + assert result == 1, "training failed to complete" + assert trainer.current_epoch == 0 + assert model.training_step_invoked, "did not run `training_step` with `fast_dev_run=True`" + assert model.training_epoch_end_invoked, "did not run `training_epoch_end` with `fast_dev_run=True`" + + def test_disabled_validation(tmpdir): """Verify that `limit_val_batches=0` disables the validation loop unless `fast_dev_run=True`."""