diff --git a/CHANGELOG.md b/CHANGELOG.md index a8afc6ed53ef5..bfcf2b252f869 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -116,7 +116,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `pytorch_lightning.core.lightning.LightningModule` in favor of `pytorch_lightning.core.module.LightningModule` ([#12740](https://github.com/PyTorchLightning/pytorch-lightning/pull/12740)) -- +- Deprecated `Trainer.reset_train_val_dataloaderrs()` in favor of `Trainer.reset_{train,val}_dataloader` ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184)) ### Removed @@ -235,6 +235,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed issue where the CLI could not pass a `Profiler` to the `Trainer` ([#13084](https://github.com/PyTorchLightning/pytorch-lightning/pull/13084)) +- Fixed logging on step level for eval mode ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184)) + + - diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 7f4f4bd12365c..85f8ce8a9c47b 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -234,10 +234,15 @@ def _get_max_batches(self) -> List[int]: def _reload_evaluation_dataloaders(self) -> None: """Reloads dataloaders if necessary.""" + dataloaders = None if self.trainer.testing: self.trainer.reset_test_dataloader() + dataloaders = self.trainer.test_dataloaders elif self.trainer.val_dataloaders is None or self.trainer._data_connector._should_reload_val_dl: self.trainer.reset_val_dataloader() + dataloaders = self.trainer.val_dataloaders + if dataloaders is not None: + self.epoch_loop._reset_dl_batch_idx(len(dataloaders)) def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: """Runs ``on_{validation/test}_start`` hooks.""" diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index f6e49fa310a24..9317546e0c0ee 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -49,6 +49,7 @@ def __init__(self) -> None: self._dl_max_batches = 0 self._data_fetcher: Optional[AbstractDataFetcher] = None self._dataloader_state_dict: Dict[str, Any] = {} + self._dl_batch_idx = [0] @property def done(self) -> bool: @@ -150,7 +151,10 @@ def advance( # type: ignore[override] self.batch_progress.increment_completed() # log batch metrics - self.trainer._logger_connector.update_eval_step_metrics() + if not self.trainer.sanity_checking: + dataloader_idx = kwargs.get("dataloader_idx", 0) + self.trainer._logger_connector.update_eval_step_metrics(self._dl_batch_idx[dataloader_idx]) + self._dl_batch_idx[dataloader_idx] += 1 # track epoch level outputs if self._should_track_batch_outputs_for_epoch_end() and output is not None: @@ -301,3 +305,6 @@ def _should_track_batch_outputs_for_epoch_end(self) -> bool: if self.trainer.testing: return is_overridden("test_epoch_end", model) return is_overridden("validation_epoch_end", model) + + def _reset_dl_batch_idx(self, num_dataloaders: int) -> None: + self._dl_batch_idx = [0] * num_dataloaders diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 40334387c0688..ab14a7aec23cf 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -204,8 +204,9 @@ def on_run_start(self) -> None: # type: ignore[override] if not self._iteration_based_training(): self.epoch_progress.current.completed = self.epoch_progress.current.processed - # reset train dataloader and val dataloader - self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module) + self.trainer.reset_train_dataloader(self.trainer.lightning_module) + # reload the evaluation dataloaders too for proper display in the progress bar + self.epoch_loop.val_loop._reload_evaluation_dataloaders() data_fetcher_cls = _select_data_fetcher(self.trainer) self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index db34f0a77e71d..b6e8d425dbb02 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -19,7 +19,6 @@ from pytorch_lightning.loggers import Logger, TensorBoardLogger from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT -from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.model_helpers import is_overridden @@ -29,8 +28,6 @@ class LoggerConnector: def __init__(self, trainer: "pl.Trainer") -> None: self.trainer = trainer - self._val_log_step: int = 0 - self._test_log_step: int = 0 self._progress_bar_metrics: _PBAR_DICT = {} self._logged_metrics: _OUT_DICT = {} self._callback_metrics: _OUT_DICT = {} @@ -116,35 +113,15 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: Evaluation metric updates """ - @property - def _eval_log_step(self) -> Optional[int]: - if self.trainer.state.stage is RunningStage.VALIDATING: - return self._val_log_step - if self.trainer.state.stage is RunningStage.TESTING: - return self._test_log_step - return None - - def _increment_eval_log_step(self) -> None: - if self.trainer.state.stage is RunningStage.VALIDATING: - self._val_log_step += 1 - elif self.trainer.state.stage is RunningStage.TESTING: - self._test_log_step += 1 - def _evaluation_epoch_end(self) -> None: results = self.trainer._results assert results is not None results.dataloader_idx = None - def update_eval_step_metrics(self) -> None: + def update_eval_step_metrics(self, step: int) -> None: assert not self._epoch_end_reached - if self.trainer.sanity_checking: - return - # logs user requested information to logger - self.log_metrics(self.metrics["log"], step=self._eval_log_step) - - # increment the step even if nothing was logged - self._increment_eval_log_step() + self.log_metrics(self.metrics["log"], step=step) def update_eval_epoch_metrics(self) -> _OUT_DICT: assert self._epoch_end_reached diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3b298911ed209..a7592ffaac1af 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1950,7 +1950,15 @@ def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = No Args: model: The ``LightningModule`` if called outside of the trainer scope. + + .. deprecated:: v1.7 + This method is deprecated in v1.7 and will be removed in v1.9. + Please use ``Trainer.reset_{train,val}_dataloader`` instead. """ + rank_zero_deprecation( + "`Trainer.reset_train_val_dataloaders` has been deprecated in v1.7 and will be removed in v1.9." + " Use `Trainer.reset_{train,val}_dataloader` instead" + ) if self.train_dataloader is None: self.reset_train_dataloader(model=model) if self.val_dataloaders is None: diff --git a/tests/deprecated_api/test_remove_1-9.py b/tests/deprecated_api/test_remove_1-9.py index 5d7b0b6260e3b..0ad4d7a7410fd 100644 --- a/tests/deprecated_api/test_remove_1-9.py +++ b/tests/deprecated_api/test_remove_1-9.py @@ -17,6 +17,7 @@ import pytest import pytorch_lightning.loggers.base as logger_base +from pytorch_lightning import Trainer from pytorch_lightning.core.module import LightningModule from pytorch_lightning.utilities.cli import LightningCLI from pytorch_lightning.utilities.rank_zero import rank_zero_only @@ -106,6 +107,12 @@ def test_old_callback_path(): from pytorch_lightning.callbacks.base import Callback with pytest.deprecated_call( - match="pytorch_lightning.callbacks.base.Callback has been deprecated in v1.7" " and will be removed in v1.9." + match="pytorch_lightning.callbacks.base.Callback has been deprecated in v1.7 and will be removed in v1.9." ): Callback() + + +def test_deprecated_dataloader_reset(): + trainer = Trainer() + with pytest.deprecated_call(match="reset_train_val_dataloaders` has been deprecated in v1.7"): + trainer.reset_train_val_dataloaders() diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 176912d67950c..9f94b38b5c0f1 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -973,3 +973,62 @@ def test_rich_print_results(inputs, expected): EvaluationLoop._print_results(*inputs) expected = expected[1:] # remove the initial line break from the """ string assert capture.get() == expected.lstrip() + + +@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics") +@pytest.mark.parametrize("num_dataloaders", (1, 2)) +def test_eval_step_logging(mock_log_metrics, tmpdir, num_dataloaders): + """Test that eval step during fit/validate/test is updated correctly.""" + + class CustomBoringModel(BoringModel): + def validation_step(self, batch, batch_idx, dataloader_idx=None): + self.log(f"val_log_{self.trainer.state.fn}", batch_idx, on_step=True, on_epoch=False) + + def test_step(self, batch, batch_idx, dataloader_idx=None): + self.log("test_log", batch_idx, on_step=True, on_epoch=False) + + def val_dataloader(self): + return [super().val_dataloader()] * num_dataloaders + + def test_dataloader(self): + return [super().test_dataloader()] * num_dataloaders + + validation_epoch_end = None + test_epoch_end = None + + limit_batches = 4 + max_epochs = 3 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=max_epochs, + limit_train_batches=1, + limit_val_batches=limit_batches, + limit_test_batches=limit_batches, + ) + model = CustomBoringModel() + + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + + def get_suffix(dl_idx): + return f"/dataloader_idx_{dl_idx}" if num_dataloaders == 2 else "" + + eval_steps = range(limit_batches) + fit_calls = [ + call(metrics={f"val_log_fit{get_suffix(dl_idx)}": float(step)}, step=step + (limit_batches * epoch)) + for epoch in range(max_epochs) + for dl_idx in range(num_dataloaders) + for step in eval_steps + ] + validate_calls = [ + call(metrics={f"val_log_validate{get_suffix(dl_idx)}": float(val)}, step=val) + for dl_idx in range(num_dataloaders) + for val in eval_steps + ] + test_calls = [ + call(metrics={f"test_log{get_suffix(dl_idx)}": float(val)}, step=val) + for dl_idx in range(num_dataloaders) + for val in eval_steps + ] + assert mock_log_metrics.mock_calls == fit_calls + validate_calls + test_calls diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 115d78e24f9ae..27667f7b2e043 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1243,7 +1243,7 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir): assert tracker.mock_calls == [ call.reset_val_dataloader(), - call.reset_train_dataloader(model=model), + call.reset_train_dataloader(model), call.reset_test_dataloader(), ]