Skip to content

Commit

Permalink
Disable training when limit_train_batches=0 (#4371)
Browse files Browse the repository at this point in the history
* Disable training when limit_train_batches=0

* chlog

* pep

* limit_train_batches

* BoringModel

Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
  • Loading branch information
rohitgr7 and s-rog authored Nov 3, 2020
1 parent ad2556b commit 360b3d8
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209))

- Disabled training when `limit_train_batches=0` ([#4371](https://github.com/PyTorchLightning/pytorch-lightning/pull/4371))

- Fixed that metrics do not store computational graph for all seen data ([#4313](https://github.com/PyTorchLightning/pytorch-lightning/pull/4313))

- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))


## [1.0.4] - 2020-10-27

### Added
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,10 @@ def train(self):
# hook
self.train_loop.on_train_start()

if self.train_loop.should_skip_training():
self.train_loop.on_train_end()
return

try:
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def num_optimizers(self):
num_optimizers = len(self.get_optimizers_iterable())
return num_optimizers

def should_skip_training(self):
if self.trainer.current_epoch >= self.trainer.max_epochs:
return True

if self.trainer.limit_train_batches == 0:
return True

return False

def on_train_start(self):
# clear cache before training
if self.trainer.on_gpu and self.trainer.root_gpu is not None:
Expand Down Expand Up @@ -597,7 +606,7 @@ def run_training_epoch(self):
self.trainer.total_batch_idx += 1

# stop epoch if we limited the number of training batches
if batch_idx + 1 >= self.trainer.num_training_batches:
if (batch_idx + 1) >= self.trainer.num_training_batches:
break

# progress global step according to grads progress
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def on_validation_epoch_end(self, trainer, pl_module):
default_root_dir=tmpdir,
max_epochs=1,
num_sanity_val_steps=2,
limit_train_batches=0,
limit_train_batches=1,
limit_val_batches=limit_val_batches,
callbacks=[progress_bar],
logger=False,
Expand Down
62 changes: 62 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,68 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
assert trainer.tested_ckpt_path == ckpt_path


def test_disabled_training(tmpdir):
"""Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`."""

class CurrentModel(BoringModel):

training_step_invoked = False
training_epoch_end_invoked = False

def training_step(self, *args, **kwargs):
self.training_step_invoked = True
return super().training_step(*args, **kwargs)

def training_epoch_end(self, *args, **kwargs):
self.training_epoch_end_invoked = True
return super().training_epoch_end(*args, **kwargs)

model = CurrentModel()

trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=2,
limit_train_batches=0.0,
limit_val_batches=0.2,
fast_dev_run=False,
)

before_state_dict = deepcopy(model.state_dict())

trainer = Trainer(**trainer_options)
result = trainer.fit(model)

after_state_dict = model.state_dict()

for key in before_state_dict.keys():
assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key]))

# check that limit_train_batches=0 turns off training
assert result == 1, "training failed to complete"
assert trainer.current_epoch == 0
assert not model.training_step_invoked, "`training_step` should not run when `limit_train_batches=0`"
assert not model.training_epoch_end_invoked, "`training_epoch_end` should not run when `limit_train_batches=0`"

# check that limit_train_batches has no influence when fast_dev_run is turned on
model = CurrentModel()
trainer_options.update(fast_dev_run=True)
before_state_dict = deepcopy(model.state_dict())

trainer = Trainer(**trainer_options)
result = trainer.fit(model)

after_state_dict = model.state_dict()

for key in before_state_dict.keys():
assert not torch.all(torch.eq(before_state_dict[key], after_state_dict[key]))

assert result == 1, "training failed to complete"
assert trainer.current_epoch == 0
assert model.training_step_invoked, "did not run `training_step` with `fast_dev_run=True`"
assert model.training_epoch_end_invoked, "did not run `training_epoch_end` with `fast_dev_run=True`"


def test_disabled_validation(tmpdir):
"""Verify that `limit_val_batches=0` disables the validation loop unless `fast_dev_run=True`."""

Expand Down

0 comments on commit 360b3d8

Please sign in to comment.