From 4740be26b7ee7d4f6b572da1223405e138d60d3b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 1 Oct 2021 04:10:58 +0200 Subject: [PATCH 01/11] Fix `self.log(on_epoch=True)` on_batch_start --- .../loops/batch/training_batch_loop.py | 53 ++++++++----------- .../logging_/test_train_loop_logging.py | 12 +++++ 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index faf6966ca4c2f..3bbaddc3cc627 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -44,6 +44,7 @@ def __init__(self) -> None: self._outputs: _OUTPUTS_TYPE = [] self._warning_cache: WarningCache = WarningCache() self._remaining_splits: Optional[List[Any]] = None + self._exit_signal: int = 0 @property def done(self) -> bool: @@ -58,35 +59,6 @@ def connect( if manual_loop is not None: self.manual_loop = manual_loop - def run(self, batch: Any, batch_idx: int) -> AttributeDict: - """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks. - - Args: - batch: the current batch to run the train step on - batch_idx: the index of the current batch - """ - if batch is None: - self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") - return AttributeDict(signal=0, outputs=[]) - - # hook - self.trainer.logger_connector.on_batch_start() - response = self.trainer.call_hook("on_batch_start") - if response == -1: - return AttributeDict(signal=-1) - - # hook - response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0) - if response == -1: - return AttributeDict(signal=-1) - - self.trainer.fit_loop.epoch_loop.batch_progress.increment_started() - - super().run(batch, batch_idx) - - output, self._outputs = AttributeDict(signal=0, outputs=self._outputs), None # free memory - return output - def reset(self) -> None: """Resets the loop state.""" self._outputs = [] @@ -108,13 +80,31 @@ def advance(self, batch, batch_idx): batch: the current batch to run the training on (this is not the split!) batch_idx: the index of the current batch """ - void(batch) + if batch is None: + self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") + raise StopIteration + split_idx, split_batch = self._remaining_splits.pop(0) self.split_idx = split_idx # let logger connector extract current batch size self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) + # hook + self.trainer.logger_connector.on_batch_start() + response = self.trainer.call_hook("on_batch_start") + if response == -1: + self._exit_signal = -1 + raise StopIteration + + # hook + response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0) + if response == -1: + self._exit_signal = -1 + raise StopIteration + + self.trainer.fit_loop.epoch_loop.batch_progress.increment_started() + # choose which loop will run the optimization if self.trainer.lightning_module.automatic_optimization: optimizers = _get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx) @@ -131,6 +121,9 @@ def on_run_end(self) -> None: self.optimizer_loop._hiddens = None # this is not necessary as the manual loop runs for only 1 iteration, but just in case self.manual_loop._hiddens = None + output, self._outputs = AttributeDict(signal=self._exit_signal, outputs=self._outputs), None # free memory + self._exit_signal = 0 + return output def teardown(self) -> None: # release memory diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 3c72a78331720..5b548e1301dee 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -276,11 +276,21 @@ def on_train_epoch_start(self, _, pl_module): pl_module, "on_train_epoch_start", on_steps=self.choices, on_epochs=[True], prob_bars=self.choices ) + def on_batch_start(self, _, pl_module, *__): + self.make_logging( + pl_module, "on_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + ) + def on_batch_end(self, _, pl_module): self.make_logging( pl_module, "on_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) + def on_train_batch_start(self, _, pl_module, *__): + self.make_logging( + pl_module, "on_train_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + ) + def on_train_batch_end(self, _, pl_module, *__): self.make_logging( pl_module, "on_train_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices @@ -323,7 +333,9 @@ def training_step(self, batch, batch_idx): "on_train_start": 1, "on_epoch_start": 1, "on_train_epoch_start": 1, + "on_train_batch_start": 2, "on_train_batch_end": 2, + "on_batch_start": 2, "on_batch_end": 2, "on_train_epoch_end": 1, "on_epoch_end": 1, From b22696b5ce9a688c06b1a5c2e8dc91f87b3cd7da Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 1 Oct 2021 04:22:12 +0200 Subject: [PATCH 02/11] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 774064191714b..96392e605d437 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -436,6 +436,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691)) +- Fixed `self.log(on_epoch=True)` for the `on_batch_start` and `on_train_batch_start` hooks ([#9780](https://github.com/PyTorchLightning/pytorch-lightning/pull/9780)) + + ## [1.4.9] - 2021-09-30 - Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704)) From 7bf64ab64a3df87f73c3edcaa812370915e547d2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 1 Oct 2021 05:13:08 +0200 Subject: [PATCH 03/11] Remove signal --- .../loops/batch/training_batch_loop.py | 33 +++---------------- .../loops/epoch/evaluation_epoch_loop.py | 4 +-- .../loops/epoch/training_epoch_loop.py | 27 +++++++++++---- .../logger_connector/logger_connector.py | 10 +++--- tests/loops/test_evaluation_loop_flow.py | 8 ++--- tests/loops/test_training_loop_flow_scalar.py | 11 ++----- 6 files changed, 36 insertions(+), 57 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 3bbaddc3cc627..f0a6366ac4bae 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -23,8 +23,6 @@ from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop from pytorch_lightning.loops.utilities import _get_active_optimizers from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import AttributeDict -from pytorch_lightning.utilities.warnings import WarningCache _OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] @@ -42,9 +40,7 @@ def __init__(self) -> None: self.manual_loop = ManualOptimization() self._outputs: _OUTPUTS_TYPE = [] - self._warning_cache: WarningCache = WarningCache() self._remaining_splits: Optional[List[Any]] = None - self._exit_signal: int = 0 @property def done(self) -> bool: @@ -80,30 +76,10 @@ def advance(self, batch, batch_idx): batch: the current batch to run the training on (this is not the split!) batch_idx: the index of the current batch """ - if batch is None: - self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") - raise StopIteration - - split_idx, split_batch = self._remaining_splits.pop(0) - self.split_idx = split_idx + self.split_idx, split_batch = self._remaining_splits.pop(0) # let logger connector extract current batch size - self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) - - # hook - self.trainer.logger_connector.on_batch_start() - response = self.trainer.call_hook("on_batch_start") - if response == -1: - self._exit_signal = -1 - raise StopIteration - - # hook - response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0) - if response == -1: - self._exit_signal = -1 - raise StopIteration - - self.trainer.fit_loop.epoch_loop.batch_progress.increment_started() + self.trainer.logger_connector.on_train_split_start(self.split_idx, split_batch) # choose which loop will run the optimization if self.trainer.lightning_module.automatic_optimization: @@ -117,12 +93,11 @@ def advance(self, batch, batch_idx): # then `advance` doesn't finish and an empty dict is returned self._outputs.append(outputs) - def on_run_end(self) -> None: + def on_run_end(self) -> _OUTPUTS_TYPE: self.optimizer_loop._hiddens = None # this is not necessary as the manual loop runs for only 1 iteration, but just in case self.manual_loop._hiddens = None - output, self._outputs = AttributeDict(signal=self._exit_signal, outputs=self._outputs), None # free memory - self._exit_signal = 0 + output, self._outputs = self._outputs, None # free memory return output def teardown(self) -> None: diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 2c7f829bb86e9..1ca3bbb9c670b 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -186,10 +186,10 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: Raises: AssertionError: If the number of dataloaders is None (has not yet been set). """ - self.trainer.logger_connector.on_batch_start() + self.trainer.logger_connector.on_batch_start(batch_idx) assert self._num_dataloaders is not None - self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self._num_dataloaders) + self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self._num_dataloaders) if self.trainer.testing: self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index d1c1000e6fc07..86b95068fdb53 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -26,6 +26,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.warnings import WarningCache _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] @@ -55,6 +56,7 @@ def __init__(self, min_steps: int, max_steps: int): self._results = ResultCollection(training=True) self._outputs: _OUTPUTS_TYPE = [] + self._warning_cache = WarningCache() @property def total_batch_idx(self) -> int: @@ -134,21 +136,34 @@ def advance(self, *args: Any, **kwargs: Any) -> None: batch_idx, (batch, is_last) = next(self.dataloader_iter) self.batch_progress.is_last_batch = is_last + if batch is None: + self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") + raise StopIteration + if not self.trainer.data_connector.train_data_fetcher.store_on_device: with self.trainer.profiler.profile("training_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch) self.batch_progress.increment_ready() + # hook + self.trainer.logger_connector.on_batch_start(batch_idx) + response = self.trainer.call_hook("on_batch_start") + if response == -1: + raise StopIteration + + # hook + response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0) + if response == -1: + raise StopIteration + + self.batch_progress.increment_started() + with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, batch_idx) self.batch_progress.increment_processed() - # when returning -1 from train_step, we end epoch early - if batch_output.signal == -1: - raise StopIteration - # update non-plateau LR schedulers # update epoch-interval ones only when we are at the end of training epoch self.update_lr_schedulers("step", update_plateau_schedulers=False) @@ -156,7 +171,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.update_lr_schedulers("epoch", update_plateau_schedulers=False) batch_end_outputs = self._prepare_outputs_training_batch_end( - batch_output.outputs, + batch_output, automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization, num_optimizers=len(self.trainer.optimizers), ) @@ -167,7 +182,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.batch_progress.increment_completed() if is_overridden("training_epoch_end", self.trainer.lightning_module): - self._outputs.append(batch_output.outputs) + self._outputs.append(batch_output) # ----------------------------------------- # SAVE METRICS TO LOGGERS AND PROGRESS_BAR diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 85323e92dc7e5..a2404a2d69015 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -132,7 +132,7 @@ def _increment_eval_log_step(self) -> None: elif self.trainer.state.stage is RunningStage.TESTING: self._test_log_step += 1 - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: + def on_evaluation_batch_start(self, batch: Any, dataloader_idx: int, num_dataloaders: int) -> None: model = self.trainer.lightning_module # set dataloader_idx only if multiple ones model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None @@ -140,7 +140,6 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: # track batch_size assert self.trainer._results is not None self.trainer._results.extract_batch_size(batch) - self._batch_idx = batch_idx def update_eval_step_metrics(self) -> None: if self.trainer.sanity_checking: @@ -207,14 +206,12 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]: Train metric updates """ - def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: + def on_train_split_start(self, split_idx: int, split_batch: Any) -> None: assert self.trainer._results is not None # when the user requests `dataloader_iter`, we can't track the batch_size # and this is left to user responsibility. if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher): self.trainer._results.extract_batch_size(split_batch) - - self._batch_idx = batch_idx self._split_idx = split_idx def update_train_step_metrics(self) -> None: @@ -255,7 +252,8 @@ def _log_gpus_metrics(self) -> None: def on_epoch_start(self) -> None: self._epoch_end_reached = False - def on_batch_start(self) -> None: + def on_batch_start(self, batch_idx: int) -> None: + self._batch_idx = batch_idx self._epoch_end_reached = False def epoch_end_reached(self) -> None: diff --git a/tests/loops/test_evaluation_loop_flow.py b/tests/loops/test_evaluation_loop_flow.py index d927262021d82..bffb18061f47e 100644 --- a/tests/loops/test_evaluation_loop_flow.py +++ b/tests/loops/test_evaluation_loop_flow.py @@ -64,10 +64,8 @@ def backward(self, loss, optimizer, optimizer_idx): # simulate training manually trainer.state.stage = RunningStage.TRAINING batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - assert out.signal == 0 + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - train_step_out = out.outputs assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out["loss"], torch.Tensor) @@ -129,10 +127,8 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - assert out.signal == 0 + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - train_step_out = out.outputs assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out["loss"], torch.Tensor) diff --git a/tests/loops/test_training_loop_flow_scalar.py b/tests/loops/test_training_loop_flow_scalar.py index 4f64a906646ba..ea4c404947a71 100644 --- a/tests/loops/test_training_loop_flow_scalar.py +++ b/tests/loops/test_training_loop_flow_scalar.py @@ -147,10 +147,8 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - assert out.signal == 0 + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - train_step_out = out.outputs assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out["loss"], torch.Tensor) @@ -221,10 +219,8 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - assert out.signal == 0 + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - train_step_out = out.outputs assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out["loss"], torch.Tensor) @@ -311,8 +307,7 @@ def training_step(self, batch, batch_idx): for batch_idx, batch in enumerate(model.train_dataloader()): out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) if not batch_idx % 2: - assert out.outputs == [] - assert out.signal == 0 + assert out == [] def test_training_step_none_batches(tmpdir): From 4a3253b175ef425e14175694602dc5241a62ed5e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 1 Oct 2021 05:46:34 +0200 Subject: [PATCH 04/11] Fix test --- .../loops/batch/training_batch_loop.py | 1 + .../loops/epoch/training_epoch_loop.py | 40 ++++++++++--------- tests/loops/test_training_loop_flow_scalar.py | 20 ++++------ 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index f0a6366ac4bae..c1d800c42d853 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -76,6 +76,7 @@ def advance(self, batch, batch_idx): batch: the current batch to run the training on (this is not the split!) batch_idx: the index of the current batch """ + void(batch) self.split_idx, split_batch = self._remaining_splits.pop(0) # let logger connector extract current batch size diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 86b95068fdb53..6e42c319f212f 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -136,31 +136,33 @@ def advance(self, *args: Any, **kwargs: Any) -> None: batch_idx, (batch, is_last) = next(self.dataloader_iter) self.batch_progress.is_last_batch = is_last - if batch is None: - self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") - raise StopIteration - if not self.trainer.data_connector.train_data_fetcher.store_on_device: with self.trainer.profiler.profile("training_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch) self.batch_progress.increment_ready() - # hook - self.trainer.logger_connector.on_batch_start(batch_idx) - response = self.trainer.call_hook("on_batch_start") - if response == -1: - raise StopIteration - - # hook - response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0) - if response == -1: - raise StopIteration - - self.batch_progress.increment_started() - - with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.batch_loop.run(batch, batch_idx) + if batch is None: + self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") + batch_output = [] + else: + # hook + self.trainer.logger_connector.on_batch_start(batch_idx) + response = self.trainer.call_hook("on_batch_start") + if response == -1: + self.batch_progress.increment_processed() + raise StopIteration + + # hook + response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0) + if response == -1: + self.batch_progress.increment_processed() + raise StopIteration + + self.batch_progress.increment_started() + + with self.trainer.profiler.profile("run_training_batch"): + batch_output = self.batch_loop.run(batch, batch_idx) self.batch_progress.increment_processed() diff --git a/tests/loops/test_training_loop_flow_scalar.py b/tests/loops/test_training_loop_flow_scalar.py index ea4c404947a71..7c44846f12d29 100644 --- a/tests/loops/test_training_loop_flow_scalar.py +++ b/tests/loops/test_training_loop_flow_scalar.py @@ -316,7 +316,6 @@ def test_training_step_none_batches(tmpdir): class TestModel(BoringModel): def __init__(self): super().__init__() - self.counter = 0 def collate_none_when_even(self, batch): @@ -328,27 +327,24 @@ def collate_none_when_even(self, batch): return result def train_dataloader(self): - return DataLoader(RandomDataset(32, 64), collate_fn=self.collate_none_when_even) + return DataLoader(RandomDataset(32, 4), collate_fn=self.collate_none_when_even) + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + if batch_idx % 2 == 0: + assert outputs == [] + else: + assert outputs model = TestModel() trainer = Trainer( default_root_dir=tmpdir, - limit_train_batches=4, limit_val_batches=1, max_epochs=4, weights_summary=None, logger=False, checkpoint_callback=False, + progress_bar_refresh_rate=0, ) with pytest.warns(UserWarning, match=r".*train_dataloader yielded None.*"): trainer.fit(model) - - trainer.state.stage = RunningStage.TRAINING - - # manually check a few batches - for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - if not batch_idx % 2: - assert out.outputs == [] - assert out.signal == 0 From 9d2798abf8892e00b3538b4b89d48ebcab56599f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 1 Oct 2021 17:14:20 +0200 Subject: [PATCH 05/11] Fix `self.log(on_epoch=True, reduce_fx=sum)` on_batch_start --- .../loops/epoch/evaluation_epoch_loop.py | 4 ++-- .../loops/epoch/training_epoch_loop.py | 7 +++++- .../logger_connector/logger_connector.py | 22 +++++++++---------- .../connectors/logger_connector/result.py | 11 ++++++---- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 1ca3bbb9c670b..dcd1cb2d1cc17 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -186,10 +186,10 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: Raises: AssertionError: If the number of dataloaders is None (has not yet been set). """ - self.trainer.logger_connector.on_batch_start(batch_idx) + self.trainer.logger_connector.on_batch_start(batch_idx, batch) assert self._num_dataloaders is not None - self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self._num_dataloaders) + self.trainer.logger_connector.on_evaluation_batch_start(dataloader_idx, self._num_dataloaders) if self.trainer.testing: self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 6e42c319f212f..f4a89e08688e8 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -142,12 +142,15 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.batch_progress.increment_ready() + # cache the batch size value to avoid extracting it again after the batch loop runs as the value will be + # different if tbptt is enabled + batch_size = self.trainer.logger_connector.on_batch_start(batch_idx, batch) + if batch is None: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") batch_output = [] else: # hook - self.trainer.logger_connector.on_batch_start(batch_idx) response = self.trainer.call_hook("on_batch_start") if response == -1: self.batch_progress.increment_processed() @@ -164,6 +167,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None: with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, batch_idx) + self.trainer._results.batch_size = batch_size + self.batch_progress.increment_processed() # update non-plateau LR schedulers diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a2404a2d69015..7c0b475bc63bc 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -132,15 +132,11 @@ def _increment_eval_log_step(self) -> None: elif self.trainer.state.stage is RunningStage.TESTING: self._test_log_step += 1 - def on_evaluation_batch_start(self, batch: Any, dataloader_idx: int, num_dataloaders: int) -> None: + def on_evaluation_batch_start(self, dataloader_idx: int, num_dataloaders: int) -> None: model = self.trainer.lightning_module # set dataloader_idx only if multiple ones model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None - # track batch_size - assert self.trainer._results is not None - self.trainer._results.extract_batch_size(batch) - def update_eval_step_metrics(self) -> None: if self.trainer.sanity_checking: return @@ -207,12 +203,8 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]: """ def on_train_split_start(self, split_idx: int, split_batch: Any) -> None: - assert self.trainer._results is not None - # when the user requests `dataloader_iter`, we can't track the batch_size - # and this is left to user responsibility. - if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher): - self.trainer._results.extract_batch_size(split_batch) self._split_idx = split_idx + self.on_new_batch(split_batch) def update_train_step_metrics(self) -> None: if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization: @@ -249,12 +241,20 @@ def _log_gpus_metrics(self) -> None: Utilities and properties """ + def on_new_batch(self, batch: Any) -> int: + # when the user requests `dataloader_iter`, we can't track the batch_size + # and this is left to user responsibility. + if not isinstance(batch, pl.utilities.fetching.DataLoaderIterDataFetcher): + assert self.trainer._results is not None + return self.trainer._results.extract_batch_size(batch) + def on_epoch_start(self) -> None: self._epoch_end_reached = False - def on_batch_start(self, batch_idx: int) -> None: + def on_batch_start(self, batch_idx: int, batch: Any) -> int: self._batch_idx = batch_idx self._epoch_end_reached = False + return self.on_new_batch(batch) def epoch_end_reached(self) -> None: self._epoch_end_reached = True diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 2d3ba728d6933..d66f151e64ac7 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -204,7 +204,7 @@ def update(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None: elif self.meta.is_max_reduction or self.meta.is_min_reduction: self.value = self.meta.reduce_fx(self.value, value.mean()) elif self.meta.is_sum_reduction: - self.value += value.mean() * batch_size + self.value += value.mean() else: self.value = value self._forward_cache = value._forward_cache @@ -550,11 +550,14 @@ def fn(item: ResultMetric) -> None: apply_to_collection(self, ResultMetric, fn) - def extract_batch_size(self, batch: Any) -> None: + def extract_batch_size(self, batch: Any) -> int: + batch_size = 1 try: - self.batch_size = extract_batch_size(batch) + batch_size = extract_batch_size(batch) except RecursionError: - self.batch_size = 1 + pass + self.batch_size = batch_size # the setter converts it to `Tensor` + return batch_size def to(self, *args: Any, **kwargs: Any) -> "ResultCollection": """Move all data to the given device.""" From c377ce52210cded24059563dd9c9cd3f3c8b490b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 1 Oct 2021 17:17:50 +0200 Subject: [PATCH 06/11] Minor change --- .../trainer/connectors/logger_connector/result.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index d66f151e64ac7..f47cf8712db46 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -551,11 +551,10 @@ def fn(item: ResultMetric) -> None: apply_to_collection(self, ResultMetric, fn) def extract_batch_size(self, batch: Any) -> int: - batch_size = 1 try: batch_size = extract_batch_size(batch) except RecursionError: - pass + batch_size = 1 self.batch_size = batch_size # the setter converts it to `Tensor` return batch_size From a9743938f623bf4f626105e2109255170d766e5e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 18 Oct 2021 15:40:55 +0200 Subject: [PATCH 07/11] Bad merge --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 88dec4b607b14..fbe0ff2e9e2cd 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict -from typing import Any, Dict, Generator, List, Optional, overload, Tuple, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, overload, Tuple, Union import numpy as np import torch @@ -59,6 +59,9 @@ def __init__(self, min_steps: int, max_steps: int): self._results = ResultCollection(training=True) self._outputs: _OUTPUTS_TYPE = [] self._warning_cache = WarningCache() + self._dataloader_iter: Optional[Iterator] = None + # caches the loaded dataloader state until dataloader objects are available + self._dataloader_state_dict: Dict[str, Any] = {} @property def total_batch_idx(self) -> int: From b3f10cd3a4aaf83c7e87b1202bec9a471d8c4d0c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 18 Oct 2021 15:41:55 +0200 Subject: [PATCH 08/11] Bad merge --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b95acdf3adf1e..17771756bd22b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -539,7 +539,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `self.log(on_epoch=True)` for the `on_batch_start` and `on_train_batch_start` hooks ([#9780](https://github.com/PyTorchLightning/pytorch-lightning/pull/9780)) -ยบ- Fixed restoring training state during `trainer.fit` only ([#9413](https://github.com/PyTorchLightning/pytorch-lightning/pull/9413)) +- Fixed restoring training state during `trainer.fit` only ([#9413](https://github.com/PyTorchLightning/pytorch-lightning/pull/9413)) - Fixed DeepSpeed and Lightning both calling the scheduler ([#9788](https://github.com/PyTorchLightning/pytorch-lightning/pull/9788)) From 7e46bcb02f0c241efcf3e6399d6cec613ede3afe Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 18 Oct 2021 16:07:06 +0200 Subject: [PATCH 09/11] Add test --- .../logging_/test_train_loop_logging.py | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index f7f7190adb9bd..b4271099d3edb 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -20,13 +20,14 @@ import numpy as np import pytest import torch +from torch.utils.data import DataLoader from torchmetrics import Accuracy from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel, RandomDictDataset +from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset from tests.helpers.runif import RunIf @@ -721,3 +722,40 @@ def on_before_backward(self, loss: torch.Tensor) -> None: gpus=1, ) trainer.fit(TestModel()) + + +def test_on_epoch_logging_with_sum_and_on_batch_start(tmpdir): + class TestModel(BoringModel): + def on_train_epoch_end(self): + assert all(v == 3 for v in self.trainer.callback_metrics.values()) + + def on_validation_epoch_end(self): + assert all(v == 3 for v in self.trainer.callback_metrics.values()) + + def on_train_batch_start(self, batch, batch_idx): + assert self.trainer._results.batch_size == 2 + self.log("on_train_batch_start", 1.0, reduce_fx="sum") + + def on_train_batch_end(self, outputs, batch, batch_idx): + assert self.trainer._results.batch_size == 2 + self.log("on_train_batch_end", 1.0, reduce_fx="sum") + + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): + assert self.trainer._results.batch_size == 2 + self.log("on_validation_batch_start", 1.0, reduce_fx="sum") + + def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + assert self.trainer._results.batch_size == 2 + self.log("on_validation_batch_end", 1.0, reduce_fx="sum") + + model = TestModel() + trainer = Trainer( + enable_progress_bar=False, + limit_train_batches=3, + limit_val_batches=3, + num_sanity_val_steps=3, + max_epochs=1, + ) + train_data = DataLoader(RandomDataset(32, 64), batch_size=2) + val_data = DataLoader(RandomDataset(32, 64), batch_size=2) + trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) From 22ed37a8c7834c2c75ac2d5fa3effd1f39be7569 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 18 Oct 2021 18:08:16 +0200 Subject: [PATCH 10/11] Fix test --- .../trainer/connectors/logger_connector/logger_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index fbad223ea580b..e165a51ef21f2 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -256,9 +256,10 @@ def _log_gpus_metrics(self) -> None: def on_new_batch(self, batch: Any) -> int: # when the user requests `dataloader_iter`, we can't track the batch_size # and this is left to user responsibility. - if not isinstance(batch, pl.utilities.fetching.DataLoaderIterDataFetcher): + if not isinstance(batch, pl.utilities.fetching.StepFuncDataLoaderIter): assert self.trainer._results is not None return self.trainer._results.extract_batch_size(batch) + return 1 def on_epoch_start(self) -> None: self._epoch_end_reached = False From cfede9cb71e8ee919fbf46d712724593c3dc03fc Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 19 Oct 2021 15:04:02 +0200 Subject: [PATCH 11/11] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 17771756bd22b..0e96a7e954bd6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -536,6 +536,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691)) +- Fixed `self.log(on_epoch=True, reduce_fx=sum))` for the `on_batch_start` and `on_train_batch_start` hooks ([#9791(https://github.com/PyTorchLightning/pytorch-lightning/pull/9791)) + + - Fixed `self.log(on_epoch=True)` for the `on_batch_start` and `on_train_batch_start` hooks ([#9780](https://github.com/PyTorchLightning/pytorch-lightning/pull/9780))