diff --git a/CHANGELOG.md b/CHANGELOG.md index f809e66c6b7ad..43ebc6464ac51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -527,11 +527,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691)) +- Fixed `self.log(on_epoch=True)` for the `on_batch_start` and `on_train_batch_start` hooks ([#9780](https://github.com/PyTorchLightning/pytorch-lightning/pull/9780)) + + - Fixed restoring training state during `trainer.fit` only ([#9413](https://github.com/PyTorchLightning/pytorch-lightning/pull/9413)) - Fixed DeepSpeed and Lightning both calling the scheduler ([#9788](https://github.com/PyTorchLightning/pytorch-lightning/pull/9788)) + - Fixed missing arguments when saving hyperparameters from the parent class but not from the child class ([#9800](https://github.com/PyTorchLightning/pytorch-lightning/pull/9800)) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 93e156070d3d1..c1d800c42d853 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -23,9 +23,6 @@ from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop from pytorch_lightning.loops.utilities import _get_active_optimizers from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import AttributeDict -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.warnings import WarningCache _OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] @@ -43,7 +40,6 @@ def __init__(self) -> None: self.manual_loop = ManualOptimization() self._outputs: _OUTPUTS_TYPE = [] - self._warning_cache: WarningCache = WarningCache() self._remaining_splits: Optional[List[Any]] = None @property @@ -59,42 +55,6 @@ def connect( if manual_loop is not None: self.manual_loop = manual_loop - def run(self, batch: Any, batch_idx: int) -> AttributeDict: - """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks. - - Args: - batch: the current batch to run the train step on - batch_idx: the index of the current batch - """ - if batch is None: - self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") - return AttributeDict(signal=0, outputs=[]) - - # hook - self.trainer.logger_connector.on_batch_start() - response = self.trainer.call_hook("on_batch_start") - if response == -1: - return AttributeDict(signal=-1) - - # hook - # TODO: Update this in v1.7 (deprecation: #9816) - model_fx = self.trainer.lightning_module.on_train_batch_start - extra_kwargs = ( - {"dataloader_idx": 0} - if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True) - else {} - ) - response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs) - if response == -1: - return AttributeDict(signal=-1) - - self.trainer.fit_loop.epoch_loop.batch_progress.increment_started() - - super().run(batch, batch_idx) - - output, self._outputs = AttributeDict(signal=0, outputs=self._outputs), None # free memory - return output - def reset(self) -> None: """Resets the loop state.""" self._outputs = [] @@ -117,11 +77,10 @@ def advance(self, batch, batch_idx): batch_idx: the index of the current batch """ void(batch) - split_idx, split_batch = self._remaining_splits.pop(0) - self.split_idx = split_idx + self.split_idx, split_batch = self._remaining_splits.pop(0) # let logger connector extract current batch size - self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) + self.trainer.logger_connector.on_train_split_start(self.split_idx, split_batch) # choose which loop will run the optimization if self.trainer.lightning_module.automatic_optimization: @@ -135,10 +94,12 @@ def advance(self, batch, batch_idx): # then `advance` doesn't finish and an empty dict is returned self._outputs.append(outputs) - def on_run_end(self) -> None: + def on_run_end(self) -> _OUTPUTS_TYPE: self.optimizer_loop._hiddens = None # this is not necessary as the manual loop runs for only 1 iteration, but just in case self.manual_loop._hiddens = None + output, self._outputs = self._outputs, None # free memory + return output def teardown(self) -> None: # release memory diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 3e1b88a2d41c3..d666cc2ad0d59 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -233,10 +233,10 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: Raises: AssertionError: If the number of dataloaders is None (has not yet been set). """ - self.trainer.logger_connector.on_batch_start() + self.trainer.logger_connector.on_batch_start(batch_idx) assert self._num_dataloaders is not None - self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self._num_dataloaders) + self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self._num_dataloaders) if self.trainer.testing: self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index fe3a2dc7431cc..4cc8eaa811231 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -28,6 +28,7 @@ from pytorch_lightning.utilities.fetching import AbstractDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.warnings import WarningCache _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] @@ -57,6 +58,7 @@ def __init__(self, min_steps: int, max_steps: int): self._results = ResultCollection(training=True) self._outputs: _OUTPUTS_TYPE = [] + self._warning_cache = WarningCache() self._dataloader_iter: Optional[Iterator] = None # caches the loaded dataloader state until dataloader objects are available self._dataloader_state_dict: Dict[str, Any] = {} @@ -151,14 +153,37 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.batch_progress.increment_ready() - with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.batch_loop.run(batch, batch_idx) + if batch is None: + self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") + batch_output = [] + else: + # hook + self.trainer.logger_connector.on_batch_start(batch_idx) + response = self.trainer.call_hook("on_batch_start") + if response == -1: + self.batch_progress.increment_processed() + raise StopIteration + + # TODO: Update this in v1.7 (deprecation: #9816) + model_fx = self.trainer.lightning_module.on_train_batch_start + extra_kwargs = ( + {"dataloader_idx": 0} + if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True) + else {} + ) - self.batch_progress.increment_processed() + # hook + response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs) + if response == -1: + self.batch_progress.increment_processed() + raise StopIteration - # when returning -1 from train_step, we end epoch early - if batch_output.signal == -1: - raise StopIteration + self.batch_progress.increment_started() + + with self.trainer.profiler.profile("run_training_batch"): + batch_output = self.batch_loop.run(batch, batch_idx) + + self.batch_progress.increment_processed() # update non-plateau LR schedulers # update epoch-interval ones only when we are at the end of training epoch @@ -167,7 +192,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.update_lr_schedulers("epoch", update_plateau_schedulers=False) batch_end_outputs = self._prepare_outputs_training_batch_end( - batch_output.outputs, + batch_output, automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization, num_optimizers=len(self.trainer.optimizers), ) @@ -186,7 +211,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.batch_progress.increment_completed() if is_overridden("training_epoch_end", self.trainer.lightning_module): - self._outputs.append(batch_output.outputs) + self._outputs.append(batch_output) # ----------------------------------------- # SAVE METRICS TO LOGGERS AND PROGRESS_BAR diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 21684d6831a65..cb01e7edbc97a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -138,7 +138,7 @@ def _increment_eval_log_step(self) -> None: elif self.trainer.state.stage is RunningStage.TESTING: self._test_log_step += 1 - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: + def on_evaluation_batch_start(self, batch: Any, dataloader_idx: int, num_dataloaders: int) -> None: model = self.trainer.lightning_module # set dataloader_idx only if multiple ones model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None @@ -146,7 +146,6 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: # track batch_size assert self.trainer._results is not None self.trainer._results.extract_batch_size(batch) - self._batch_idx = batch_idx def update_eval_step_metrics(self) -> None: if self.trainer.sanity_checking: @@ -213,14 +212,12 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]: Train metric updates """ - def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: + def on_train_split_start(self, split_idx: int, split_batch: Any) -> None: assert self.trainer._results is not None # when the user requests `dataloader_iter`, we can't track the batch_size # and this is left to user responsibility. if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher): self.trainer._results.extract_batch_size(split_batch) - - self._batch_idx = batch_idx self._split_idx = split_idx def update_train_step_metrics(self) -> None: @@ -267,7 +264,8 @@ def _log_gpus_metrics(self) -> None: def on_epoch_start(self) -> None: self._epoch_end_reached = False - def on_batch_start(self) -> None: + def on_batch_start(self, batch_idx: int) -> None: + self._batch_idx = batch_idx self._epoch_end_reached = False def epoch_end_reached(self) -> None: diff --git a/tests/loops/test_evaluation_loop_flow.py b/tests/loops/test_evaluation_loop_flow.py index 5a9d0a737350c..0fe90557b3530 100644 --- a/tests/loops/test_evaluation_loop_flow.py +++ b/tests/loops/test_evaluation_loop_flow.py @@ -64,10 +64,8 @@ def backward(self, loss, optimizer, optimizer_idx): # simulate training manually trainer.state.stage = RunningStage.TRAINING batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - assert out.signal == 0 + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - train_step_out = out.outputs assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out["loss"], torch.Tensor) @@ -129,10 +127,8 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - assert out.signal == 0 + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - train_step_out = out.outputs assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out["loss"], torch.Tensor) diff --git a/tests/loops/test_training_loop_flow_scalar.py b/tests/loops/test_training_loop_flow_scalar.py index 0501cbdf529db..f7f539efef8cd 100644 --- a/tests/loops/test_training_loop_flow_scalar.py +++ b/tests/loops/test_training_loop_flow_scalar.py @@ -147,10 +147,8 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - assert out.signal == 0 + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - train_step_out = out.outputs assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out["loss"], torch.Tensor) @@ -221,10 +219,8 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - assert out.signal == 0 + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - train_step_out = out.outputs assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out["loss"], torch.Tensor) @@ -311,8 +307,7 @@ def training_step(self, batch, batch_idx): for batch_idx, batch in enumerate(model.train_dataloader()): out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) if not batch_idx % 2: - assert out.outputs == [] - assert out.signal == 0 + assert out == [] def test_training_step_none_batches(tmpdir): @@ -321,7 +316,6 @@ def test_training_step_none_batches(tmpdir): class TestModel(BoringModel): def __init__(self): super().__init__() - self.counter = 0 def collate_none_when_even(self, batch): @@ -333,12 +327,17 @@ def collate_none_when_even(self, batch): return result def train_dataloader(self): - return DataLoader(RandomDataset(32, 64), collate_fn=self.collate_none_when_even) + return DataLoader(RandomDataset(32, 4), collate_fn=self.collate_none_when_even) + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + if batch_idx % 2 == 0: + assert outputs == [] + else: + assert outputs model = TestModel() trainer = Trainer( default_root_dir=tmpdir, - limit_train_batches=4, limit_val_batches=1, max_epochs=4, enable_model_summary=False, @@ -348,12 +347,3 @@ def train_dataloader(self): with pytest.warns(UserWarning, match=r".*train_dataloader yielded None.*"): trainer.fit(model) - - trainer.state.stage = RunningStage.TRAINING - - # manually check a few batches - for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - if not batch_idx % 2: - assert out.outputs == [] - assert out.signal == 0 diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 5e9db17b2de62..f7f7190adb9bd 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -276,11 +276,21 @@ def on_train_epoch_start(self, _, pl_module): pl_module, "on_train_epoch_start", on_steps=self.choices, on_epochs=[True], prob_bars=self.choices ) + def on_batch_start(self, _, pl_module, *__): + self.make_logging( + pl_module, "on_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + ) + def on_batch_end(self, _, pl_module): self.make_logging( pl_module, "on_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) + def on_train_batch_start(self, _, pl_module, *__): + self.make_logging( + pl_module, "on_train_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + ) + def on_train_batch_end(self, _, pl_module, *__): self.make_logging( pl_module, "on_train_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices @@ -323,7 +333,9 @@ def training_step(self, batch, batch_idx): "on_train_start": 1, "on_epoch_start": 1, "on_train_epoch_start": 1, + "on_train_batch_start": 2, "on_train_batch_end": 2, + "on_batch_start": 2, "on_batch_end": 2, "on_train_epoch_end": 1, "on_epoch_end": 1,