From da276dbd35e0a4dea5fede269c25ae9033792e43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 10 Mar 2023 18:02:28 +0100 Subject: [PATCH] Error checking for non-iterables (#17007) --- src/lightning/pytorch/CHANGELOG.md | 3 + .../pytorch/loops/evaluation_loop.py | 7 +- src/lightning/pytorch/loops/fit_loop.py | 14 ++- .../pytorch/loops/prediction_loop.py | 11 ++- .../trainer/connectors/data_connector.py | 61 ++++++------ tests/tests_pytorch/loops/test_loops.py | 93 ++++++++++++++----- .../trainer/connectors/test_data_connector.py | 37 +++++++- 7 files changed, 162 insertions(+), 64 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 67fb44b41fd65..9ef3daa7456e6 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `DDPStrategy(start_method=...)` argument, defaulting to 'popen' ([#16809](https://github.com/Lightning-AI/lightning/pull/16809)) + +- Added checks for whether the iterables used by the loops are valid ([#17007](https://github.com/Lightning-AI/lightning/pull/17007)) + ### Changed diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 4f0ac6e012fcc..fa78972ee4d0f 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -29,6 +29,7 @@ from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _verify_dataloader_idx_requirement from lightning.pytorch.trainer import call from lightning.pytorch.trainer.connectors.data_connector import ( + _check_dataloader_iterable, _DataLoaderSource, _parse_num_batches, _process_dataloader, @@ -123,15 +124,15 @@ def run(self) -> List[_OUT_DICT]: def setup_data(self) -> None: trainer = self.trainer trainer_fn = trainer.state.fn + assert trainer_fn is not None if self._combined_loader is not None and trainer_fn == "fit" and not self._should_reload_val_dl: return - source = self._data_source pl_module = trainer.lightning_module limit_batches = trainer.limit_test_batches if trainer.testing else trainer.limit_val_batches hook_name = "test_step" if trainer.testing else "validation_step" - if not source.is_defined() or limit_batches == 0 or not is_overridden(hook_name, pl_module): + if limit_batches == 0 or not is_overridden(hook_name, pl_module): return # store epoch of dataloader reset for reload_dataloaders_every_n_epochs @@ -145,6 +146,7 @@ def setup_data(self) -> None: stage = trainer.state.stage assert stage is not None + source = self._data_source dataloaders = _request_dataloader(source) trainer.strategy.barrier(f"{stage.dataloader_prefix}_dataloader()") @@ -163,6 +165,7 @@ def setup_data(self) -> None: dataloaders = [] self._max_batches = [] for dl in combined_loader.flattened: + _check_dataloader_iterable(dl, source, trainer_fn) dl = _process_dataloader(trainer, dl) dataloaders.append(dl) diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index c5778a0d9e818..c214177ebfe28 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -23,6 +23,7 @@ from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher from lightning.pytorch.trainer import call from lightning.pytorch.trainer.connectors.data_connector import ( + _check_dataloader_iterable, _DataLoaderSource, _parse_num_batches, _process_dataloader, @@ -30,7 +31,7 @@ _resolve_overfit_batches, ) from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection -from lightning.pytorch.trainer.states import RunningStage +from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException @@ -208,13 +209,13 @@ def setup_data(self) -> None: return trainer = self.trainer - source = self._data_source pl_module = trainer.lightning_module - if not source.is_defined() or trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module): + if trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module): return log.debug(f"{self.__class__.__name__}: resetting train dataloader") + source = self._data_source train_dataloader = _request_dataloader(source) trainer.strategy.barrier("train_dataloader()") @@ -226,7 +227,12 @@ def setup_data(self) -> None: if trainer.overfit_batches > 0: _resolve_overfit_batches(combined_loader, mode=RunningStage.TRAINING) - dataloaders = [_process_dataloader(trainer, dl) for dl in combined_loader.flattened] + trainer_fn = TrainerFn.FITTING + dataloaders = [] + for dl in combined_loader.flattened: + _check_dataloader_iterable(dl, source, trainer_fn) + dl = _process_dataloader(trainer, dl) + dataloaders.append(dl) combined_loader.flattened = dataloaders self._combined_loader = combined_loader diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index cd9817ada6721..ca543cb2e1576 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -28,12 +28,13 @@ from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher from lightning.pytorch.trainer import call from lightning.pytorch.trainer.connectors.data_connector import ( + _check_dataloader_iterable, _DataLoaderSource, _parse_num_batches, _process_dataloader, _request_dataloader, ) -from lightning.pytorch.trainer.states import RunningStage +from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -118,11 +119,11 @@ def run(self) -> Optional[_PREDICT_OUTPUT]: def setup_data(self) -> None: trainer = self.trainer - source = self._data_source - # a dfault `predict_step` exists in the LightningModule, so no need to check if it's overridden - if not source.is_defined() or trainer.limit_predict_batches == 0: + # a default `predict_step` exists in the LightningModule, so no need to check if it's overridden + if trainer.limit_predict_batches == 0: return + source = self._data_source dataloaders = _request_dataloader(source) trainer.strategy.barrier("predict_dataloader()") @@ -135,10 +136,12 @@ def setup_data(self) -> None: if trainer.datamodule is not None: allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices + trainer_fn = TrainerFn.PREDICTING stage = RunningStage.PREDICTING dataloaders = [] self.max_batches = [] for dl in combined_loader.flattened: + _check_dataloader_iterable(dl, source, trainer_fn) dl = _process_dataloader(trainer, dl) dataloaders.append(dl) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 7b5f59d08800e..e1f6628509dd8 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -118,21 +118,8 @@ def attach_data( ) self.attach_datamodule(model, datamodule=datamodule) - trainer = self.trainer - fn = trainer.state.fn - # Validate that the required data sources are available - if fn == TrainerFn.FITTING: - _check_dataloader_none(train_dataloaders, trainer.fit_loop._data_source, fn) - # TODO(carmocca): fit's validation dataloaders should be checked too - elif fn == TrainerFn.VALIDATING: - _check_dataloader_none(val_dataloaders, trainer.validate_loop._data_source, fn) - elif fn == TrainerFn.TESTING: - _check_dataloader_none(test_dataloaders, trainer.test_loop._data_source, fn) - elif fn == TrainerFn.PREDICTING: - _check_dataloader_none(predict_dataloaders, trainer.predict_loop._data_source, fn) - # Attach the trainer to the LightningModule - model.trainer = trainer + model.trainer = self.trainer def attach_dataloaders( self, @@ -264,7 +251,9 @@ def _get_distributed_sampler( def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None: - all_have_sequential_sampler = all(isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened) + all_have_sequential_sampler = all( + isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler") + ) if all_have_sequential_sampler: return rank_zero_warn( @@ -272,7 +261,8 @@ def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you." ) updated = [ - _update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) for dl in combined_loader.flattened + _update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl + for dl in combined_loader.flattened ] combined_loader.flattened = updated @@ -303,11 +293,9 @@ def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: """ if isinstance(self.instance, pl.LightningModule): return call._call_lightning_module_hook(self.instance.trainer, self.name, pl_module=self.instance) - if isinstance(self.instance, pl.LightningDataModule): - method = getattr(self.instance, self.name) - return method() - + assert self.instance.trainer is not None + return call._call_lightning_datamodule_hook(self.instance.trainer, self.name) assert self.instance is not None return self.instance @@ -386,18 +374,31 @@ def get_instance(self, hook_name: str) -> Union["pl.LightningModule", "pl.Lightn return self.model -def _check_dataloader_none( - dataloader: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]], - dataloader_source: _DataLoaderSource, +def _check_dataloader_iterable( + dataloader: object, + source: _DataLoaderSource, trainer_fn: TrainerFn, ) -> None: - # A prefix in the message to disambiguate between the train- and (optional) val dataloader that .fit() accepts - prefix = "train_" if trainer_fn == TrainerFn.FITTING else "" - if dataloader is None and not dataloader_source.is_defined(): - raise ValueError( - f"An invalid dataloader was passed to `Trainer.{trainer_fn}({prefix}dataloaders=...)`." - f" Either pass the dataloader to the `.{trainer_fn}()` method OR implement" - f" `def {dataloader_source.name}(self):` in your LightningModule/LightningDataModule." + try: + iter(dataloader) # type: ignore[call-overload] + except TypeError: + # A prefix in the message to disambiguate between the train- and (optional) val dataloader that .fit() accepts + prefix = "train_" if trainer_fn == TrainerFn.FITTING else "" + if not source.is_module(): + raise TypeError( + f"An invalid dataloader was passed to `Trainer.{trainer_fn}({prefix}dataloaders=...)`." + f" Found {dataloader}." + ) + if not is_overridden(source.name, source.instance): + raise TypeError( + f"An invalid dataloader was passed to `Trainer.{trainer_fn}({prefix}dataloaders=...)`." + f" Found {dataloader}." + f" Either pass the dataloader to the `.{trainer_fn}()` method OR implement" + f" `def {source.name}(self):` in your LightningModule/LightningDataModule." + ) + raise TypeError( + f"An invalid dataloader was returned from `{type(source.instance).__name__}.{source.name}()`." + f" Found {dataloader}." ) diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index e0a154255b61f..d6076347bbdc9 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -762,19 +762,39 @@ def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers): # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance + class TestCallback(Callback): + def on_train_epoch_end(self, trainer, *_): + if trainer.current_epoch == 1: + raise CustomException + + max_epochs = 3 + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=max_epochs, + callbacks=TestCallback() if should_fail else None, + enable_checkpointing=False, + enable_model_summary=False, + enable_progress_bar=False, + logger=False, + ) + class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter): def __init__(self, *args, dataloader, **kwargs): super().__init__(*args, **kwargs) self.dataloader = dataloader def _shutdown_workers(self): - self.dataloader.count_shutdown_workers += 1 + self.dataloader.shutdown_workers_epochs.append(trainer.current_epoch) super()._shutdown_workers() class TestDataLoader(DataLoader): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.count_shutdown_workers = 0 + self.shutdown_workers_epochs = [] def _get_iterator(self): if self.num_workers == 0: @@ -786,29 +806,58 @@ def _get_iterator(self): train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) - class TestCallback(Callback): - def on_train_epoch_end(self, trainer, *_): - if trainer.current_epoch == 1: - raise CustomException - - max_epochs = 3 - - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=max_epochs, - callbacks=TestCallback() if should_fail else None, - ) - if should_fail: with pytest.raises(CustomException): trainer.fit(model, train_dataloader, val_dataloader) else: trainer.fit(model, train_dataloader, val_dataloader) - assert train_dataloader.count_shutdown_workers == 2 if should_fail else (2 if persistent_workers else max_epochs) - # on sanity checking end, the workers are being deleted too. - expected = 2 if persistent_workers else (3 if should_fail else max_epochs + 1) - assert val_dataloader.count_shutdown_workers == expected + if persistent_workers: + expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown + elif should_fail: + expected = [ + # iterable check + 0, + # epoch ends + 1, + # teardown + 1, + ] + else: + expected = [ + # iterable check + 0, + # epoch ends + 1, + 2, + # teardown + 3, + ] + assert train_dataloader.shutdown_workers_epochs == expected + + if persistent_workers: + expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown + elif should_fail: + expected = [ + # sanity check + 0, + # iterable check + 0, + # epoch ends + 1, + # teardown + 1, + ] + else: + expected = [ + # sanity check + 0, + # iterable check + 0, + # epoch ends + 1, + 2, + # teardown + 3, + ] + assert val_dataloader.shutdown_workers_epochs == expected diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index fd1acbfbabc70..03ad099ec451f 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -614,8 +614,41 @@ def test_attach_data_input_validation_with_none_dataloader(trainer_fn_name, data datamodule.test_dataloader = None datamodule.predict_dataloader = None - with pytest.raises(ValueError, match=f"An invalid .*dataloader was passed to `Trainer.{trainer_fn_name}"): + with pytest.raises(TypeError, match=f"An invalid .*dataloader was passed to `Trainer.{trainer_fn_name}"): trainer_fn(model, **{dataloader_name: None}, datamodule=datamodule) - with pytest.raises(ValueError, match=f"An invalid .*dataloader was passed to `Trainer.{trainer_fn_name}"): + with pytest.raises(TypeError, match=f"An invalid .*dataloader was passed to `Trainer.{trainer_fn_name}"): trainer_fn(model, **{dataloader_name: None}, datamodule=None) + + +@pytest.mark.parametrize( + "trainer_fn_name, dataloader_name, stage", + [ + ("fit", "train_dataloaders", RunningStage.TRAINING), + ("validate", "dataloaders", RunningStage.VALIDATING), + ("test", "dataloaders", RunningStage.TESTING), + ("predict", "dataloaders", RunningStage.PREDICTING), + ], +) +@pytest.mark.parametrize("dataloader", [None, object(), [1, object()]]) +def test_non_iterables_raise(tmp_path, trainer_fn_name, dataloader_name, stage, dataloader): + model = BoringModel() + + # Pretend that these methods are not implemented + model.train_dataloader = None + model.val_dataloader = None + model.test_dataloader = None + model.predict_dataloader = None + + trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=1) + trainer_fn = getattr(trainer, trainer_fn_name) + + with pytest.raises( + TypeError, match=rf"invalid dataloader was passed to `Trainer.{trainer_fn_name}\({dataloader_name}" + ): + trainer_fn(model, **{dataloader_name: dataloader}) + + dl_method = stage.dataloader_prefix + "_dataloader" + setattr(model, dl_method, lambda: dataloader) + with pytest.raises(TypeError, match=f"invalid dataloader was returned from `BoringModel.{dl_method}"): + trainer_fn(model)