diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 3afc2a79b0484..fb312fa3cd2aa 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -244,6 +244,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed setting the tracking uri in `MLFlowLogger` for logging artifacts to the MLFlow server ([#18395](https://github.com/Lightning-AI/lightning/pull/18395)) +- Fixed redundant `iter()` call to dataloader when checking dataloading configuration ([#18415](https://github.com/Lightning-AI/lightning/pull/18415)) + + ## [2.0.5] - 2023-07-07 ### Fixed diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 424cfe7cf6d14..63aa129ea9d42 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -393,6 +393,10 @@ def _check_dataloader_iterable( source: _DataLoaderSource, trainer_fn: TrainerFn, ) -> None: + if isinstance(dataloader, DataLoader): + # Fast path: `torch.utils.data.DataLoader` is always iterable, calling iter() would be expensive + return + try: iter(dataloader) # type: ignore[call-overload] except TypeError: diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index df263db1b00f8..e8aa8de9f19ff 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -818,8 +818,6 @@ def _get_iterator(self): expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown elif should_fail: expected = [ - # iterable check - 0, # epoch ends 1, # teardown @@ -827,8 +825,6 @@ def _get_iterator(self): ] else: expected = [ - # iterable check - 0, # epoch ends 1, 2, @@ -843,8 +839,6 @@ def _get_iterator(self): expected = [ # sanity check 0, - # iterable check - 0, # epoch ends 0, 1, @@ -853,8 +847,6 @@ def _get_iterator(self): expected = [ # sanity check 0, - # iterable check - 0, # epoch ends 0, 1, diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 1e70bd0e59bb7..7c9dc9126dc0e 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -26,7 +26,12 @@ from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset -from lightning.pytorch.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache +from lightning.pytorch.trainer.connectors.data_connector import ( + _check_dataloader_iterable, + _DataHookSelector, + _DataLoaderSource, + warning_cache, +) from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import _update_dataloader @@ -643,3 +648,17 @@ def test_non_iterables_raise(tmp_path, trainer_fn_name, dataloader_name, stage, setattr(model, dl_method, lambda: dataloader) with pytest.raises(TypeError, match=f"invalid dataloader was returned from `BoringModel.{dl_method}"): trainer_fn(model) + + +def test_iterable_check_on_known_iterators(): + """Test that we only call the `iter()` on the dataloader object if it isn't a known type.""" + iterable = Mock() + iterable.__iter__ = Mock(return_value=iter(range(3))) + _check_dataloader_iterable(iterable, Mock(), Mock()) + iterable.__iter__.assert_called_once() + + # If it's a datalaoder, we don't call the expensive `__iter__` method + dataloader = Mock(spec=DataLoader) + dataloader.__iter__ = Mock() + _check_dataloader_iterable(dataloader, Mock(), Mock()) + dataloader.__iter__.assert_not_called()