Skip to content

Commit

Permalink
Restore trainer.current_epoch after tuning (#7434)
Browse files Browse the repository at this point in the history
* Add a test

* Save and restore current_epoch

* Update CHANGELOG

* alphabetical order
  • Loading branch information
akihironitta authored May 8, 2021
1 parent 45143fd commit 710b144
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed DeepSpeed with IterableDatasets ([#7362](https://github.com/PyTorchLightning/pytorch-lightning/pull/7362))


- Fixed `Trainer.current_epoch` not getting restored after tuning ([#7434](https://github.com/PyTorchLightning/pytorch-lightning/pull/7434))


## [1.3.0] - 2021-05-06

### Added
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def __lr_finder_dump_params(trainer, model):
'logger': trainer.logger,
'max_steps': trainer.max_steps,
'checkpoint_callback': trainer.checkpoint_callback,
'current_epoch': trainer.current_epoch,
'configure_optimizers': model.configure_optimizers,
}

Expand All @@ -297,6 +298,7 @@ def __lr_finder_restore_params(trainer, model):
trainer.logger = trainer.__dumped_params['logger']
trainer.callbacks = trainer.__dumped_params['callbacks']
trainer.max_steps = trainer.__dumped_params['max_steps']
trainer.current_epoch = trainer.__dumped_params['current_epoch']
model.configure_optimizers = trainer.__dumped_params['configure_optimizers']
del trainer.__dumped_params

Expand Down
8 changes: 7 additions & 1 deletion tests/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ def test_trainer_reset_correctly(tmpdir):
)

changed_attributes = [
'callbacks', 'logger', 'max_steps', 'auto_lr_find', 'accumulate_grad_batches', 'checkpoint_callback'
'accumulate_grad_batches',
'auto_lr_find',
'callbacks',
'checkpoint_callback',
'current_epoch',
'logger',
'max_steps',
]
expected = {ca: getattr(trainer, ca) for ca in changed_attributes}
trainer.tuner.lr_find(model, num_training=5)
Expand Down
8 changes: 4 additions & 4 deletions tests/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ def test_trainer_reset_correctly(tmpdir):
)

changed_attributes = [
'max_steps',
'weights_summary',
'logger',
'callbacks',
'checkpoint_callback',
'limit_train_batches',
'current_epoch',
'limit_train_batches',
'logger',
'max_steps',
'weights_summary',
]
expected = {ca: getattr(trainer, ca) for ca in changed_attributes}
trainer.tuner.scale_batch_size(model, max_trials=5)
Expand Down

0 comments on commit 710b144

Please sign in to comment.