diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b86fdaab097d..43180c75513b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371)) * Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561)) * Use `completed` over `processed` in `reset_on_restart` ([#9656](https://github.com/PyTorchLightning/pytorch-lightning/pull/9656)) + * Rename `reset_on_epoch` to `reset_on_run` ([#9658](https://github.com/PyTorchLightning/pytorch-lightning/pull/9658)) - Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628)) diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py index 00a5ee32b933b..8e0d57c782cab 100644 --- a/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -56,9 +56,9 @@ def done(self) -> bool: def reset(self) -> None: """Resets the internal state.""" if not self.restarting: - self.dataloader_progress.current.reset() + self.dataloader_progress.reset_on_run() else: - self.dataloader_progress.current.reset_on_restart() + self.dataloader_progress.reset_on_restart() def on_advance_start(self, *args: Any, **kwargs: Any) -> None: self.dataloader_progress.increment_ready() diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 3c9cf0d7172f2..3a7dd95acbf31 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -58,9 +58,9 @@ def reset(self) -> None: self.outputs = [] if not self.restarting: - self.batch_progress.current.reset() + self.batch_progress.reset_on_run() else: - self.batch_progress.current.reset_on_restart() + self.batch_progress.reset_on_restart() def on_run_start( self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index bd5a45089528d..58e65233dfe81 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -46,7 +46,7 @@ def reset(self) -> None: """Resets the loops internal state.""" self._all_batch_indices: List[int] = [] self.predictions: List[Any] = [] - self.batch_progress.current.reset() + self.batch_progress.reset_on_run() def on_run_start( self, diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 344514a448ba5..97108d72b58c9 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -96,12 +96,12 @@ def reset(self) -> None: assert self.batch_loop.optimizer_loop is not None if self.restarting: self.batch_progress.reset_on_restart() - self.scheduler_progress.current.reset_on_restart() + self.scheduler_progress.reset_on_restart() self.batch_loop.optimizer_loop.optim_progress.reset_on_restart() else: - self.batch_progress.reset_on_epoch() - self.scheduler_progress.reset_on_epoch() - self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch() + self.batch_progress.reset_on_run() + self.scheduler_progress.reset_on_run() + self.batch_loop.optimizer_loop.optim_progress.reset_on_run() # track epoch output self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 9a4f7c510f303..c6da69feb65b2 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -175,7 +175,7 @@ def connect(self, epoch_loop: TrainingEpochLoop): def reset(self) -> None: """Resets the internal state of this loop.""" if self.restarting: - self.epoch_progress.current.reset_on_restart() + self.epoch_progress.reset_on_restart() def on_run_start(self) -> None: """Calls the ``on_train_start`` hook.""" diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index fa74963ee2964..7eaf219910b67 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -151,6 +151,9 @@ def from_defaults(cls, tracker_cls: Type[ReadyCompletedTracker], **kwargs: int) def reset_on_epoch(self) -> None: self.current.reset() + def reset_on_run(self) -> None: + self.current.reset() + def reset_on_restart(self) -> None: self.current.reset_on_restart() @@ -188,8 +191,8 @@ class BatchProgress(Progress): is_last_batch: bool = False - def reset_on_epoch(self) -> None: - super().reset_on_epoch() + def reset_on_run(self) -> None: + super().reset_on_run() self.is_last_batch = False def load_state_dict(self, state_dict: dict) -> None: @@ -224,9 +227,9 @@ class OptimizerProgress(BaseProgress): step: Progress = field(default_factory=lambda: Progress.from_defaults(ReadyCompletedTracker)) zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker)) - def reset_on_epoch(self) -> None: - self.step.reset_on_epoch() - self.zero_grad.reset_on_epoch() + def reset_on_run(self) -> None: + self.step.reset_on_run() + self.zero_grad.reset_on_run() def reset_on_restart(self) -> None: self.step.reset_on_restart() @@ -257,8 +260,8 @@ class OptimizationProgress(BaseProgress): def optimizer_steps(self) -> int: return self.optimizer.step.total.completed - def reset_on_epoch(self) -> None: - self.optimizer.reset_on_epoch() + def reset_on_run(self) -> None: + self.optimizer.reset_on_run() def reset_on_restart(self) -> None: self.optimizer.reset_on_restart()