From f239948e04b4a1a6cb8e900bc73731c0f4a1b6a7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 23 Sep 2021 02:19:00 +0200 Subject: [PATCH 01/11] Add `is_last_batch` to progress tracking --- .../loops/epoch/training_epoch_loop.py | 19 ++++++------- pytorch_lightning/trainer/progress.py | 28 +++++++++++++++++-- pytorch_lightning/trainer/trainer.py | 2 +- tests/trainer/test_trainer.py | 3 +- 4 files changed, 36 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 3d7f36477c55e..a559b45e0b297 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -20,7 +20,7 @@ from pytorch_lightning.loops.optimization.closure import OutputResult from pytorch_lightning.loops.utilities import _prepare_dataloader_iter from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import Progress, SchedulerProgress +from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -43,9 +43,7 @@ def __init__(self, min_steps: int, max_steps: int): self.max_steps: int = max_steps self.global_step: int = 0 - # manually tracking which is the last batch is necessary for iterable dataset support - self.is_last_batch: Optional[bool] = None - self.batch_progress = Progress() + self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() self.batch_loop: Optional[TrainingBatchLoop] = None @@ -94,12 +92,10 @@ def reset(self) -> None: assert self.batch_loop is not None assert self.batch_loop.optimizer_loop is not None if self.restarting: - self.batch_progress.current.reset_on_restart() + self.batch_progress.reset_on_restart() self.scheduler_progress.current.reset_on_restart() self.batch_loop.optimizer_loop.optim_progress.reset_on_restart() - self.is_last_batch = False - # track epoch output self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] @@ -127,6 +123,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: StopIteration: When the epoch is canceled by the user returning -1 """ batch_idx, (batch, is_last) = next(self.dataloader_iter) + self.batch_progress.is_last_batch = is_last if not self.trainer.data_connector.train_data_fetcher.store_on_device: with self.trainer.profiler.profile("training_batch_to_device"): @@ -139,8 +136,6 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.batch_progress.increment_processed() - self.is_last_batch = is_last - # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: raise StopIteration @@ -178,7 +173,7 @@ def on_advance_end(self): # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- - should_check_val = self._should_check_val_fx(self.batch_idx, self.is_last_batch) + should_check_val = self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch) if should_check_val: self.trainer.validating = True self._run_validation() @@ -259,7 +254,9 @@ def _accumulated_batches_reached(self) -> bool: def _num_training_batches_reached(self) -> bool: """Checks if we are in the last batch or if there are more batches to follow.""" - return self.batch_progress.current.ready == self.trainer.num_training_batches or self.is_last_batch + return ( + self.batch_progress.current.ready == self.trainer.num_training_batches or self.batch_progress.is_last_batch + ) def _should_accumulate(self) -> bool: """Checks if the optimizer step should be performed or gradients should be accumulated for the current diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 0f07c61999e1c..5d643667f3b62 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -159,7 +159,9 @@ def reset_on_restart(self) -> None: @dataclass class DataLoaderProgress(Progress): - """Tracks the dataloader progress These counters are local to a trainer rank. By default, they are not globally + """Tracks dataloader progress. + + These counters are local to a trainer rank. By default, they are not globally synced across all ranks. Args: @@ -171,9 +173,31 @@ class DataLoaderProgress(Progress): current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker) +@dataclass +class BatchProgress(Progress): + """Tracks batch progress. + + These counters are local to a trainer rank. By default, they are not globally + synced across all ranks. + + Args: + total: Tracks the total dataloader progress. + current: Tracks the current dataloader progress. + is_last_batch: Whether the batch is the last one. This is useful for iterable datasets. + """ + + is_last_batch: bool = False + + def reset_on_restart(self) -> None: + super().reset_on_restart() + self.is_last_batch = False + + @dataclass class SchedulerProgress(Progress): - """Tracks the scheduler progress. These counters are local to a trainer rank. By default, they are not globally + """Tracks scheduler progress. + + These counters are local to a trainer rank. By default, they are not globally synced across all ranks. Args: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2e115decf3ade..94caa69c000f8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1877,7 +1877,7 @@ def min_steps(self) -> Optional[int]: @property def is_last_batch(self) -> bool: - return self.fit_loop.epoch_loop.is_last_batch + return self.fit_loop.epoch_loop.batch_progress.is_last_batch @property def fit_loop(self) -> FitLoop: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5f1bdd1f34541..6c30047445889 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -244,9 +244,8 @@ def on_train_batch_start(self, *_): def on_train_batch_end(self, outputs, batch, batch_idx, *_): end_state_dict = self.state_dict() - is_last_batch = (batch_idx + 1) == self.trainer.num_training_batches - if is_last_batch or self.opt_step_called: + if self.trainer.is_last_batch or self.opt_step_called: assert self.check(self.start_state_dict, end_state_dict, equal=False) else: assert self.check(self.start_state_dict, end_state_dict) From 9eedea0f26422bb9d206360fc8671706ccf9e39e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 23 Sep 2021 02:24:46 +0200 Subject: [PATCH 02/11] Docstring fixes --- pytorch_lightning/trainer/progress.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 5d643667f3b62..37db5e55caae9 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -161,8 +161,7 @@ def reset_on_restart(self) -> None: class DataLoaderProgress(Progress): """Tracks dataloader progress. - These counters are local to a trainer rank. By default, they are not globally - synced across all ranks. + These counters are local to a trainer rank. By default, they are not globally synced across all ranks. Args: total: Tracks the total dataloader progress. @@ -177,8 +176,7 @@ class DataLoaderProgress(Progress): class BatchProgress(Progress): """Tracks batch progress. - These counters are local to a trainer rank. By default, they are not globally - synced across all ranks. + These counters are local to a trainer rank. By default, they are not globally synced across all ranks. Args: total: Tracks the total dataloader progress. @@ -197,8 +195,7 @@ def reset_on_restart(self) -> None: class SchedulerProgress(Progress): """Tracks scheduler progress. - These counters are local to a trainer rank. By default, they are not globally - synced across all ranks. + These counters are local to a trainer rank. By default, they are not globally synced across all ranks. Args: total: Tracks the total scheduler progress. From 181b6824d2338610cc0a8e8f47cc571380740016 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 23 Sep 2021 02:26:01 +0200 Subject: [PATCH 03/11] Update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 954a83d2f034f..9452afd06eccc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)) * Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371)) * Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561)) + * Integrate `TrainingEpochLoop.is_last_batch` ([#9657](https://github.com/PyTorchLightning/pytorch-lightning/pull/9657)) - Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628)) From 36d424ac9648dc8f2a6de42facffffb32c9196da Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 23 Sep 2021 02:27:01 +0200 Subject: [PATCH 04/11] Update CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9452afd06eccc..a8ecf5306db7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,10 +24,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Progress tracking * Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)) + * Add `BatchProgress` and integrate `TrainingEpochLoop.is_last_batch` ([#9657](https://github.com/PyTorchLightning/pytorch-lightning/pull/9657)) * Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)) * Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371)) * Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561)) - * Integrate `TrainingEpochLoop.is_last_batch` ([#9657](https://github.com/PyTorchLightning/pytorch-lightning/pull/9657)) - Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628)) From 2353a565dd7bb890b66cd85ffebeb916ec512803 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 23 Sep 2021 02:53:04 +0200 Subject: [PATCH 05/11] Fix tests --- tests/loops/test_loop_state_dict.py | 1 + tests/loops/test_loops.py | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index ad5d0159036d5..0459e0033e46a 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -51,6 +51,7 @@ def test_loops_state_dict_structure(): "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "is_last_batch": False, }, "epoch_loop.scheduler_progress": { "total": {"ready": 0, "completed": 0}, diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 47145a2f8f408..2f7cc233d6831 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -20,7 +20,9 @@ import pytest import torch +from torch.utils.data import DataLoader +from pl_examples.bug_report_model import RandomDataset from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loops import Loop, TrainingBatchLoop @@ -443,6 +445,7 @@ def configure_optimizers_multiple(self): "processed": stop_batch, "completed": stop_batch, }, + "is_last_batch": False, }, "epoch_loop.scheduler_progress": { "total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps}, @@ -548,13 +551,15 @@ def configure_optimizers_multiple(self): return optimizers, lr_schedulers + def train_dataloader(self): + return DataLoader(RandomDataset(32, n_batches)) + model = TestModel() model.training_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, max_epochs=n_epochs, - limit_train_batches=n_batches, limit_val_batches=0, accumulate_grad_batches=accumulate_grad_batches, progress_bar_refresh_rate=0, @@ -563,6 +568,8 @@ def configure_optimizers_multiple(self): ) trainer.fit(model) + assert trainer.num_training_batches == n_batches + ckpt_path = trainer.checkpoint_callback.best_model_path assert os.path.exists(ckpt_path) checkpoint = torch.load(ckpt_path) @@ -607,6 +614,7 @@ def configure_optimizers_multiple(self): "processed": n_batches, "completed": n_batches, }, + "is_last_batch": True, }, "epoch_loop.scheduler_progress": { "total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total}, From 25e764b3133d0be77fe1b01b917f6432f5f789c0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 23 Sep 2021 03:10:31 +0200 Subject: [PATCH 06/11] Use `reset_on_epoch` --- .../loops/epoch/training_epoch_loop.py | 4 ++-- pytorch_lightning/trainer/progress.py | 22 ++++++++++++------- tests/loops/test_loops.py | 1 + 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index a559b45e0b297..f829c20e557b1 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -100,8 +100,8 @@ def reset(self) -> None: self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] if not self.restarting or self._num_training_batches_reached(): - self.batch_progress.current.reset() - self.scheduler_progress.current.reset() + self.batch_progress.reset_on_epoch() + self.scheduler_progress.reset_on_epoch() self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch() def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None: diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 37db5e55caae9..064e5a8e07c04 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -153,6 +153,9 @@ def load_state_dict(self, state_dict: dict) -> None: self.total.load_state_dict(state_dict["total"]) self.current.load_state_dict(state_dict["current"]) + def reset_on_epoch(self) -> None: + self.current.reset() + def reset_on_restart(self) -> None: self.current.reset_on_restart() @@ -186,8 +189,8 @@ class BatchProgress(Progress): is_last_batch: bool = False - def reset_on_restart(self) -> None: - super().reset_on_restart() + def reset_on_epoch(self) -> None: + super().reset_on_epoch() self.is_last_batch = False @@ -219,8 +222,12 @@ class OptimizerProgress(BaseProgress): zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker)) def reset_on_epoch(self) -> None: - self.step.current.reset() - self.zero_grad.current.reset() + self.step.reset_on_epoch() + self.zero_grad.reset_on_epoch() + + def reset_on_restart(self): + self.step.reset_on_restart() + self.zero_grad.reset_on_restart() def load_state_dict(self, state_dict: dict) -> None: self.step.load_state_dict(state_dict["step"]) @@ -250,10 +257,9 @@ def optimizer_steps(self) -> int: def reset_on_epoch(self) -> None: self.optimizer.reset_on_epoch() + def reset_on_restart(self) -> None: + self.optimizer.reset_on_restart() + def load_state_dict(self, state_dict: dict) -> None: self.optimizer.load_state_dict(state_dict["optimizer"]) self.optimizer_position = state_dict["optimizer_position"] - - def reset_on_restart(self) -> None: - self.optimizer.step.current.reset_on_restart() - self.optimizer.zero_grad.current.reset_on_restart() diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 2f7cc233d6831..5d58f1dbfa941 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -552,6 +552,7 @@ def configure_optimizers_multiple(self): return optimizers, lr_schedulers def train_dataloader(self): + # override to test the `is_last_batch` value return DataLoader(RandomDataset(32, n_batches)) model = TestModel() From 187f97d4975c65293003cbdb76234903e047698c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 23 Sep 2021 03:24:19 +0200 Subject: [PATCH 07/11] Rename `reset_on_epoch` to `reset_on_run` --- .../loops/dataloader/dataloader_loop.py | 4 ++-- .../loops/epoch/evaluation_epoch_loop.py | 4 ++-- .../loops/epoch/prediction_epoch_loop.py | 2 +- .../loops/epoch/training_epoch_loop.py | 8 ++++---- pytorch_lightning/loops/fit_loop.py | 2 +- pytorch_lightning/trainer/progress.py | 16 ++++++++-------- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py index 00a5ee32b933b..8e0d57c782cab 100644 --- a/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -56,9 +56,9 @@ def done(self) -> bool: def reset(self) -> None: """Resets the internal state.""" if not self.restarting: - self.dataloader_progress.current.reset() + self.dataloader_progress.reset_on_run() else: - self.dataloader_progress.current.reset_on_restart() + self.dataloader_progress.reset_on_restart() def on_advance_start(self, *args: Any, **kwargs: Any) -> None: self.dataloader_progress.increment_ready() diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 3c9cf0d7172f2..3a7dd95acbf31 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -58,9 +58,9 @@ def reset(self) -> None: self.outputs = [] if not self.restarting: - self.batch_progress.current.reset() + self.batch_progress.reset_on_run() else: - self.batch_progress.current.reset_on_restart() + self.batch_progress.reset_on_restart() def on_run_start( self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index bd5a45089528d..58e65233dfe81 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -46,7 +46,7 @@ def reset(self) -> None: """Resets the loops internal state.""" self._all_batch_indices: List[int] = [] self.predictions: List[Any] = [] - self.batch_progress.current.reset() + self.batch_progress.reset_on_run() def on_run_start( self, diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index f829c20e557b1..727b166dacb90 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -93,16 +93,16 @@ def reset(self) -> None: assert self.batch_loop.optimizer_loop is not None if self.restarting: self.batch_progress.reset_on_restart() - self.scheduler_progress.current.reset_on_restart() + self.scheduler_progress.reset_on_restart() self.batch_loop.optimizer_loop.optim_progress.reset_on_restart() # track epoch output self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] if not self.restarting or self._num_training_batches_reached(): - self.batch_progress.reset_on_epoch() - self.scheduler_progress.reset_on_epoch() - self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch() + self.batch_progress.reset_on_run() + self.scheduler_progress.reset_on_run() + self.batch_loop.optimizer_loop.optim_progress.reset_on_run() def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None: # hook diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 9a4f7c510f303..c6da69feb65b2 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -175,7 +175,7 @@ def connect(self, epoch_loop: TrainingEpochLoop): def reset(self) -> None: """Resets the internal state of this loop.""" if self.restarting: - self.epoch_progress.current.reset_on_restart() + self.epoch_progress.reset_on_restart() def on_run_start(self) -> None: """Calls the ``on_train_start`` hook.""" diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 064e5a8e07c04..d3f976343c17b 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -153,7 +153,7 @@ def load_state_dict(self, state_dict: dict) -> None: self.total.load_state_dict(state_dict["total"]) self.current.load_state_dict(state_dict["current"]) - def reset_on_epoch(self) -> None: + def reset_on_run(self) -> None: self.current.reset() def reset_on_restart(self) -> None: @@ -189,8 +189,8 @@ class BatchProgress(Progress): is_last_batch: bool = False - def reset_on_epoch(self) -> None: - super().reset_on_epoch() + def reset_on_run(self) -> None: + super().reset_on_run() self.is_last_batch = False @@ -221,9 +221,9 @@ class OptimizerProgress(BaseProgress): step: Progress = field(default_factory=lambda: Progress.from_defaults(ReadyCompletedTracker)) zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker)) - def reset_on_epoch(self) -> None: - self.step.reset_on_epoch() - self.zero_grad.reset_on_epoch() + def reset_on_run(self) -> None: + self.step.reset_on_run() + self.zero_grad.reset_on_run() def reset_on_restart(self): self.step.reset_on_restart() @@ -254,8 +254,8 @@ class OptimizationProgress(BaseProgress): def optimizer_steps(self) -> int: return self.optimizer.step.total.completed - def reset_on_epoch(self) -> None: - self.optimizer.reset_on_epoch() + def reset_on_run(self) -> None: + self.optimizer.reset_on_run() def reset_on_restart(self) -> None: self.optimizer.reset_on_restart() From 9cc44ada235044cdba04f095ebb199dce487d9a2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 23 Sep 2021 03:27:47 +0200 Subject: [PATCH 08/11] Fix mypy --- pytorch_lightning/trainer/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index d3f976343c17b..34636ee4383f9 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -225,7 +225,7 @@ def reset_on_run(self) -> None: self.step.reset_on_run() self.zero_grad.reset_on_run() - def reset_on_restart(self): + def reset_on_restart(self) -> None: self.step.reset_on_restart() self.zero_grad.reset_on_restart() From 556635d423e179446b350d5aa790a94bff34f226 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 23 Sep 2021 15:16:37 +0200 Subject: [PATCH 09/11] Bad merge --- pytorch_lightning/trainer/progress.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 17a5ca610b603..946eed0c16f78 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -262,9 +262,6 @@ def reset_on_run(self) -> None: def reset_on_restart(self) -> None: self.optimizer.reset_on_restart() - def reset_on_restart(self) -> None: - self.optimizer.reset_on_restart() - def load_state_dict(self, state_dict: dict) -> None: self.optimizer.load_state_dict(state_dict["optimizer"]) self.optimizer_position = state_dict["optimizer_position"] From d352998f1c4654dd59815340b8de48299f64f375 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 23 Sep 2021 15:17:08 +0200 Subject: [PATCH 10/11] Bad merge --- tests/trainer/test_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 56384f939f8ef..7d565edb00608 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -244,8 +244,9 @@ def on_train_batch_start(self, *_): def on_train_batch_end(self, outputs, batch, batch_idx, *_): end_state_dict = self.state_dict() + is_last_batch = (batch_idx + 1) == self.trainer.num_training_batches - if self.trainer.is_last_batch or self.opt_step_called: + if is_last_batch or self.opt_step_called: assert self.check(self.start_state_dict, end_state_dict, equal=False) else: assert self.check(self.start_state_dict, end_state_dict) From 056e84b4d3b0c1c47eb407537159094fbec0e977 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 23 Sep 2021 15:18:25 +0200 Subject: [PATCH 11/11] Update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 45c1328193ab2..4e0f353bef10d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371)) * Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561)) * Use `completed` over `processed` in `reset_on_restart` ([#9656](https://github.com/PyTorchLightning/pytorch-lightning/pull/9656)) + * Rename `reset_on_epoch` to `reset_on_run` ([#9658](https://github.com/PyTorchLightning/pytorch-lightning/pull/9658)) - Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))