From 262900d57859e7028806e30a4317c216c008c7d8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 16 Sep 2021 13:53:15 +0100 Subject: [PATCH 01/12] move reset_on_restart in the loop --- pytorch_lightning/loops/base.py | 15 +++----- .../loops/epoch/evaluation_epoch_loop.py | 2 ++ .../loops/epoch/training_epoch_loop.py | 5 +++ pytorch_lightning/loops/fit_loop.py | 2 ++ .../loops/optimization/optimizer_loop.py | 2 ++ pytorch_lightning/trainer/progress.py | 8 +++++ tests/loops/test_loops.py | 34 +++++++++++++------ 7 files changed, 46 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 5573b04952ddd..1a19c753b0e2b 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -20,8 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import BaseProgress, Progress -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException T = TypeVar("T") # the output type of `run` @@ -200,25 +199,19 @@ def load_state_dict( self, state_dict: Dict, prefix: str = "", - restart_progress: bool = True, metrics: Optional[Dict[str, Metric]] = None, ) -> None: """Loads the state of this loop and all its children.""" - self._load_from_state_dict(state_dict.copy(), prefix, restart_progress, metrics) + self._load_from_state_dict(state_dict.copy(), prefix, metrics) for k, v in self.__dict__.items(): if isinstance(v, Loop): - v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress) + v.load_state_dict(state_dict.copy(), prefix + k + ".") - def _load_from_state_dict( - self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None - ) -> None: + def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None: for k, v in self.__dict__.items(): key = prefix + k if isinstance(v, BaseProgress): v.load_state_dict(state_dict[key]) - if restart_progress: - apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart()) - elif ( isinstance(v, ResultCollection) and self.trainer is not None diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index f30df960c1ad8..3c9cf0d7172f2 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -59,6 +59,8 @@ def reset(self) -> None: if not self.restarting: self.batch_progress.current.reset() + else: + self.batch_progress.current.reset_on_restart() def on_run_start( self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 274074653d1e7..bf51e05390619 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -95,6 +95,11 @@ def reset(self) -> None: assert self.batch_loop.optimizer_loop is not None self.is_last_batch = False + if self.restarting: + self.batch_progress.current.reset_on_restart() + self.scheduler_progress.current.reset_on_restart() + self.batch_loop.optimizer_loop.optim_progress.reset_on_restart() + # track epoch output self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 3e9917a551193..9a4f7c510f303 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -174,6 +174,8 @@ def connect(self, epoch_loop: TrainingEpochLoop): def reset(self) -> None: """Resets the internal state of this loop.""" + if self.restarting: + self.epoch_progress.current.reset_on_restart() def on_run_start(self) -> None: """Calls the ``on_train_start`` hook.""" diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index 484c3ba4bd60f..a5a7ca746733b 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -197,6 +197,8 @@ def reset(self) -> None: if not self.restarting: # when reset() is called from outside (manually), we reset the loop progress self.optim_progress.optimizer_position = 0 + else: + self.optim_progress.reset_on_restart() self.outputs = [[] for _ in range(len(self.trainer.optimizers))] def on_run_start( # type: ignore[override] diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 5b4f072305947..4febd7bde6953 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -229,3 +229,11 @@ def reset_on_epoch(self) -> None: def load_state_dict(self, state_dict: dict) -> None: self.optimizer.load_state_dict(state_dict["optimizer"]) self.optimizer_position = state_dict["optimizer_position"] + + def reset_on_restart(self): + self.optimizer.step.current.reset_on_restart() + self.optimizer.zero_grad.current.reset_on_restart() + + def reset_on_batch_end(self): + self.optimizer.step.current.reset() + self.optimizer.zero_grad.current.reset() diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 870c525561899..b45ac7b1639db 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -297,7 +297,7 @@ def val_dataloader(self): } assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected - trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False) + trainer.fit_loop.load_state_dict(checkpoint) # `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch nbe_total_val_batch = stop_epoch * n_dataloaders * n_batches @@ -327,7 +327,12 @@ def val_dataloader(self): "processed": total_val_batch, "completed": total_val_batch, }, - "current": {"ready": stop_batch, "started": stop_batch, "processed": stop_batch, "completed": stop_batch}, + "current": { + "ready": stop_batch + 1, + "started": stop_batch + 1, + "processed": stop_batch, + "completed": stop_batch, + }, } assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected @@ -496,7 +501,7 @@ def configure_optimizers_multiple(self): } assert checkpoint["loops"]["fit_loop"] == expected - trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False) + trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) state_dict = trainer.fit_loop.state_dict() # need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the @@ -505,7 +510,14 @@ def configure_optimizers_multiple(self): assert state_dict == checkpoint["loops"]["fit_loop"] # with `restart_progress=True`, we expect all `ready` counters to be reset to `completed` - trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=True) + trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) + + trainer.fit_loop.reset() + trainer.fit_loop.epoch_loop.reset() + trainer.fit_loop.epoch_loop.batch_loop.reset() + trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.reset() + trainer.fit_loop.epoch_loop.val_loop.reset() + trainer.fit_loop.epoch_loop.val_loop.epoch_loop.reset() epoch_progress = trainer.fit_loop.epoch_progress assert epoch_progress.current.ready == stop_epoch @@ -693,18 +705,18 @@ def test_fit_loop_reset(tmpdir): fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"]) - def mid_epoch_reset_assertions(): + def mid_epoch_reset_assertions(has_reset: bool = False): assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch - assert fit_loop.epoch_progress.current.ready == 0 + assert fit_loop.epoch_progress.current.ready == (not int(has_reset)) assert fit_loop.epoch_progress.current.completed == 0 assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 2 assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 2 - assert epoch_loop.batch_progress.current.completed == 2 + assert epoch_loop.batch_progress.current.ready == 2 if has_reset else 1 + assert epoch_loop.batch_progress.current.completed == 2 if has_reset else 1 # resetting from a mid-epoch checkpoint should not change progress counters mid_epoch_reset_assertions() @@ -712,7 +724,7 @@ def mid_epoch_reset_assertions(): fit_loop.reset() epoch_loop.reset() optimizer_loop.reset() - mid_epoch_reset_assertions() + mid_epoch_reset_assertions(has_reset=True) assert optimizer_loop.optim_progress.optimizer_position == 1 # reset state loaded from a checkpoint from the end of an epoch @@ -728,14 +740,14 @@ def mid_epoch_reset_assertions(): assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes - assert fit_loop.epoch_progress.current.ready == 0 + assert fit_loop.epoch_progress.current.ready == 1 assert fit_loop.epoch_progress.current.completed == 0 assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 4 assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end assert epoch_loop.batch_progress.current.ready == 4 - assert epoch_loop.batch_progress.current.completed == 4 + assert epoch_loop.batch_progress.current.completed == 3 assert optimizer_loop.optim_progress.optimizer_position == 1 From 3668fedf1fcbedf09e49314f4a9494060df3a8f7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 16 Sep 2021 14:24:52 +0100 Subject: [PATCH 02/12] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 42a720e5c0d1a..2f6d3d4864483 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)) * Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)) * Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371)) + * Move `reset_on_restart` within the loop reset function ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561)) - Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628)) From e9a46d0ef7c8445b044f5a71a1a710689321fef3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 16 Sep 2021 14:38:07 +0100 Subject: [PATCH 03/12] update --- pytorch_lightning/loops/dataloader/dataloader_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py index 6b5fecd07e807..00a5ee32b933b 100644 --- a/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -57,6 +57,8 @@ def reset(self) -> None: """Resets the internal state.""" if not self.restarting: self.dataloader_progress.current.reset() + else: + self.dataloader_progress.current.reset_on_restart() def on_advance_start(self, *args: Any, **kwargs: Any) -> None: self.dataloader_progress.increment_ready() From c4b86abeef815a5abada77b4267020097eee1577 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 16:56:27 +0200 Subject: [PATCH 04/12] Minor test changes --- tests/loops/test_loops.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index b45ac7b1639db..839751cbd1318 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -509,9 +509,8 @@ def configure_optimizers_multiple(self): checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY assert state_dict == checkpoint["loops"]["fit_loop"] - # with `restart_progress=True`, we expect all `ready` counters to be reset to `completed` trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) - + # test resetting manually, we expect all `ready` counters to be reset to `completed` trainer.fit_loop.reset() trainer.fit_loop.epoch_loop.reset() trainer.fit_loop.epoch_loop.batch_loop.reset() @@ -705,11 +704,11 @@ def test_fit_loop_reset(tmpdir): fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"]) - def mid_epoch_reset_assertions(has_reset: bool = False): + def mid_epoch_reset_assertions(has_reset): assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch - assert fit_loop.epoch_progress.current.ready == (not int(has_reset)) + assert fit_loop.epoch_progress.current.ready == (not has_reset) assert fit_loop.epoch_progress.current.completed == 0 assert epoch_loop.restarting @@ -719,7 +718,7 @@ def mid_epoch_reset_assertions(has_reset: bool = False): assert epoch_loop.batch_progress.current.completed == 2 if has_reset else 1 # resetting from a mid-epoch checkpoint should not change progress counters - mid_epoch_reset_assertions() + mid_epoch_reset_assertions(has_reset=False) assert optimizer_loop.optim_progress.optimizer_position == 1 fit_loop.reset() epoch_loop.reset() From ad334f51cd0a8cb9f6ea415c35646a679e6209ca Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 16 Sep 2021 16:16:27 +0100 Subject: [PATCH 05/12] Update CHANGELOG.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f6d3d4864483..8d7d59940f052 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)) * Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)) * Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371)) - * Move `reset_on_restart` within the loop reset function ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561)) + * Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561)) - Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628)) From c672030cb250759ee4428118840016faabaec792 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 16 Sep 2021 16:22:13 +0100 Subject: [PATCH 06/12] updte --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 4 ++-- pytorch_lightning/trainer/progress.py | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index bf51e05390619..3d7f36477c55e 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -93,13 +93,13 @@ def reset(self) -> None: """Resets the internal state of the loop for a new run.""" assert self.batch_loop is not None assert self.batch_loop.optimizer_loop is not None - self.is_last_batch = False - if self.restarting: self.batch_progress.current.reset_on_restart() self.scheduler_progress.current.reset_on_restart() self.batch_loop.optimizer_loop.optim_progress.reset_on_restart() + self.is_last_batch = False + # track epoch output self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 4febd7bde6953..0b58f610a693d 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -233,7 +233,3 @@ def load_state_dict(self, state_dict: dict) -> None: def reset_on_restart(self): self.optimizer.step.current.reset_on_restart() self.optimizer.zero_grad.current.reset_on_restart() - - def reset_on_batch_end(self): - self.optimizer.step.current.reset() - self.optimizer.zero_grad.current.reset() From 2107f2a92dcdb51eeb2833f00dd0d66069dfdef5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 17 Sep 2021 14:41:59 +0200 Subject: [PATCH 07/12] Remove duplicated block --- tests/loops/test_loops.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 839751cbd1318..f6be3b0566e32 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -319,23 +319,6 @@ def val_dataloader(self): } assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected - trainer.fit_loop.load_state_dict(checkpoint) - expected = { - "total": { - "ready": total_val_batch + 1, - "started": total_val_batch + 1, - "processed": total_val_batch, - "completed": total_val_batch, - }, - "current": { - "ready": stop_batch + 1, - "started": stop_batch + 1, - "processed": stop_batch, - "completed": stop_batch, - }, - } - assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected - @RunIf(min_torch="1.7.0") @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) From 9249a723612508f13a81cf43dd35a526f6c06d59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 17 Sep 2021 15:22:54 +0200 Subject: [PATCH 08/12] more meaningful test --- tests/loops/test_loops.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index f6be3b0566e32..fb5590b3c06e6 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -687,26 +687,23 @@ def test_fit_loop_reset(tmpdir): fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"]) - def mid_epoch_reset_assertions(has_reset): - assert fit_loop.restarting - assert fit_loop.epoch_progress.total.ready == 1 - assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch - assert fit_loop.epoch_progress.current.ready == (not has_reset) - assert fit_loop.epoch_progress.current.completed == 0 - - assert epoch_loop.restarting - assert epoch_loop.batch_progress.total.ready == 2 - assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 2 if has_reset else 1 - assert epoch_loop.batch_progress.current.completed == 2 if has_reset else 1 - - # resetting from a mid-epoch checkpoint should not change progress counters - mid_epoch_reset_assertions(has_reset=False) assert optimizer_loop.optim_progress.optimizer_position == 1 fit_loop.reset() epoch_loop.reset() optimizer_loop.reset() - mid_epoch_reset_assertions(has_reset=True) + assert fit_loop.restarting + assert fit_loop.epoch_progress.total.ready == 1 + assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch + assert fit_loop.epoch_progress.current.ready == 0 + assert fit_loop.epoch_progress.current.completed == 0 + + assert epoch_loop.restarting + assert epoch_loop.batch_progress.total.ready == 2 + assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end + assert epoch_loop.batch_progress.current.ready == 2 + assert epoch_loop.batch_progress.current.completed == 2 + + assert optimizer_loop.restarting assert optimizer_loop.optim_progress.optimizer_position == 1 # reset state loaded from a checkpoint from the end of an epoch From d3f0293f69baa499f3c9f0cd8f5fae406ea9a939 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 17 Sep 2021 15:34:46 +0200 Subject: [PATCH 09/12] Remove duplicated line --- tests/loops/test_loops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index fb5590b3c06e6..9ce04849244cd 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -686,8 +686,6 @@ def test_fit_loop_reset(tmpdir): assert not optimizer_loop.restarting fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"]) - - assert optimizer_loop.optim_progress.optimizer_position == 1 fit_loop.reset() epoch_loop.reset() optimizer_loop.reset() From 3254d305194c54d00c0a35d3dd3d3e85693e8bc2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 17 Sep 2021 15:48:23 +0200 Subject: [PATCH 10/12] Add back value checks before `reset` --- tests/loops/test_loops.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 9ce04849244cd..cfb8fe0be1932 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -685,10 +685,29 @@ def test_fit_loop_reset(tmpdir): assert not epoch_loop.restarting assert not optimizer_loop.restarting + # we load exactly what was saved - no reset yet fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"]) + + assert fit_loop.restarting + assert fit_loop.epoch_progress.total.ready == 1 + assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch + assert fit_loop.epoch_progress.current.ready == 1 + assert fit_loop.epoch_progress.current.completed == 0 + + assert epoch_loop.restarting + assert epoch_loop.batch_progress.total.ready == 2 + assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end + assert epoch_loop.batch_progress.current.ready == 2 + assert epoch_loop.batch_progress.current.completed == 1 + + assert optimizer_loop.restarting + assert optimizer_loop.optim_progress.optimizer_position == 1 + + # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 fit_loop.reset() epoch_loop.reset() optimizer_loop.reset() + assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch @@ -712,6 +731,7 @@ def test_fit_loop_reset(tmpdir): epoch_loop.restarting = False optimizer_loop.restarting = False + # we load exactly what was saved - no reset yet fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"]) assert fit_loop.restarting @@ -728,7 +748,7 @@ def test_fit_loop_reset(tmpdir): assert optimizer_loop.optim_progress.optimizer_position == 1 - # resetting from a end-of-epoch checkpoint should reset the current counters to 0 + # resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0 fit_loop.reset() epoch_loop.reset() optimizer_loop.reset() From ff50b6ecdd4b22d682c010f7dfe93a27ddd61941 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 17 Sep 2021 16:01:04 +0200 Subject: [PATCH 11/12] Remove code at the request of Adrian --- tests/loops/test_loops.py | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index cfb8fe0be1932..47145a2f8f408 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -687,22 +687,6 @@ def test_fit_loop_reset(tmpdir): # we load exactly what was saved - no reset yet fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"]) - - assert fit_loop.restarting - assert fit_loop.epoch_progress.total.ready == 1 - assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch - assert fit_loop.epoch_progress.current.ready == 1 - assert fit_loop.epoch_progress.current.completed == 0 - - assert epoch_loop.restarting - assert epoch_loop.batch_progress.total.ready == 2 - assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 2 - assert epoch_loop.batch_progress.current.completed == 1 - - assert optimizer_loop.restarting - assert optimizer_loop.optim_progress.optimizer_position == 1 - # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 fit_loop.reset() epoch_loop.reset() @@ -733,21 +717,6 @@ def test_fit_loop_reset(tmpdir): # we load exactly what was saved - no reset yet fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"]) - - assert fit_loop.restarting - assert fit_loop.epoch_progress.total.ready == 1 - assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes - assert fit_loop.epoch_progress.current.ready == 1 - assert fit_loop.epoch_progress.current.completed == 0 - - assert epoch_loop.restarting - assert epoch_loop.batch_progress.total.ready == 4 - assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 4 - assert epoch_loop.batch_progress.current.completed == 3 - - assert optimizer_loop.optim_progress.optimizer_position == 1 - # resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0 fit_loop.reset() epoch_loop.reset() From 76fce3dcaed116a385e753105093409b9817fa21 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 16:31:25 +0100 Subject: [PATCH 12/12] update --- pytorch_lightning/trainer/progress.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 0b58f610a693d..0f07c61999e1c 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -153,6 +153,9 @@ def load_state_dict(self, state_dict: dict) -> None: self.total.load_state_dict(state_dict["total"]) self.current.load_state_dict(state_dict["current"]) + def reset_on_restart(self) -> None: + self.current.reset_on_restart() + @dataclass class DataLoaderProgress(Progress): @@ -230,6 +233,6 @@ def load_state_dict(self, state_dict: dict) -> None: self.optimizer.load_state_dict(state_dict["optimizer"]) self.optimizer_position = state_dict["optimizer_position"] - def reset_on_restart(self): + def reset_on_restart(self) -> None: self.optimizer.step.current.reset_on_restart() self.optimizer.zero_grad.current.reset_on_restart()