Skip to content

Commit

Permalink
Restore log step during restart (#13467)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
rohitgr7 and carmocca authored Jul 12, 2022
1 parent 24189c2 commit df931e2
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the input validation for the accelerator Trainer argument when passed as a string ([#13417](https://github.com/PyTorchLightning/pytorch-lightning/pull/13417))


- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467))


## [1.6.4] - 2022-06-01

Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def teardown(self) -> None:

def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
state_dict["_batches_that_stepped"] = self._batches_that_stepped

if (
self.trainer is not None
Expand All @@ -292,6 +293,7 @@ def on_save_checkpoint(self) -> Dict:
def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict")
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)

def _run_validation(self) -> None:
# reload dataloaders
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_loops_state_dict_structure():
expected = {
"fit_loop": {
"state_dict": {},
"epoch_loop.state_dict": {},
"epoch_loop.state_dict": {"_batches_that_stepped": 0},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
Expand Down
1 change: 1 addition & 0 deletions tests/tests_pytorch/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def on_train_start(self) -> None:
trainer.fit(TestModel(), ckpt_path=ckpt_path)
assert trainer.current_epoch == max_epochs
assert trainer.global_step == max_epochs * train_batches
assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches


def test_fit_twice(tmpdir):
Expand Down

0 comments on commit df931e2

Please sign in to comment.