Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename reset_on_epoch to reset_on_run #9658

Merged
merged 13 commits into from
Sep 25, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))
* Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561))
* Use `completed` over `processed` in `reset_on_restart` ([#9656](https://github.com/PyTorchLightning/pytorch-lightning/pull/9656))
* Rename `reset_on_epoch` to `reset_on_run` ([#9658](https://github.com/PyTorchLightning/pytorch-lightning/pull/9658))


- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/dataloader/dataloader_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def done(self) -> bool:
def reset(self) -> None:
"""Resets the internal state."""
if not self.restarting:
self.dataloader_progress.current.reset()
self.dataloader_progress.reset_on_run()
else:
self.dataloader_progress.current.reset_on_restart()
self.dataloader_progress.reset_on_restart()

def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
self.dataloader_progress.increment_ready()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def reset(self) -> None:
self.outputs = []

if not self.restarting:
self.batch_progress.current.reset()
self.batch_progress.reset_on_run()
else:
self.batch_progress.current.reset_on_restart()
self.batch_progress.reset_on_restart()

def on_run_start(
self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def reset(self) -> None:
"""Resets the loops internal state."""
self._all_batch_indices: List[int] = []
self.predictions: List[Any] = []
self.batch_progress.current.reset()
self.batch_progress.reset_on_run()

def on_run_start(
self,
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ def reset(self) -> None:
assert self.batch_loop.optimizer_loop is not None
if self.restarting:
self.batch_progress.reset_on_restart()
self.scheduler_progress.current.reset_on_restart()
self.scheduler_progress.reset_on_restart()
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()
else:
self.batch_progress.reset_on_epoch()
self.scheduler_progress.reset_on_epoch()
self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch()
self.batch_progress.reset_on_run()
self.scheduler_progress.reset_on_run()
self.batch_loop.optimizer_loop.optim_progress.reset_on_run()

# track epoch output
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def connect(self, epoch_loop: TrainingEpochLoop):
def reset(self) -> None:
"""Resets the internal state of this loop."""
if self.restarting:
self.epoch_progress.current.reset_on_restart()
self.epoch_progress.reset_on_restart()

def on_run_start(self) -> None:
"""Calls the ``on_train_start`` hook."""
Expand Down
17 changes: 10 additions & 7 deletions pytorch_lightning/trainer/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def from_defaults(cls, tracker_cls: Type[ReadyCompletedTracker], **kwargs: int)
def reset_on_epoch(self) -> None:
self.current.reset()

def reset_on_run(self) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.current.reset()

def reset_on_restart(self) -> None:
self.current.reset_on_restart()

Expand Down Expand Up @@ -188,8 +191,8 @@ class BatchProgress(Progress):

is_last_batch: bool = False

def reset_on_epoch(self) -> None:
super().reset_on_epoch()
def reset_on_run(self) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
super().reset_on_run()
self.is_last_batch = False

def load_state_dict(self, state_dict: dict) -> None:
Expand Down Expand Up @@ -224,9 +227,9 @@ class OptimizerProgress(BaseProgress):
step: Progress = field(default_factory=lambda: Progress.from_defaults(ReadyCompletedTracker))
zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker))

def reset_on_epoch(self) -> None:
self.step.reset_on_epoch()
self.zero_grad.reset_on_epoch()
def reset_on_run(self) -> None:
self.step.reset_on_run()
self.zero_grad.reset_on_run()

def reset_on_restart(self) -> None:
self.step.reset_on_restart()
Expand Down Expand Up @@ -257,8 +260,8 @@ class OptimizationProgress(BaseProgress):
def optimizer_steps(self) -> int:
return self.optimizer.step.total.completed

def reset_on_epoch(self) -> None:
self.optimizer.reset_on_epoch()
def reset_on_run(self) -> None:
self.optimizer.reset_on_run()

def reset_on_restart(self) -> None:
self.optimizer.reset_on_restart()
Expand Down