Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error checking for non-iterables #17007

Merged
merged 7 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `DDPStrategy(start_method=...)` argument, defaulting to 'popen' ([#16809](https://github.com/Lightning-AI/lightning/pull/16809))


- Added checks for whether the iterables used by the loops are valid ([#17007](https://github.com/Lightning-AI/lightning/pull/17007))

### Changed


Expand Down
7 changes: 5 additions & 2 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _verify_dataloader_idx_requirement
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import (
_check_dataloader_iterable,
_DataLoaderSource,
_parse_num_batches,
_process_dataloader,
Expand Down Expand Up @@ -123,15 +124,15 @@ def run(self) -> List[_OUT_DICT]:
def setup_data(self) -> None:
trainer = self.trainer
trainer_fn = trainer.state.fn
assert trainer_fn is not None

if self._combined_loader is not None and trainer_fn == "fit" and not self._should_reload_val_dl:
return

source = self._data_source
pl_module = trainer.lightning_module
limit_batches = trainer.limit_test_batches if trainer.testing else trainer.limit_val_batches
hook_name = "test_step" if trainer.testing else "validation_step"
if not source.is_defined() or limit_batches == 0 or not is_overridden(hook_name, pl_module):
if limit_batches == 0 or not is_overridden(hook_name, pl_module):
return

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
Expand All @@ -145,6 +146,7 @@ def setup_data(self) -> None:
stage = trainer.state.stage
assert stage is not None

source = self._data_source
dataloaders = _request_dataloader(source)
trainer.strategy.barrier(f"{stage.dataloader_prefix}_dataloader()")

Expand All @@ -163,6 +165,7 @@ def setup_data(self) -> None:
dataloaders = []
self._max_batches = []
for dl in combined_loader.flattened:
_check_dataloader_iterable(dl, source, trainer_fn)
dl = _process_dataloader(trainer, dl)
dataloaders.append(dl)

Expand Down
14 changes: 10 additions & 4 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import (
_check_dataloader_iterable,
_DataLoaderSource,
_parse_num_batches,
_process_dataloader,
_request_dataloader,
_resolve_overfit_batches,
)
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.data import has_len_all_ranks
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
Expand Down Expand Up @@ -208,13 +209,13 @@ def setup_data(self) -> None:
return

trainer = self.trainer
source = self._data_source
pl_module = trainer.lightning_module
if not source.is_defined() or trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module):
if trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module):
return

log.debug(f"{self.__class__.__name__}: resetting train dataloader")

source = self._data_source
train_dataloader = _request_dataloader(source)
trainer.strategy.barrier("train_dataloader()")

Expand All @@ -226,7 +227,12 @@ def setup_data(self) -> None:
if trainer.overfit_batches > 0:
_resolve_overfit_batches(combined_loader, mode=RunningStage.TRAINING)

dataloaders = [_process_dataloader(trainer, dl) for dl in combined_loader.flattened]
trainer_fn = TrainerFn.FITTING
dataloaders = []
for dl in combined_loader.flattened:
_check_dataloader_iterable(dl, source, trainer_fn)
dl = _process_dataloader(trainer, dl)
dataloaders.append(dl)
combined_loader.flattened = dataloaders
self._combined_loader = combined_loader

Expand Down
11 changes: 7 additions & 4 deletions src/lightning/pytorch/loops/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import (
_check_dataloader_iterable,
_DataLoaderSource,
_parse_num_batches,
_process_dataloader,
_request_dataloader,
)
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader
from lightning.pytorch.utilities.data import has_len_all_ranks
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -118,11 +119,11 @@ def run(self) -> Optional[_PREDICT_OUTPUT]:

def setup_data(self) -> None:
trainer = self.trainer
source = self._data_source
# a dfault `predict_step` exists in the LightningModule, so no need to check if it's overridden
if not source.is_defined() or trainer.limit_predict_batches == 0:
# a default `predict_step` exists in the LightningModule, so no need to check if it's overridden
if trainer.limit_predict_batches == 0:
return

source = self._data_source
dataloaders = _request_dataloader(source)
trainer.strategy.barrier("predict_dataloader()")

Expand All @@ -135,10 +136,12 @@ def setup_data(self) -> None:
if trainer.datamodule is not None:
allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices

trainer_fn = TrainerFn.PREDICTING
stage = RunningStage.PREDICTING
dataloaders = []
self.max_batches = []
for dl in combined_loader.flattened:
_check_dataloader_iterable(dl, source, trainer_fn)
dl = _process_dataloader(trainer, dl)
dataloaders.append(dl)

Expand Down
60 changes: 30 additions & 30 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,21 +118,8 @@ def attach_data(
)
self.attach_datamodule(model, datamodule=datamodule)

trainer = self.trainer
fn = trainer.state.fn
# Validate that the required data sources are available
if fn == TrainerFn.FITTING:
_check_dataloader_none(train_dataloaders, trainer.fit_loop._data_source, fn)
# TODO(carmocca): fit's validation dataloaders should be checked too
elif fn == TrainerFn.VALIDATING:
_check_dataloader_none(val_dataloaders, trainer.validate_loop._data_source, fn)
elif fn == TrainerFn.TESTING:
_check_dataloader_none(test_dataloaders, trainer.test_loop._data_source, fn)
elif fn == TrainerFn.PREDICTING:
_check_dataloader_none(predict_dataloaders, trainer.predict_loop._data_source, fn)

# Attach the trainer to the LightningModule
model.trainer = trainer
model.trainer = self.trainer

def attach_dataloaders(
self,
Expand Down Expand Up @@ -264,15 +251,18 @@ def _get_distributed_sampler(


def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None:
all_have_sequential_sampler = all(isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened)
all_have_sequential_sampler = all(
isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
)
if all_have_sequential_sampler:
return
rank_zero_warn(
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
)
updated = [
_update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) for dl in combined_loader.flattened
_update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl
for dl in combined_loader.flattened
]
combined_loader.flattened = updated

Expand Down Expand Up @@ -303,11 +293,9 @@ def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
"""
if isinstance(self.instance, pl.LightningModule):
return call._call_lightning_module_hook(self.instance.trainer, self.name, pl_module=self.instance)

if isinstance(self.instance, pl.LightningDataModule):
method = getattr(self.instance, self.name)
return method()

assert self.instance.trainer is not None
return call._call_lightning_datamodule_hook(self.instance.trainer, self.name)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
assert self.instance is not None
return self.instance

Expand Down Expand Up @@ -386,18 +374,30 @@ def get_instance(self, hook_name: str) -> Union["pl.LightningModule", "pl.Lightn
return self.model


def _check_dataloader_none(
dataloader: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]],
dataloader_source: _DataLoaderSource,
def _check_dataloader_iterable(
dataloader: object,
source: _DataLoaderSource,
trainer_fn: TrainerFn,
) -> None:
# A prefix in the message to disambiguate between the train- and (optional) val dataloader that .fit() accepts
prefix = "train_" if trainer_fn == TrainerFn.FITTING else ""
if dataloader is None and not dataloader_source.is_defined():
raise ValueError(
f"An invalid dataloader was passed to `Trainer.{trainer_fn}({prefix}dataloaders=...)`."
f" Either pass the dataloader to the `.{trainer_fn}()` method OR implement"
f" `def {dataloader_source.name}(self):` in your LightningModule/LightningDataModule."
try:
iter(dataloader) # type: ignore[call-overload]
except TypeError:
# A prefix in the message to disambiguate between the train- and (optional) val dataloader that .fit() accepts
prefix = "train_" if trainer_fn == TrainerFn.FITTING else ""
if source.is_module():
if not is_overridden(source.name, source.instance):
raise TypeError(
f"An invalid dataloader was passed to `Trainer.{trainer_fn}({prefix}dataloaders=...)`."
f" Found {dataloader}."
f" Either pass the dataloader to the `.{trainer_fn}()` method OR implement"
f" `def {source.name}(self):` in your LightningModule/LightningDataModule."
)
raise TypeError(
f"An invalid dataloader was returned from `{type(source.instance).__name__}.{source.name}()`."
f" Found {dataloader}."
)
raise TypeError(
f"An invalid dataloader was passed to `Trainer.{trainer_fn}({prefix}dataloaders=...)`. Found {dataloader}."
)
carmocca marked this conversation as resolved.
Show resolved Hide resolved


Expand Down
93 changes: 71 additions & 22 deletions tests/tests_pytorch/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,19 +762,39 @@ def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers):
# `num_workers == 1` uses `_MultiProcessingDataLoaderIter`
# `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance

class TestCallback(Callback):
def on_train_epoch_end(self, trainer, *_):
if trainer.current_epoch == 1:
raise CustomException

max_epochs = 3

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=max_epochs,
callbacks=TestCallback() if should_fail else None,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
logger=False,
)

class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
def __init__(self, *args, dataloader, **kwargs):
super().__init__(*args, **kwargs)
self.dataloader = dataloader

def _shutdown_workers(self):
self.dataloader.count_shutdown_workers += 1
self.dataloader.shutdown_workers_epochs.append(trainer.current_epoch)
super()._shutdown_workers()

class TestDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.count_shutdown_workers = 0
self.shutdown_workers_epochs = []

def _get_iterator(self):
if self.num_workers == 0:
Expand All @@ -786,29 +806,58 @@ def _get_iterator(self):
train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)

class TestCallback(Callback):
def on_train_epoch_end(self, trainer, *_):
if trainer.current_epoch == 1:
raise CustomException

max_epochs = 3

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=max_epochs,
callbacks=TestCallback() if should_fail else None,
)

if should_fail:
with pytest.raises(CustomException):
trainer.fit(model, train_dataloader, val_dataloader)
else:
trainer.fit(model, train_dataloader, val_dataloader)

assert train_dataloader.count_shutdown_workers == 2 if should_fail else (2 if persistent_workers else max_epochs)
# on sanity checking end, the workers are being deleted too.
expected = 2 if persistent_workers else (3 if should_fail else max_epochs + 1)
assert val_dataloader.count_shutdown_workers == expected
if persistent_workers:
expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown
elif should_fail:
expected = [
# iterable check
0,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# epoch ends
1,
# teardown
1,
]
else:
expected = [
# iterable check
0,
# epoch ends
1,
2,
# teardown
3,
]
assert train_dataloader.shutdown_workers_epochs == expected

if persistent_workers:
expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown
elif should_fail:
expected = [
# sanity check
0,
# iterable check
0,
# epoch ends
1,
# teardown
1,
]
else:
expected = [
# sanity check
0,
# iterable check
0,
# epoch ends
1,
2,
# teardown
3,
]
assert val_dataloader.shutdown_workers_epochs == expected
30 changes: 28 additions & 2 deletions tests/tests_pytorch/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,34 @@ def test_attach_data_input_validation_with_none_dataloader(trainer_fn_name, data
datamodule.test_dataloader = None
datamodule.predict_dataloader = None

with pytest.raises(ValueError, match=f"An invalid .*dataloader was passed to `Trainer.{trainer_fn_name}"):
with pytest.raises(TypeError, match=f"An invalid .*dataloader was passed to `Trainer.{trainer_fn_name}"):
trainer_fn(model, **{dataloader_name: None}, datamodule=datamodule)

with pytest.raises(ValueError, match=f"An invalid .*dataloader was passed to `Trainer.{trainer_fn_name}"):
with pytest.raises(TypeError, match=f"An invalid .*dataloader was passed to `Trainer.{trainer_fn_name}"):
trainer_fn(model, **{dataloader_name: None}, datamodule=None)


@pytest.mark.parametrize(
"trainer_fn_name, dataloader_name, stage",
[
("fit", "train_dataloaders", RunningStage.TRAINING),
("validate", "dataloaders", RunningStage.VALIDATING),
("test", "dataloaders", RunningStage.TESTING),
("predict", "dataloaders", RunningStage.PREDICTING),
],
)
@pytest.mark.parametrize("dataloader", [object(), [1, object()]])
carmocca marked this conversation as resolved.
Show resolved Hide resolved
def test_non_iterables_raise(tmp_path, trainer_fn_name, dataloader_name, stage, dataloader):
model = BoringModel()
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=1)
trainer_fn = getattr(trainer, trainer_fn_name)

with pytest.raises(
TypeError, match=rf"invalid dataloader was passed to `Trainer.{trainer_fn_name}\({dataloader_name}"
):
trainer_fn(model, **{dataloader_name: dataloader})

dl_method = stage.dataloader_prefix + "_dataloader"
setattr(model, dl_method, lambda: dataloader)
with pytest.raises(TypeError, match=f"invalid dataloader was returned from `BoringModel.{dl_method}"):
trainer_fn(model)