Skip to content

Commit

Permalink
[Refactor] 1/2 Move reset_on_restart within the loop reset (#9561)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
2 people authored and SeanNaren committed Sep 22, 2021
1 parent a0e3de9 commit 562e18f
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 61 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598))
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320))
* Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))
* Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561))


- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))
Expand Down
15 changes: 4 additions & 11 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

import pytorch_lightning as pl
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BaseProgress, Progress
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException

T = TypeVar("T") # the output type of `run`
Expand Down Expand Up @@ -200,25 +199,19 @@ def load_state_dict(
self,
state_dict: Dict,
prefix: str = "",
restart_progress: bool = True,
metrics: Optional[Dict[str, Metric]] = None,
) -> None:
"""Loads the state of this loop and all its children."""
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress, metrics)
self._load_from_state_dict(state_dict.copy(), prefix, metrics)
for k, v in self.__dict__.items():
if isinstance(v, Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress)
v.load_state_dict(state_dict.copy(), prefix + k + ".")

def _load_from_state_dict(
self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None
) -> None:
def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None:
for k, v in self.__dict__.items():
key = prefix + k
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[key])
if restart_progress:
apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart())

elif (
isinstance(v, ResultCollection)
and self.trainer is not None
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/loops/dataloader/dataloader_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def reset(self) -> None:
"""Resets the internal state."""
if not self.restarting:
self.dataloader_progress.current.reset()
else:
self.dataloader_progress.current.reset_on_restart()

def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
self.dataloader_progress.increment_ready()
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def reset(self) -> None:

if not self.restarting:
self.batch_progress.current.reset()
else:
self.batch_progress.current.reset_on_restart()

def on_run_start(
self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def reset(self) -> None:
"""Resets the internal state of the loop for a new run."""
assert self.batch_loop is not None
assert self.batch_loop.optimizer_loop is not None
if self.restarting:
self.batch_progress.current.reset_on_restart()
self.scheduler_progress.current.reset_on_restart()
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()

self.is_last_batch = False

# track epoch output
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def connect(self, epoch_loop: TrainingEpochLoop):

def reset(self) -> None:
"""Resets the internal state of this loop."""
if self.restarting:
self.epoch_progress.current.reset_on_restart()

def on_run_start(self) -> None:
"""Calls the ``on_train_start`` hook."""
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def reset(self) -> None:
if not self.restarting:
# when reset() is called from outside (manually), we reset the loop progress
self.optim_progress.optimizer_position = 0
else:
self.optim_progress.reset_on_restart()
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]

def on_run_start( # type: ignore[override]
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/trainer/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def load_state_dict(self, state_dict: dict) -> None:
self.total.load_state_dict(state_dict["total"])
self.current.load_state_dict(state_dict["current"])

def reset_on_restart(self) -> None:
self.current.reset_on_restart()


@dataclass
class DataLoaderProgress(Progress):
Expand Down Expand Up @@ -229,3 +232,7 @@ def reset_on_epoch(self) -> None:
def load_state_dict(self, state_dict: dict) -> None:
self.optimizer.load_state_dict(state_dict["optimizer"])
self.optimizer_position = state_dict["optimizer_position"]

def reset_on_restart(self) -> None:
self.optimizer.step.current.reset_on_restart()
self.optimizer.zero_grad.current.reset_on_restart()
78 changes: 28 additions & 50 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def val_dataloader(self):
}
assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected

trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False)
trainer.fit_loop.load_state_dict(checkpoint)

# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
nbe_total_val_batch = stop_epoch * n_dataloaders * n_batches
Expand All @@ -319,18 +319,6 @@ def val_dataloader(self):
}
assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected

trainer.fit_loop.load_state_dict(checkpoint)
expected = {
"total": {
"ready": total_val_batch + 1,
"started": total_val_batch + 1,
"processed": total_val_batch,
"completed": total_val_batch,
},
"current": {"ready": stop_batch, "started": stop_batch, "processed": stop_batch, "completed": stop_batch},
}
assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected


@RunIf(min_torch="1.7.0")
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
Expand Down Expand Up @@ -496,16 +484,22 @@ def configure_optimizers_multiple(self):
}
assert checkpoint["loops"]["fit_loop"] == expected

trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False)
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
state_dict = trainer.fit_loop.state_dict()

# need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the
# fit loop to have an iterator, which is only available during training
checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY
assert state_dict == checkpoint["loops"]["fit_loop"]

# with `restart_progress=True`, we expect all `ready` counters to be reset to `completed`
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=True)
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
# test resetting manually, we expect all `ready` counters to be reset to `completed`
trainer.fit_loop.reset()
trainer.fit_loop.epoch_loop.reset()
trainer.fit_loop.epoch_loop.batch_loop.reset()
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.reset()
trainer.fit_loop.epoch_loop.val_loop.reset()
trainer.fit_loop.epoch_loop.val_loop.epoch_loop.reset()

epoch_progress = trainer.fit_loop.epoch_progress
assert epoch_progress.current.ready == stop_epoch
Expand Down Expand Up @@ -691,28 +685,26 @@ def test_fit_loop_reset(tmpdir):
assert not epoch_loop.restarting
assert not optimizer_loop.restarting

# we load exactly what was saved - no reset yet
fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"])

def mid_epoch_reset_assertions():
assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0

assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 2
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 2
assert epoch_loop.batch_progress.current.completed == 2

# resetting from a mid-epoch checkpoint should not change progress counters
mid_epoch_reset_assertions()
assert optimizer_loop.optim_progress.optimizer_position == 1
# resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
fit_loop.reset()
epoch_loop.reset()
optimizer_loop.reset()
mid_epoch_reset_assertions()

assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0

assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 2
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 2
assert epoch_loop.batch_progress.current.completed == 2

assert optimizer_loop.restarting
assert optimizer_loop.optim_progress.optimizer_position == 1

# reset state loaded from a checkpoint from the end of an epoch
Expand All @@ -723,23 +715,9 @@ def mid_epoch_reset_assertions():
epoch_loop.restarting = False
optimizer_loop.restarting = False

# we load exactly what was saved - no reset yet
fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"])

assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0

assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 4
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 4
assert epoch_loop.batch_progress.current.completed == 4

assert optimizer_loop.optim_progress.optimizer_position == 1

# resetting from a end-of-epoch checkpoint should reset the current counters to 0
# resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0
fit_loop.reset()
epoch_loop.reset()
optimizer_loop.reset()
Expand Down

0 comments on commit 562e18f

Please sign in to comment.