Skip to content

Commit

Permalink
Avoid expensive iter() call to dataloader in dataloader checks (#18415
Browse files Browse the repository at this point in the history
)
  • Loading branch information
awaelchli authored Aug 28, 2023
1 parent 722fdea commit d77132b
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 9 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed setting the tracking uri in `MLFlowLogger` for logging artifacts to the MLFlow server ([#18395](https://github.com/Lightning-AI/lightning/pull/18395))


- Fixed redundant `iter()` call to dataloader when checking dataloading configuration ([#18415](https://github.com/Lightning-AI/lightning/pull/18415))


## [2.0.5] - 2023-07-07

### Fixed
Expand Down
4 changes: 4 additions & 0 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ def _check_dataloader_iterable(
source: _DataLoaderSource,
trainer_fn: TrainerFn,
) -> None:
if isinstance(dataloader, DataLoader):
# Fast path: `torch.utils.data.DataLoader` is always iterable, calling iter() would be expensive
return

try:
iter(dataloader) # type: ignore[call-overload]
except TypeError:
Expand Down
8 changes: 0 additions & 8 deletions tests/tests_pytorch/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,17 +818,13 @@ def _get_iterator(self):
expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown
elif should_fail:
expected = [
# iterable check
0,
# epoch ends
1,
# teardown
1,
]
else:
expected = [
# iterable check
0,
# epoch ends
1,
2,
Expand All @@ -843,8 +839,6 @@ def _get_iterator(self):
expected = [
# sanity check
0,
# iterable check
0,
# epoch ends
0,
1,
Expand All @@ -853,8 +847,6 @@ def _get_iterator(self):
expected = [
# sanity check
0,
# iterable check
0,
# epoch ends
0,
1,
Expand Down
21 changes: 20 additions & 1 deletion tests/tests_pytorch/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
from lightning.pytorch.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache
from lightning.pytorch.trainer.connectors.data_connector import (
_check_dataloader_iterable,
_DataHookSelector,
_DataLoaderSource,
warning_cache,
)
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.data import _update_dataloader
Expand Down Expand Up @@ -643,3 +648,17 @@ def test_non_iterables_raise(tmp_path, trainer_fn_name, dataloader_name, stage,
setattr(model, dl_method, lambda: dataloader)
with pytest.raises(TypeError, match=f"invalid dataloader was returned from `BoringModel.{dl_method}"):
trainer_fn(model)


def test_iterable_check_on_known_iterators():
"""Test that we only call the `iter()` on the dataloader object if it isn't a known type."""
iterable = Mock()
iterable.__iter__ = Mock(return_value=iter(range(3)))
_check_dataloader_iterable(iterable, Mock(), Mock())
iterable.__iter__.assert_called_once()

# If it's a datalaoder, we don't call the expensive `__iter__` method
dataloader = Mock(spec=DataLoader)
dataloader.__iter__ = Mock()
_check_dataloader_iterable(dataloader, Mock(), Mock())
dataloader.__iter__.assert_not_called()

0 comments on commit d77132b

Please sign in to comment.