Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `pytorch_lightning.core.lightning.LightningModule` in favor of `pytorch_lightning.core.module.LightningModule` ([#12740](https://github.com/PyTorchLightning/pytorch-lightning/pull/12740))


-
- Deprecated `Trainer.reset_train_val_dataloaderrs()` in favor of `Trainer.reset_{train,val}_dataloader` ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184))

### Removed

Expand Down Expand Up @@ -235,6 +235,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed issue where the CLI could not pass a `Profiler` to the `Trainer` ([#13084](https://github.com/PyTorchLightning/pytorch-lightning/pull/13084))


- Fixed logging on step level for eval mode ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184))


-


Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,15 @@ def _get_max_batches(self) -> List[int]:

def _reload_evaluation_dataloaders(self) -> None:
"""Reloads dataloaders if necessary."""
dataloaders = None
if self.trainer.testing:
self.trainer.reset_test_dataloader()
dataloaders = self.trainer.test_dataloaders
elif self.trainer.val_dataloaders is None or self.trainer._data_connector._should_reload_val_dl:
self.trainer.reset_val_dataloader()
dataloaders = self.trainer.val_dataloaders
if dataloaders is not None:
self.epoch_loop._reset_dl_batch_idx(len(dataloaders))

def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks."""
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self) -> None:
self._dl_max_batches = 0
self._data_fetcher: Optional[AbstractDataFetcher] = None
self._dataloader_state_dict: Dict[str, Any] = {}
self._dl_batch_idx = [0]

@property
def done(self) -> bool:
Expand Down Expand Up @@ -150,7 +151,10 @@ def advance( # type: ignore[override]
self.batch_progress.increment_completed()

# log batch metrics
self.trainer._logger_connector.update_eval_step_metrics()
if not self.trainer.sanity_checking:
dataloader_idx = kwargs.get("dataloader_idx", 0)
self.trainer._logger_connector.update_eval_step_metrics(self._dl_batch_idx[dataloader_idx])
self._dl_batch_idx[dataloader_idx] += 1

# track epoch level outputs
if self._should_track_batch_outputs_for_epoch_end() and output is not None:
Expand Down Expand Up @@ -301,3 +305,6 @@ def _should_track_batch_outputs_for_epoch_end(self) -> bool:
if self.trainer.testing:
return is_overridden("test_epoch_end", model)
return is_overridden("validation_epoch_end", model)

def _reset_dl_batch_idx(self, num_dataloaders: int) -> None:
self._dl_batch_idx = [0] * num_dataloaders
5 changes: 3 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,9 @@ def on_run_start(self) -> None: # type: ignore[override]
if not self._iteration_based_training():
self.epoch_progress.current.completed = self.epoch_progress.current.processed

# reset train dataloader and val dataloader
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
self.trainer.reset_train_dataloader(self.trainer.lightning_module)
# reload the evaluation dataloaders too for proper display in the progress bar
self.epoch_loop.val_loop._reload_evaluation_dataloaders()

data_fetcher_cls = _select_data_fetcher(self.trainer)
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from pytorch_lightning.loggers import Logger, TensorBoardLogger
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand All @@ -29,8 +28,6 @@
class LoggerConnector:
def __init__(self, trainer: "pl.Trainer") -> None:
self.trainer = trainer
self._val_log_step: int = 0
self._test_log_step: int = 0
self._progress_bar_metrics: _PBAR_DICT = {}
self._logged_metrics: _OUT_DICT = {}
self._callback_metrics: _OUT_DICT = {}
Expand Down Expand Up @@ -116,35 +113,15 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
Evaluation metric updates
"""

@property
def _eval_log_step(self) -> Optional[int]:
if self.trainer.state.stage is RunningStage.VALIDATING:
return self._val_log_step
if self.trainer.state.stage is RunningStage.TESTING:
return self._test_log_step
return None

def _increment_eval_log_step(self) -> None:
if self.trainer.state.stage is RunningStage.VALIDATING:
self._val_log_step += 1
elif self.trainer.state.stage is RunningStage.TESTING:
self._test_log_step += 1

def _evaluation_epoch_end(self) -> None:
results = self.trainer._results
assert results is not None
results.dataloader_idx = None

def update_eval_step_metrics(self) -> None:
def update_eval_step_metrics(self, step: int) -> None:
assert not self._epoch_end_reached
if self.trainer.sanity_checking:
return

# logs user requested information to logger
self.log_metrics(self.metrics["log"], step=self._eval_log_step)

# increment the step even if nothing was logged
self._increment_eval_log_step()
self.log_metrics(self.metrics["log"], step=step)

def update_eval_epoch_metrics(self) -> _OUT_DICT:
assert self._epoch_end_reached
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,7 +1950,15 @@ def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = No

Args:
model: The ``LightningModule`` if called outside of the trainer scope.

.. deprecated:: v1.7
This method is deprecated in v1.7 and will be removed in v1.9.
Please use ``Trainer.reset_{train,val}_dataloader`` instead.
"""
rank_zero_deprecation(
"`Trainer.reset_train_val_dataloaders` has been deprecated in v1.7 and will be removed in v1.9."
" Use `Trainer.reset_{train,val}_dataloader` instead"
)
if self.train_dataloader is None:
self.reset_train_dataloader(model=model)
if self.val_dataloaders is None:
Expand Down
9 changes: 8 additions & 1 deletion tests/deprecated_api/test_remove_1-9.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest

import pytorch_lightning.loggers.base as logger_base
from pytorch_lightning import Trainer
from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.rank_zero import rank_zero_only
Expand Down Expand Up @@ -106,6 +107,12 @@ def test_old_callback_path():
from pytorch_lightning.callbacks.base import Callback

with pytest.deprecated_call(
match="pytorch_lightning.callbacks.base.Callback has been deprecated in v1.7" " and will be removed in v1.9."
match="pytorch_lightning.callbacks.base.Callback has been deprecated in v1.7 and will be removed in v1.9."
):
Callback()


def test_deprecated_dataloader_reset():
trainer = Trainer()
with pytest.deprecated_call(match="reset_train_val_dataloaders` has been deprecated in v1.7"):
trainer.reset_train_val_dataloaders()
59 changes: 59 additions & 0 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,3 +973,62 @@ def test_rich_print_results(inputs, expected):
EvaluationLoop._print_results(*inputs)
expected = expected[1:] # remove the initial line break from the """ string
assert capture.get() == expected.lstrip()


@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics")
@pytest.mark.parametrize("num_dataloaders", (1, 2))
def test_eval_step_logging(mock_log_metrics, tmpdir, num_dataloaders):
"""Test that eval step during fit/validate/test is updated correctly."""

class CustomBoringModel(BoringModel):
def validation_step(self, batch, batch_idx, dataloader_idx=None):
self.log(f"val_log_{self.trainer.state.fn}", batch_idx, on_step=True, on_epoch=False)

def test_step(self, batch, batch_idx, dataloader_idx=None):
self.log("test_log", batch_idx, on_step=True, on_epoch=False)

def val_dataloader(self):
return [super().val_dataloader()] * num_dataloaders

def test_dataloader(self):
return [super().test_dataloader()] * num_dataloaders

validation_epoch_end = None
test_epoch_end = None

limit_batches = 4
max_epochs = 3
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=max_epochs,
limit_train_batches=1,
limit_val_batches=limit_batches,
limit_test_batches=limit_batches,
)
model = CustomBoringModel()

trainer.fit(model)
trainer.validate(model)
trainer.test(model)

def get_suffix(dl_idx):
return f"/dataloader_idx_{dl_idx}" if num_dataloaders == 2 else ""

eval_steps = range(limit_batches)
fit_calls = [
call(metrics={f"val_log_fit{get_suffix(dl_idx)}": float(step)}, step=step + (limit_batches * epoch))
for epoch in range(max_epochs)
for dl_idx in range(num_dataloaders)
for step in eval_steps
]
validate_calls = [
call(metrics={f"val_log_validate{get_suffix(dl_idx)}": float(val)}, step=val)
for dl_idx in range(num_dataloaders)
for val in eval_steps
]
test_calls = [
call(metrics={f"test_log{get_suffix(dl_idx)}": float(val)}, step=val)
for dl_idx in range(num_dataloaders)
for val in eval_steps
]
assert mock_log_metrics.mock_calls == fit_calls + validate_calls + test_calls
2 changes: 1 addition & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir):

assert tracker.mock_calls == [
call.reset_val_dataloader(),
call.reset_train_dataloader(model=model),
call.reset_train_dataloader(model),
call.reset_test_dataloader(),
]

Expand Down