Skip to content

Commit

Permalink
Error checking for non-iterables (#17007)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 10, 2023
1 parent b09c077 commit da276db
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 64 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `DDPStrategy(start_method=...)` argument, defaulting to 'popen' ([#16809](https://github.com/Lightning-AI/lightning/pull/16809))


- Added checks for whether the iterables used by the loops are valid ([#17007](https://github.com/Lightning-AI/lightning/pull/17007))

### Changed


Expand Down
7 changes: 5 additions & 2 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _verify_dataloader_idx_requirement
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import (
_check_dataloader_iterable,
_DataLoaderSource,
_parse_num_batches,
_process_dataloader,
Expand Down Expand Up @@ -123,15 +124,15 @@ def run(self) -> List[_OUT_DICT]:
def setup_data(self) -> None:
trainer = self.trainer
trainer_fn = trainer.state.fn
assert trainer_fn is not None

if self._combined_loader is not None and trainer_fn == "fit" and not self._should_reload_val_dl:
return

source = self._data_source
pl_module = trainer.lightning_module
limit_batches = trainer.limit_test_batches if trainer.testing else trainer.limit_val_batches
hook_name = "test_step" if trainer.testing else "validation_step"
if not source.is_defined() or limit_batches == 0 or not is_overridden(hook_name, pl_module):
if limit_batches == 0 or not is_overridden(hook_name, pl_module):
return

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
Expand All @@ -145,6 +146,7 @@ def setup_data(self) -> None:
stage = trainer.state.stage
assert stage is not None

source = self._data_source
dataloaders = _request_dataloader(source)
trainer.strategy.barrier(f"{stage.dataloader_prefix}_dataloader()")

Expand All @@ -163,6 +165,7 @@ def setup_data(self) -> None:
dataloaders = []
self._max_batches = []
for dl in combined_loader.flattened:
_check_dataloader_iterable(dl, source, trainer_fn)
dl = _process_dataloader(trainer, dl)
dataloaders.append(dl)

Expand Down
14 changes: 10 additions & 4 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import (
_check_dataloader_iterable,
_DataLoaderSource,
_parse_num_batches,
_process_dataloader,
_request_dataloader,
_resolve_overfit_batches,
)
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.data import has_len_all_ranks
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
Expand Down Expand Up @@ -208,13 +209,13 @@ def setup_data(self) -> None:
return

trainer = self.trainer
source = self._data_source
pl_module = trainer.lightning_module
if not source.is_defined() or trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module):
if trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module):
return

log.debug(f"{self.__class__.__name__}: resetting train dataloader")

source = self._data_source
train_dataloader = _request_dataloader(source)
trainer.strategy.barrier("train_dataloader()")

Expand All @@ -226,7 +227,12 @@ def setup_data(self) -> None:
if trainer.overfit_batches > 0:
_resolve_overfit_batches(combined_loader, mode=RunningStage.TRAINING)

dataloaders = [_process_dataloader(trainer, dl) for dl in combined_loader.flattened]
trainer_fn = TrainerFn.FITTING
dataloaders = []
for dl in combined_loader.flattened:
_check_dataloader_iterable(dl, source, trainer_fn)
dl = _process_dataloader(trainer, dl)
dataloaders.append(dl)
combined_loader.flattened = dataloaders
self._combined_loader = combined_loader

Expand Down
11 changes: 7 additions & 4 deletions src/lightning/pytorch/loops/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import (
_check_dataloader_iterable,
_DataLoaderSource,
_parse_num_batches,
_process_dataloader,
_request_dataloader,
)
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader
from lightning.pytorch.utilities.data import has_len_all_ranks
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -118,11 +119,11 @@ def run(self) -> Optional[_PREDICT_OUTPUT]:

def setup_data(self) -> None:
trainer = self.trainer
source = self._data_source
# a dfault `predict_step` exists in the LightningModule, so no need to check if it's overridden
if not source.is_defined() or trainer.limit_predict_batches == 0:
# a default `predict_step` exists in the LightningModule, so no need to check if it's overridden
if trainer.limit_predict_batches == 0:
return

source = self._data_source
dataloaders = _request_dataloader(source)
trainer.strategy.barrier("predict_dataloader()")

Expand All @@ -135,10 +136,12 @@ def setup_data(self) -> None:
if trainer.datamodule is not None:
allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices

trainer_fn = TrainerFn.PREDICTING
stage = RunningStage.PREDICTING
dataloaders = []
self.max_batches = []
for dl in combined_loader.flattened:
_check_dataloader_iterable(dl, source, trainer_fn)
dl = _process_dataloader(trainer, dl)
dataloaders.append(dl)

Expand Down
61 changes: 31 additions & 30 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,21 +118,8 @@ def attach_data(
)
self.attach_datamodule(model, datamodule=datamodule)

trainer = self.trainer
fn = trainer.state.fn
# Validate that the required data sources are available
if fn == TrainerFn.FITTING:
_check_dataloader_none(train_dataloaders, trainer.fit_loop._data_source, fn)
# TODO(carmocca): fit's validation dataloaders should be checked too
elif fn == TrainerFn.VALIDATING:
_check_dataloader_none(val_dataloaders, trainer.validate_loop._data_source, fn)
elif fn == TrainerFn.TESTING:
_check_dataloader_none(test_dataloaders, trainer.test_loop._data_source, fn)
elif fn == TrainerFn.PREDICTING:
_check_dataloader_none(predict_dataloaders, trainer.predict_loop._data_source, fn)

# Attach the trainer to the LightningModule
model.trainer = trainer
model.trainer = self.trainer

def attach_dataloaders(
self,
Expand Down Expand Up @@ -264,15 +251,18 @@ def _get_distributed_sampler(


def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None:
all_have_sequential_sampler = all(isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened)
all_have_sequential_sampler = all(
isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler")
)
if all_have_sequential_sampler:
return
rank_zero_warn(
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
)
updated = [
_update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) for dl in combined_loader.flattened
_update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl
for dl in combined_loader.flattened
]
combined_loader.flattened = updated

Expand Down Expand Up @@ -303,11 +293,9 @@ def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
"""
if isinstance(self.instance, pl.LightningModule):
return call._call_lightning_module_hook(self.instance.trainer, self.name, pl_module=self.instance)

if isinstance(self.instance, pl.LightningDataModule):
method = getattr(self.instance, self.name)
return method()

assert self.instance.trainer is not None
return call._call_lightning_datamodule_hook(self.instance.trainer, self.name)
assert self.instance is not None
return self.instance

Expand Down Expand Up @@ -386,18 +374,31 @@ def get_instance(self, hook_name: str) -> Union["pl.LightningModule", "pl.Lightn
return self.model


def _check_dataloader_none(
dataloader: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]],
dataloader_source: _DataLoaderSource,
def _check_dataloader_iterable(
dataloader: object,
source: _DataLoaderSource,
trainer_fn: TrainerFn,
) -> None:
# A prefix in the message to disambiguate between the train- and (optional) val dataloader that .fit() accepts
prefix = "train_" if trainer_fn == TrainerFn.FITTING else ""
if dataloader is None and not dataloader_source.is_defined():
raise ValueError(
f"An invalid dataloader was passed to `Trainer.{trainer_fn}({prefix}dataloaders=...)`."
f" Either pass the dataloader to the `.{trainer_fn}()` method OR implement"
f" `def {dataloader_source.name}(self):` in your LightningModule/LightningDataModule."
try:
iter(dataloader) # type: ignore[call-overload]
except TypeError:
# A prefix in the message to disambiguate between the train- and (optional) val dataloader that .fit() accepts
prefix = "train_" if trainer_fn == TrainerFn.FITTING else ""
if not source.is_module():
raise TypeError(
f"An invalid dataloader was passed to `Trainer.{trainer_fn}({prefix}dataloaders=...)`."
f" Found {dataloader}."
)
if not is_overridden(source.name, source.instance):
raise TypeError(
f"An invalid dataloader was passed to `Trainer.{trainer_fn}({prefix}dataloaders=...)`."
f" Found {dataloader}."
f" Either pass the dataloader to the `.{trainer_fn}()` method OR implement"
f" `def {source.name}(self):` in your LightningModule/LightningDataModule."
)
raise TypeError(
f"An invalid dataloader was returned from `{type(source.instance).__name__}.{source.name}()`."
f" Found {dataloader}."
)


Expand Down
93 changes: 71 additions & 22 deletions tests/tests_pytorch/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,19 +762,39 @@ def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers):
# `num_workers == 1` uses `_MultiProcessingDataLoaderIter`
# `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance

class TestCallback(Callback):
def on_train_epoch_end(self, trainer, *_):
if trainer.current_epoch == 1:
raise CustomException

max_epochs = 3

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=max_epochs,
callbacks=TestCallback() if should_fail else None,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
logger=False,
)

class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
def __init__(self, *args, dataloader, **kwargs):
super().__init__(*args, **kwargs)
self.dataloader = dataloader

def _shutdown_workers(self):
self.dataloader.count_shutdown_workers += 1
self.dataloader.shutdown_workers_epochs.append(trainer.current_epoch)
super()._shutdown_workers()

class TestDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.count_shutdown_workers = 0
self.shutdown_workers_epochs = []

def _get_iterator(self):
if self.num_workers == 0:
Expand All @@ -786,29 +806,58 @@ def _get_iterator(self):
train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)

class TestCallback(Callback):
def on_train_epoch_end(self, trainer, *_):
if trainer.current_epoch == 1:
raise CustomException

max_epochs = 3

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=max_epochs,
callbacks=TestCallback() if should_fail else None,
)

if should_fail:
with pytest.raises(CustomException):
trainer.fit(model, train_dataloader, val_dataloader)
else:
trainer.fit(model, train_dataloader, val_dataloader)

assert train_dataloader.count_shutdown_workers == 2 if should_fail else (2 if persistent_workers else max_epochs)
# on sanity checking end, the workers are being deleted too.
expected = 2 if persistent_workers else (3 if should_fail else max_epochs + 1)
assert val_dataloader.count_shutdown_workers == expected
if persistent_workers:
expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown
elif should_fail:
expected = [
# iterable check
0,
# epoch ends
1,
# teardown
1,
]
else:
expected = [
# iterable check
0,
# epoch ends
1,
2,
# teardown
3,
]
assert train_dataloader.shutdown_workers_epochs == expected

if persistent_workers:
expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown
elif should_fail:
expected = [
# sanity check
0,
# iterable check
0,
# epoch ends
1,
# teardown
1,
]
else:
expected = [
# sanity check
0,
# iterable check
0,
# epoch ends
1,
2,
# teardown
3,
]
assert val_dataloader.shutdown_workers_epochs == expected
Loading

0 comments on commit da276db

Please sign in to comment.