Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Move reset_on_restart within the loop reset #9561

Merged
merged 14 commits into from
Sep 17, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598))
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320))
* Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))
* Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561))


- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))
Expand Down
15 changes: 4 additions & 11 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

import pytorch_lightning as pl
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BaseProgress, Progress
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException

T = TypeVar("T") # the output type of `run`
Expand Down Expand Up @@ -200,25 +199,19 @@ def load_state_dict(
self,
state_dict: Dict,
prefix: str = "",
restart_progress: bool = True,
metrics: Optional[Dict[str, Metric]] = None,
) -> None:
"""Loads the state of this loop and all its children."""
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress, metrics)
self._load_from_state_dict(state_dict.copy(), prefix, metrics)
for k, v in self.__dict__.items():
if isinstance(v, Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress)
v.load_state_dict(state_dict.copy(), prefix + k + ".")

def _load_from_state_dict(
self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None
) -> None:
def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None:
for k, v in self.__dict__.items():
key = prefix + k
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[key])
if restart_progress:
apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart())

elif (
isinstance(v, ResultCollection)
and self.trainer is not None
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/loops/dataloader/dataloader_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def reset(self) -> None:
"""Resets the internal state."""
if not self.restarting:
self.dataloader_progress.current.reset()
else:
self.dataloader_progress.current.reset_on_restart()

def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
self.dataloader_progress.increment_ready()
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def reset(self) -> None:

if not self.restarting:
self.batch_progress.current.reset()
else:
self.batch_progress.current.reset_on_restart()

def on_run_start(
self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def reset(self) -> None:
"""Resets the internal state of the loop for a new run."""
assert self.batch_loop is not None
assert self.batch_loop.optimizer_loop is not None
if self.restarting:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.batch_progress.current.reset_on_restart()
self.scheduler_progress.current.reset_on_restart()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()

self.is_last_batch = False

# track epoch output
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def connect(self, epoch_loop: TrainingEpochLoop):

def reset(self) -> None:
"""Resets the internal state of this loop."""
if self.restarting:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.epoch_progress.current.reset_on_restart()

def on_run_start(self) -> None:
"""Calls the ``on_train_start`` hook."""
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def reset(self) -> None:
if not self.restarting:
# when reset() is called from outside (manually), we reset the loop progress
self.optim_progress.optimizer_position = 0
else:
self.optim_progress.reset_on_restart()
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]

def on_run_start( # type: ignore[override]
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,7 @@ def reset_on_epoch(self) -> None:
def load_state_dict(self, state_dict: dict) -> None:
self.optimizer.load_state_dict(state_dict["optimizer"])
self.optimizer_position = state_dict["optimizer_position"]

def reset_on_restart(self):
self.optimizer.step.current.reset_on_restart()
self.optimizer.zero_grad.current.reset_on_restart()
37 changes: 24 additions & 13 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def val_dataloader(self):
}
assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected

trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False)
trainer.fit_loop.load_state_dict(checkpoint)

# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
nbe_total_val_batch = stop_epoch * n_dataloaders * n_batches
Expand Down Expand Up @@ -327,7 +327,12 @@ def val_dataloader(self):
"processed": total_val_batch,
"completed": total_val_batch,
},
"current": {"ready": stop_batch, "started": stop_batch, "processed": stop_batch, "completed": stop_batch},
"current": {
"ready": stop_batch + 1,
"started": stop_batch + 1,
"processed": stop_batch,
"completed": stop_batch,
},
}
assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected

Expand Down Expand Up @@ -496,16 +501,22 @@ def configure_optimizers_multiple(self):
}
assert checkpoint["loops"]["fit_loop"] == expected

trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False)
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
state_dict = trainer.fit_loop.state_dict()

# need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the
# fit loop to have an iterator, which is only available during training
checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY
assert state_dict == checkpoint["loops"]["fit_loop"]

# with `restart_progress=True`, we expect all `ready` counters to be reset to `completed`
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=True)
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
# test resetting manually, we expect all `ready` counters to be reset to `completed`
trainer.fit_loop.reset()
trainer.fit_loop.epoch_loop.reset()
trainer.fit_loop.epoch_loop.batch_loop.reset()
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.reset()
trainer.fit_loop.epoch_loop.val_loop.reset()
trainer.fit_loop.epoch_loop.val_loop.epoch_loop.reset()

epoch_progress = trainer.fit_loop.epoch_progress
assert epoch_progress.current.ready == stop_epoch
Expand Down Expand Up @@ -693,26 +704,26 @@ def test_fit_loop_reset(tmpdir):

fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"])

def mid_epoch_reset_assertions():
def mid_epoch_reset_assertions(has_reset):
assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.ready == (not has_reset)
assert fit_loop.epoch_progress.current.completed == 0

assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 2
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 2
assert epoch_loop.batch_progress.current.completed == 2
assert epoch_loop.batch_progress.current.ready == 2 if has_reset else 1
assert epoch_loop.batch_progress.current.completed == 2 if has_reset else 1
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

# resetting from a mid-epoch checkpoint should not change progress counters
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
mid_epoch_reset_assertions()
mid_epoch_reset_assertions(has_reset=False)
assert optimizer_loop.optim_progress.optimizer_position == 1
fit_loop.reset()
epoch_loop.reset()
optimizer_loop.reset()
mid_epoch_reset_assertions()
mid_epoch_reset_assertions(has_reset=True)
assert optimizer_loop.optim_progress.optimizer_position == 1

# reset state loaded from a checkpoint from the end of an epoch
Expand All @@ -728,14 +739,14 @@ def mid_epoch_reset_assertions():
assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.ready == 1
assert fit_loop.epoch_progress.current.completed == 0

assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 4
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 4
assert epoch_loop.batch_progress.current.completed == 4
assert epoch_loop.batch_progress.current.completed == 3

assert optimizer_loop.optim_progress.optimizer_position == 1

Expand Down