From 874614fa6b7fea4608db4e698d14198e7556362c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 29 Sep 2021 12:59:07 +0200 Subject: [PATCH 01/50] draft --- .../trainer/configuration_validator.py | 4 +- .../trainer/connectors/data_connector.py | 106 ++++++++++++++---- pytorch_lightning/trainer/data_loading.py | 6 +- tests/trainer/test_data_loading.py | 8 +- tests/trainer/test_dataloaders.py | 5 +- 5 files changed, 98 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index ee5be467b86bf..56732283c9b34 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -63,7 +63,7 @@ def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None # ----------------------------------- # verify model has a train dataloader # ----------------------------------- - has_train_dataloader = is_overridden("train_dataloader", model) + has_train_dataloader = self.trainer.data_connector._train_dataloader_source.is_available() if not has_train_dataloader: raise MisconfigurationException( "No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a" @@ -159,7 +159,7 @@ def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: s ) def __verify_predict_loop_configuration(self, model: "pl.LightningModule") -> None: - has_predict_dataloader = is_overridden("predict_dataloader", model) + has_predict_dataloader = self.trainer.data_connector._predict_dataloader_source.is_available() if not has_predict_dataloader: raise MisconfigurationException("Dataloader not found for `Trainer.predict`") # ---------------------------------------------- diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 589906d2bb0e4..a276c3bd37cf2 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from dataclasses import dataclass from functools import partial -from typing import Callable, Iterable, Optional, Union +from typing import Callable, Iterable, Optional, Union, List, Dict + +from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation @@ -35,6 +38,7 @@ def __init__( self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle", + # TODO: remove these args train_data_fetcher: Optional[AbstractDataFetcher] = None, validate_data_fetcher: Optional[AbstractDataFetcher] = None, test_data_fetcher: Optional[AbstractDataFetcher] = None, @@ -47,6 +51,15 @@ def __init__( self.test_data_fetcher = test_data_fetcher self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None + self._train_dataloader_source = DataLoaderSource() + self._val_dataloader_source = DataLoaderSource() + self._test_dataloader_source = DataLoaderSource() + self._predict_dataloader_source = DataLoaderSource() + + # @property + # def train_dataloader(self): + # pass + @property def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]: if self.trainer.sanity_checking: @@ -190,27 +203,33 @@ def attach_dataloaders( test_dataloaders: Optional[EVAL_DATALOADERS] = None, predict_dataloaders: Optional[EVAL_DATALOADERS] = None, ) -> None: - # when dataloader is passed via fit, patch the train_dataloader - # functions to overwrite with these implementations - if train_dataloaders is not None: - self.trainer.train_dataloader = None - train_dataloader = _PatchDataLoader(train_dataloaders, "train") - train_dataloader.patch(model) - - if val_dataloaders is not None: - self.trainer.val_dataloaders = None - val_dataloader = _PatchDataLoader(val_dataloaders, "val") - val_dataloader.patch(model) - - if test_dataloaders is not None: - self.trainer.test_dataloaders = None - test_dataloader = _PatchDataLoader(test_dataloaders, "test") - test_dataloader.patch(model) - - if predict_dataloaders is not None: - self.trainer.predict_dataloaders = None - predict_dataloader = _PatchDataLoader(predict_dataloaders, "predict") - predict_dataloader.patch(model) + self._train_dataloader_source = DataLoaderSource( + train_dataloaders if train_dataloaders is not None else model, "train_dataloader" + ) + self._val_dataloader_source = DataLoaderSource( + val_dataloaders if val_dataloaders is not None else model, "val_dataloader" + ) + self._test_dataloader_source = DataLoaderSource( + test_dataloaders if test_dataloaders is not None else model, "test_dataloader" + ) + self._predict_dataloader_source = DataLoaderSource( + predict_dataloaders if predict_dataloaders is not None else model, "predict_dataloader" + ) + + # if val_dataloaders is not None: + # self.trainer.val_dataloaders = val_dataloaders + # val_dataloader = _PatchDataLoader(val_dataloaders, "val") + # val_dataloader.patch(model) + + # if test_dataloaders is not None: + # self.trainer.test_dataloaders = test_dataloaders + # test_dataloader = _PatchDataLoader(test_dataloaders, "test") + # test_dataloader.patch(model) + + # if predict_dataloaders is not None: + # self.trainer.predict_dataloaders = predict_dataloaders + # predict_dataloader = _PatchDataLoader(predict_dataloaders, "predict") + # predict_dataloader.patch(model) def attach_datamodule( self, model: "pl.LightningModule", datamodule: Optional["pl.LightningDataModule"] = None @@ -259,6 +278,49 @@ def teardown(self) -> None: self.sanity_check_data_fetcher.teardown() self.sanity_check_data_fetcher = None + # self._train_dataloader_source = DataLoaderSource() + # self._val_dataloader_source = DataLoaderSource() + # self._test_dataloader_source = DataLoaderSource() + # self._predict_dataloader_source = DataLoaderSource() + + +# TODO: type for list/dict of dataloaders +@dataclass +class DataLoaderSource_old: + source: Union[DataLoader, Callable[[], DataLoader]] = None + + def request(self) -> Union[DataLoader]: + return self.source() if callable(self.source) else self.source + + # TODO: necessary? + def __bool__(self): + return self.source is not None + + +_DATALOADERS = Union[DataLoader, List[DataLoader], Dict[str, DataLoader]] + +# TODO: type for list/dict of dataloaders +@dataclass +class DataLoaderSource: + + instance: Optional[Union[_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] = None + name: str = "" + + def request(self) -> _DATALOADERS: + from pytorch_lightning import LightningDataModule, LightningModule + + if isinstance(self.instance, (LightningModule, LightningDataModule)) and self.name: + return getattr(self.instance, self.name)() + return self.instance + + # TODO: needed in config validator? + def is_available(self) -> bool: + from pytorch_lightning import LightningDataModule, LightningModule + + return not isinstance(self.instance, (LightningModule, LightningDataModule)) or is_overridden( + self.name, self.instance + ) + class _PatchDataLoader: r""" diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index fefac298887ea..1799bd2b46c09 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -537,6 +537,7 @@ def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = No if self.val_dataloaders is None: self.reset_val_dataloader(model=model) + # FIXME: wrong docstring def request_dataloader( self, stage: RunningStage, model: Optional["pl.LightningModule"] = None ) -> Union[DataLoader, List[DataLoader]]: @@ -545,9 +546,12 @@ def request_dataloader( Returns: The dataloader """ + source = getattr(self.data_connector, f"_{stage.dataloader_prefix}_dataloader_source") + hook = f"{stage.dataloader_prefix}_dataloader" self.call_hook("on_" + hook, pl_module=model) - dataloader = self.call_hook(hook, pl_module=model) + # dataloader = self.call_hook(hook, pl_module=model) + dataloader = source.request() if isinstance(dataloader, tuple): dataloader = list(dataloader) self.training_type_plugin.barrier("get_dataloaders") diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 6e91cf926723c..31e18b2bfb578 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -267,19 +267,19 @@ def test_loader_detaching(): class LoaderTestModel(BoringModel): def training_step(self, batch, batch_idx): - assert len(model.train_dataloader()) == 10 + assert len(self.trainer.train_dataloader.loaders) == 10 return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): - assert len(model.val_dataloader()) == 10 + assert len(self.trainer.val_dataloaders[0]) == 10 return super().validation_step(batch, batch_idx) def test_step(self, batch, batch_idx): - assert len(model.test_dataloader()) == 10 + assert len(self.trainer.test_dataloaders[0]) == 10 return super().test_step(batch, batch_idx) def predict_step(self, batch, batch_idx, dataloader_idx=None): - assert len(model.predict_dataloader()) == 10 + assert len(self.trainer.predict_dataloaders[0]) == 10 return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) loader = DataLoader(RandomDataset(32, 10), batch_size=1) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index f5709836c0db6..5cb78252b11d8 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -184,7 +184,7 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders model.test_step = model.test_step__multiple_dataloaders - # train, multiple val and multiple test passed to fit + # multiple val dataloaders passed to fit trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders) @@ -1313,7 +1313,7 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir): def test_dataloaders_reset_and_attach(tmpdir): - """Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset and dataloaders before + """Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching the new one.""" # the assertions compare the datasets and not dataloaders since we patch and replace the samplers dataloader_0 = DataLoader(dataset=RandomDataset(32, 64)) @@ -1486,6 +1486,7 @@ def on_train_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> No assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper) self.on_train_batch_start_called = True + # FIXME: this patching happens after we set the source, so we still call the old method def on_val_dataloader(self) -> None: loader = self.val_dataloader() self.val_dataloader = DataLoaderFunc(DataLoaderWrapper(loader)) From 6d6b3eab8ccd617a0ff6f2902e8e74460238c969 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 29 Sep 2021 13:31:32 +0200 Subject: [PATCH 02/50] draft --- .../plugins/training_type/tpu_spawn.py | 1 + .../trainer/connectors/data_connector.py | 87 ++----------------- pytorch_lightning/trainer/data_loading.py | 4 +- pytorch_lightning/trainer/trainer.py | 2 - tests/callbacks/test_stochastic_weight_avg.py | 2 - tests/utilities/test_model_helpers.py | 1 + 6 files changed, 10 insertions(+), 87 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 43831fa2ac908..b08aa80eb9bd1 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -88,6 +88,7 @@ def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> No " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." ) + # TODO: what to do here? @staticmethod def _validate_patched_dataloaders(model: Module) -> None: """Validate and fail fast if the dataloaders were passed directly to fit.""" diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index a276c3bd37cf2..9dd9c80e2117b 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,7 +14,7 @@ import os from dataclasses import dataclass from functools import partial -from typing import Callable, Iterable, Optional, Union, List, Dict +from typing import Iterable, Optional, Union, List, Dict from torch.utils.data import DataLoader @@ -56,10 +56,6 @@ def __init__( self._test_dataloader_source = DataLoaderSource() self._predict_dataloader_source = DataLoaderSource() - # @property - # def train_dataloader(self): - # pass - @property def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]: if self.trainer.sanity_checking: @@ -216,21 +212,6 @@ def attach_dataloaders( predict_dataloaders if predict_dataloaders is not None else model, "predict_dataloader" ) - # if val_dataloaders is not None: - # self.trainer.val_dataloaders = val_dataloaders - # val_dataloader = _PatchDataLoader(val_dataloaders, "val") - # val_dataloader.patch(model) - - # if test_dataloaders is not None: - # self.trainer.test_dataloaders = test_dataloaders - # test_dataloader = _PatchDataLoader(test_dataloaders, "test") - # test_dataloader.patch(model) - - # if predict_dataloaders is not None: - # self.trainer.predict_dataloaders = predict_dataloaders - # predict_dataloader = _PatchDataLoader(predict_dataloaders, "predict") - # predict_dataloader.patch(model) - def attach_datamodule( self, model: "pl.LightningModule", datamodule: Optional["pl.LightningDataModule"] = None ) -> None: @@ -238,11 +219,10 @@ def attach_datamodule( if datamodule is None: return - # Override loader hooks - dl_methods = ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader") - for method in dl_methods: - if is_overridden(method, datamodule): - setattr(model, method, getattr(datamodule, method)) + self._train_dataloader_source = DataLoaderSource(datamodule, "train_dataloader") + self._val_dataloader_source = DataLoaderSource(datamodule, "val_dataloader") + self._test_dataloader_source = DataLoaderSource(datamodule, "test_dataloader") + self._predict_dataloader_source = DataLoaderSource(datamodule, "predict_dataloader") # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule batch_transfer_hooks = ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") @@ -257,13 +237,6 @@ def attach_datamodule( if hasattr(datamodule, "data_pipeline"): model.data_pipeline = datamodule.data_pipeline - @staticmethod - def detach_data(model: "pl.LightningModule") -> None: - for stage in ("train", "val", "test", "predict"): - loader = getattr(model, f"{stage}_dataloader", None) - if isinstance(loader, _PatchDataLoader): - loader.unpatch(model) - def teardown(self) -> None: if self.train_data_fetcher: self.train_data_fetcher.teardown() @@ -278,27 +251,10 @@ def teardown(self) -> None: self.sanity_check_data_fetcher.teardown() self.sanity_check_data_fetcher = None - # self._train_dataloader_source = DataLoaderSource() - # self._val_dataloader_source = DataLoaderSource() - # self._test_dataloader_source = DataLoaderSource() - # self._predict_dataloader_source = DataLoaderSource() - - -# TODO: type for list/dict of dataloaders -@dataclass -class DataLoaderSource_old: - source: Union[DataLoader, Callable[[], DataLoader]] = None - - def request(self) -> Union[DataLoader]: - return self.source() if callable(self.source) else self.source - - # TODO: necessary? - def __bool__(self): - return self.source is not None - _DATALOADERS = Union[DataLoader, List[DataLoader], Dict[str, DataLoader]] + # TODO: type for list/dict of dataloaders @dataclass class DataLoaderSource: @@ -320,34 +276,3 @@ def is_available(self) -> bool: return not isinstance(self.instance, (LightningModule, LightningDataModule)) or is_overridden( self.name, self.instance ) - - -class _PatchDataLoader: - r""" - Callable object for patching dataloaders passed into trainer.fit(). - Use this class to override model.*_dataloader() and be pickle-compatible. - - Args: - dataloader: Dataloader object to return when called. - """ - - def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], stage: str) -> None: - self.dataloader = dataloader - - # cannot pickle __code__ so cannot verify if PatchDataloader - # exists which shows dataloader methods have been overwritten. - # so, we hack it by using the string representation - self.patch_loader_code = str(self.__call__.__code__) - self._old_loader: Optional[Callable] = None - self.stage = stage - - def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: - return self.dataloader - - def patch(self, model: "pl.LightningModule") -> None: - self._old_loader = getattr(model, self.stage + "_dataloader") - setattr(model, self.stage + "_dataloader", self) - - def unpatch(self, model: "pl.LightningModule") -> None: - setattr(model, self.stage + "_dataloader", self._old_loader) - self._old_loader = None diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 1799bd2b46c09..00feb4618d8e1 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -516,9 +516,9 @@ def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) Args: model: The `LightningModule` if called outside of the trainer scope. """ + source = self.data_connector._predict_dataloader_source pl_module = self.lightning_module or model - has_loader = is_overridden("predict_dataloader", pl_module) - if has_loader: + if source.is_available: self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader( RunningStage.PREDICTING, model=pl_module ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f49c892e37191..f3e9bb43efc38 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1306,8 +1306,6 @@ def _call_teardown_hook(self) -> None: if self.datamodule is not None: self.datamodule.teardown(stage=fn) - self.data_connector.detach_data(self.lightning_module) - self.call_hook("teardown", stage=fn) self.lightning_module._current_fx_name = None diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index b2693ed5ded48..044b367699818 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -24,7 +24,6 @@ from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import StochasticWeightAveraging from pytorch_lightning.plugins import DDPSpawnPlugin -from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.runif import RunIf @@ -228,7 +227,6 @@ def on_before_accelerator_backend_setup(self, trainer: "Trainer", pl_module: "Li super().on_before_accelerator_backend_setup(trainer, pl_module) assert self._average_model.train_dataloader is not pl_module.train_dataloader assert self._average_model.train_dataloader.__self__ == self._average_model - assert isinstance(pl_module.train_dataloader, _PatchDataLoader) assert self._average_model.trainer is None self.on_before_accelerator_backend_setup_called = True diff --git a/tests/utilities/test_model_helpers.py b/tests/utilities/test_model_helpers.py index 1319e6b44fd8f..0594574c2de9a 100644 --- a/tests/utilities/test_model_helpers.py +++ b/tests/utilities/test_model_helpers.py @@ -78,6 +78,7 @@ def bar(self): model.training_step = partial(model.training_step) assert is_overridden("training_step", model) + # TODO: remove # `_PatchDataLoader.patch_loader_code` support class TestModel(BoringModel): def on_fit_start(self): From a10c526012d6c66e8486a3d050518e0ef7e9c806 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 29 Sep 2021 13:34:14 +0200 Subject: [PATCH 03/50] clean up --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 1 - tests/utilities/test_model_helpers.py | 12 ------------ 2 files changed, 13 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index b08aa80eb9bd1..2959063b8a92c 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -27,7 +27,6 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin -from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection diff --git a/tests/utilities/test_model_helpers.py b/tests/utilities/test_model_helpers.py index 0594574c2de9a..a1d5277697bef 100644 --- a/tests/utilities/test_model_helpers.py +++ b/tests/utilities/test_model_helpers.py @@ -77,15 +77,3 @@ def bar(self): # `partial` support model.training_step = partial(model.training_step) assert is_overridden("training_step", model) - - # TODO: remove - # `_PatchDataLoader.patch_loader_code` support - class TestModel(BoringModel): - def on_fit_start(self): - assert is_overridden("train_dataloader", self) - self.on_fit_start_called = True - - model = TestModel() - trainer = Trainer(fast_dev_run=1) - trainer.fit(model, train_dataloader=model.train_dataloader()) - assert model.on_fit_start_called From a1c5537d0b44f2fa5f41c7b5aed4dee4bfea4d24 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Sep 2021 15:54:20 +0000 Subject: [PATCH 04/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/data_connector.py | 2 +- tests/trainer/test_dataloaders.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 9dd9c80e2117b..1557652c2e147 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,7 +14,7 @@ import os from dataclasses import dataclass from functools import partial -from typing import Iterable, Optional, Union, List, Dict +from typing import Dict, Iterable, List, Optional, Union from torch.utils.data import DataLoader diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 5cb78252b11d8..86999dc99ebe8 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1313,8 +1313,8 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir): def test_dataloaders_reset_and_attach(tmpdir): - """Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before - attaching the new one.""" + """Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching + the new one.""" # the assertions compare the datasets and not dataloaders since we patch and replace the samplers dataloader_0 = DataLoader(dataset=RandomDataset(32, 64)) dataloader_1 = DataLoader(dataset=RandomDataset(32, 64)) From 5c36fbcbce43813f52c3580fd74caed411587fbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 29 Sep 2021 18:14:10 +0200 Subject: [PATCH 05/50] check availability of val/test dataloader --- pytorch_lightning/trainer/data_loading.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 00feb4618d8e1..c1d4f88fc52b6 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -488,10 +488,10 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> Args: model: The `LightningModule` if called outside of the trainer scope. """ + source = self.data_connector._val_dataloader_source pl_module = self.lightning_module or model - has_loader = is_overridden("val_dataloader", pl_module) has_step = is_overridden("validation_step", pl_module) - if has_loader and has_step: + if source.is_available and has_step: self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader( RunningStage.VALIDATING, model=pl_module ) @@ -502,10 +502,10 @@ def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> Args: model: The `LightningModule` if called outside of the trainer scope. """ + source = self.data_connector._test_dataloader_source pl_module = self.lightning_module or model - has_loader = is_overridden("test_dataloader", pl_module) has_step = is_overridden("test_step", pl_module) - if has_loader and has_step: + if source.is_available and has_step: self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader( RunningStage.TESTING, model=pl_module ) From 84086e634a33db52590cb1b69b7243758b056872 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 29 Sep 2021 18:33:18 +0200 Subject: [PATCH 06/50] availability check / property --- pytorch_lightning/trainer/configuration_validator.py | 4 ++-- pytorch_lightning/trainer/connectors/data_connector.py | 4 ++-- pytorch_lightning/trainer/data_loading.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 56732283c9b34..807758acd8d2f 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -63,7 +63,7 @@ def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None # ----------------------------------- # verify model has a train dataloader # ----------------------------------- - has_train_dataloader = self.trainer.data_connector._train_dataloader_source.is_available() + has_train_dataloader = self.trainer.data_connector._train_dataloader_source.available if not has_train_dataloader: raise MisconfigurationException( "No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a" @@ -159,7 +159,7 @@ def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: s ) def __verify_predict_loop_configuration(self, model: "pl.LightningModule") -> None: - has_predict_dataloader = self.trainer.data_connector._predict_dataloader_source.is_available() + has_predict_dataloader = self.trainer.data_connector._predict_dataloader_source.available if not has_predict_dataloader: raise MisconfigurationException("Dataloader not found for `Trainer.predict`") # ---------------------------------------------- diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 1557652c2e147..ee7e484f8e3da 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -269,8 +269,8 @@ def request(self) -> _DATALOADERS: return getattr(self.instance, self.name)() return self.instance - # TODO: needed in config validator? - def is_available(self) -> bool: + @property + def available(self) -> bool: from pytorch_lightning import LightningDataModule, LightningModule return not isinstance(self.instance, (LightningModule, LightningDataModule)) or is_overridden( diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index c1d4f88fc52b6..4f99eb1b4c5ec 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -491,7 +491,7 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> source = self.data_connector._val_dataloader_source pl_module = self.lightning_module or model has_step = is_overridden("validation_step", pl_module) - if source.is_available and has_step: + if source.available and has_step: self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader( RunningStage.VALIDATING, model=pl_module ) @@ -505,7 +505,7 @@ def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> source = self.data_connector._test_dataloader_source pl_module = self.lightning_module or model has_step = is_overridden("test_step", pl_module) - if source.is_available and has_step: + if source.available and has_step: self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader( RunningStage.TESTING, model=pl_module ) @@ -518,7 +518,7 @@ def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) """ source = self.data_connector._predict_dataloader_source pl_module = self.lightning_module or model - if source.is_available: + if source.available: self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader( RunningStage.PREDICTING, model=pl_module ) From 5d93ee4b939ebe3e965acea3f203e5b412c9d959 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 30 Sep 2021 03:14:56 +0200 Subject: [PATCH 07/50] hack around on tpu test --- .../plugins/training_type/tpu_spawn.py | 20 +++++++++---------- tests/plugins/test_tpu_spawn.py | 5 +++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 2959063b8a92c..21520110036e7 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -87,21 +87,21 @@ def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> No " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." ) - # TODO: what to do here? @staticmethod - def _validate_patched_dataloaders(model: Module) -> None: + def _validate_patched_dataloaders(model: "pl.LightningModule") -> None: """Validate and fail fast if the dataloaders were passed directly to fit.""" - if hasattr(model, "train_dataloader") and isinstance(model.train_dataloader, _PatchDataLoader): - TPUSpawnPlugin._validate_dataloader(model.train_dataloader.dataloader) + connector = model.trainer.data_connector + if connector._train_dataloader_source.instance is not model: + TPUSpawnPlugin._validate_dataloader(connector._train_dataloader_source.instance) - if hasattr(model, "val_dataloader") and isinstance(model.val_dataloader, _PatchDataLoader): - TPUSpawnPlugin._validate_dataloader(model.val_dataloader.dataloader) + if connector._val_dataloader_source.instance is not model: + TPUSpawnPlugin._validate_dataloader(connector._val_dataloader_source.instance) - if hasattr(model, "test_dataloader") and isinstance(model.test_dataloader, _PatchDataLoader): - TPUSpawnPlugin._validate_dataloader(model.test_dataloader.dataloader) + if connector._test_dataloader_source.instance is not model: + TPUSpawnPlugin._validate_dataloader(connector._test_dataloader_source.instance) - if hasattr(model, "predict_dataloader") and isinstance(model.predict_dataloader, _PatchDataLoader): - TPUSpawnPlugin._validate_dataloader(model.predict_dataloader.dataloader) + if connector._predict_dataloader_source.instance is not model: + TPUSpawnPlugin._validate_dataloader(connector._predict_dataloader_source.instance) def connect(self, model: "pl.LightningModule") -> None: TPUSpawnPlugin._validate_patched_dataloaders(model) diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index 036e26a7c4a2f..8b04492d149f1 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -61,10 +61,11 @@ def predict_dataloader(self): def test_error_patched_iterable_dataloaders( _, tmpdir, train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders ): + trainer = Trainer() model = BoringModelNoDataloaders() - connector = DataConnector(MagicMock()) + model.trainer = trainer - connector.attach_dataloaders( + trainer.data_connector.attach_dataloaders( model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, From 7b19d5ecccf5876080116053db304c05fb598700 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 30 Sep 2021 03:18:27 +0200 Subject: [PATCH 08/50] fix test_dataloaders_reset_and_attach test --- pytorch_lightning/trainer/connectors/data_connector.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index ee7e484f8e3da..4171dc265c1fc 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -199,6 +199,11 @@ def attach_dataloaders( test_dataloaders: Optional[EVAL_DATALOADERS] = None, predict_dataloaders: Optional[EVAL_DATALOADERS] = None, ) -> None: + self.trainer.train_dataloader = None + self.trainer.val_dataloaders = None + self.trainer.test_dataloaders = None + self.trainer.predict_dataloaders = None + self._train_dataloader_source = DataLoaderSource( train_dataloaders if train_dataloaders is not None else model, "train_dataloader" ) From 1fc7c2a5c458d250b6a4dfb4bb54568bbc4f52e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 30 Sep 2021 03:25:03 +0200 Subject: [PATCH 09/50] wip --- tests/tuner/test_scale_batch_size.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 1d5c7ae2257d3..f473ea5feb892 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -217,8 +217,9 @@ def test_error_on_dataloader_passed_to_fit(tmpdir): limit_train_batches=0.2, auto_scale_batch_size="power", ) - fit_options = dict(train_dataloader=model.dataloader(train=True)) + fit_options = dict(train_dataloaders=model.dataloader(train=True)) + # TODO: specify error message with pytest.raises(MisconfigurationException): trainer.tune(model, **fit_options) From 3e05cfc719bdd0c5dfeb225618493472e746aa3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 30 Sep 2021 03:37:54 +0200 Subject: [PATCH 10/50] specify error message in test --- tests/tuner/test_scale_batch_size.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index f473ea5feb892..0f900dd672b40 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -219,8 +219,10 @@ def test_error_on_dataloader_passed_to_fit(tmpdir): ) fit_options = dict(train_dataloaders=model.dataloader(train=True)) - # TODO: specify error message - with pytest.raises(MisconfigurationException): + with pytest.raises( + MisconfigurationException, + match="The batch scaling feature cannot be used with dataloaders passed directly", + ): trainer.tune(model, **fit_options) From 35bb187d7c7ed17bfe7f8ce7b78328f13194bbd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 30 Sep 2021 03:38:16 +0200 Subject: [PATCH 11/50] fix scale batch size test --- pytorch_lightning/trainer/connectors/data_connector.py | 1 + pytorch_lightning/tuner/batch_size_scaling.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 4171dc265c1fc..400133be6a782 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -274,6 +274,7 @@ def request(self) -> _DATALOADERS: return getattr(self.instance, self.name)() return self.instance + # TODO: move is_overridden check back to config validator? @property def available(self) -> bool: from pytorch_lightning import LightningDataModule, LightningModule diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 0ecc983994afd..82ffb710bf753 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -51,7 +51,7 @@ def scale_batch_size( " If this is not the intended behavior, please remove either one." ) - if hasattr(model.train_dataloader, "patch_loader_code"): + if trainer.data_connector._train_dataloader_source.available: raise MisconfigurationException( "The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`." " Please disable the feature or incorporate the dataloader into the model." From 37dddfc08900e9e6d544560a499133a7631eb247 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 30 Sep 2021 03:38:39 +0200 Subject: [PATCH 12/50] remove patch_loader_code check from is_overridden util --- pytorch_lightning/utilities/model_helpers.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 8596f1c67b812..240470645c481 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -64,10 +64,7 @@ def is_overridden( if parent_attr is None: raise ValueError("The parent should define the method") - # cannot pickle `__code__` so cannot verify if `PatchDataloader` - # exists which shows dataloader methods have been overwritten. - # so, we hack it by using the string representation - instance_code = getattr(instance_attr, "patch_loader_code", None) or instance_attr.__code__ + instance_code = instance_attr.__code__ parent_code = parent_attr.__code__ return instance_code != parent_code From e8b8dcb8c8fa2a5648cbbc13340b757e4e0680b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 30 Sep 2021 03:40:17 +0200 Subject: [PATCH 13/50] remove patch_loader_code reference from plugins registry --- pytorch_lightning/plugins/plugins_registry.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 1040a35eaa369..87a137aaac8a9 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -127,11 +127,7 @@ def is_register_plugins_overridden(plugin: type) -> bool: else: return False - if hasattr(plugin_attr, "patch_loader_code"): - is_overridden = plugin_attr.patch_loader_code != str(super_attr.__code__) - else: - is_overridden = plugin_attr.__code__ is not super_attr.__code__ - return is_overridden + return plugin_attr.__code__ is not super_attr.__code__ def call_training_type_register_plugins(root: Path, base_module: str) -> None: From 91c24dbf8f49bd819b294a750fe2399645e82e31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 30 Sep 2021 10:23:00 +0200 Subject: [PATCH 14/50] add is_module method --- .../trainer/connectors/data_connector.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 400133be6a782..0b81cbd37a74e 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -267,18 +267,16 @@ class DataLoaderSource: instance: Optional[Union[_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] = None name: str = "" - def request(self) -> _DATALOADERS: - from pytorch_lightning import LightningDataModule, LightningModule + @property + def available(self) -> bool: + return not self.is_module() or is_overridden(self.name, self.instance) - if isinstance(self.instance, (LightningModule, LightningDataModule)) and self.name: + def request(self) -> _DATALOADERS: + if self.is_module() and self.name: return getattr(self.instance, self.name)() return self.instance - # TODO: move is_overridden check back to config validator? - @property - def available(self) -> bool: + def is_module(self) -> bool: from pytorch_lightning import LightningDataModule, LightningModule - return not isinstance(self.instance, (LightningModule, LightningDataModule)) or is_overridden( - self.name, self.instance - ) + return isinstance(self.instance, (LightningModule, LightningDataModule)) From 5af46cec9365357266ae6ad34e8af5635bc8b351 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 30 Sep 2021 10:23:15 +0200 Subject: [PATCH 15/50] update tests for is_module() check --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 11 ++++++----- pytorch_lightning/tuner/batch_size_scaling.py | 2 +- tests/plugins/test_tpu_spawn.py | 3 ++- tests/trainer/test_dataloaders.py | 10 ++++------ 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 21520110036e7..acff554d99499 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -27,6 +27,7 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -90,17 +91,17 @@ def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> No @staticmethod def _validate_patched_dataloaders(model: "pl.LightningModule") -> None: """Validate and fail fast if the dataloaders were passed directly to fit.""" - connector = model.trainer.data_connector - if connector._train_dataloader_source.instance is not model: + connector: DataConnector = model.trainer.data_connector + if not connector._train_dataloader_source.is_module(): TPUSpawnPlugin._validate_dataloader(connector._train_dataloader_source.instance) - if connector._val_dataloader_source.instance is not model: + if not connector._val_dataloader_source.is_module(): TPUSpawnPlugin._validate_dataloader(connector._val_dataloader_source.instance) - if connector._test_dataloader_source.instance is not model: + if not connector._test_dataloader_source.is_module(): TPUSpawnPlugin._validate_dataloader(connector._test_dataloader_source.instance) - if connector._predict_dataloader_source.instance is not model: + if not connector._predict_dataloader_source.is_module(): TPUSpawnPlugin._validate_dataloader(connector._predict_dataloader_source.instance) def connect(self, model: "pl.LightningModule") -> None: diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 82ffb710bf753..a8a7d1ef09cc4 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -51,7 +51,7 @@ def scale_batch_size( " If this is not the intended behavior, please remove either one." ) - if trainer.data_connector._train_dataloader_source.available: + if not trainer.data_connector._train_dataloader_source.is_module(): raise MisconfigurationException( "The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`." " Please disable the feature or incorporate the dataloader into the model." diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index 8b04492d149f1..deaef3d7c5f60 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -58,9 +58,10 @@ def predict_dataloader(self): ], ) @mock.patch("pytorch_lightning.plugins.training_type.tpu_spawn.xm") -def test_error_patched_iterable_dataloaders( +def test_error_iterable_dataloaders_passed_to_fit( _, tmpdir, train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders ): + """Test that the TPUSpawnPlugin identifies dataloaders with iterable datasets and fails early.""" trainer = Trainer() model = BoringModelNoDataloaders() model.trainer = trainer diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 86999dc99ebe8..f3bb19891109a 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -195,10 +195,10 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): ckpt_path = trainer.checkpoint_callback.best_model_path trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path) - trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path) + assert len(trainer.test_dataloaders) == n + trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path) assert len(trainer.val_dataloaders) == n - assert len(trainer.test_dataloaders) == n class DummyModel(BoringModel): @@ -551,17 +551,15 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path): # fit model trainer = Trainer(**trainer_options) trainer.fit(model, val_dataloaders=model.dataloader(train=False)) - assert trainer.state.finished, f"Training failed with {trainer.state}" # fit model trainer = Trainer(**trainer_options) trainer.fit(model, val_dataloaders=model.dataloader(train=False)) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}" + if ckpt_path == "specific": ckpt_path = trainer.checkpoint_callback.best_model_path trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path) - - assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}" assert ( len(trainer.test_dataloaders) == 1 ), f"`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}" From a8d4a26545a8145638db3b7b121ab204b490e9c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 11 Oct 2021 15:29:45 +0200 Subject: [PATCH 16/50] update tests --- tests/callbacks/test_early_stopping.py | 2 +- tests/models/test_restore.py | 2 +- tests/trainer/test_dataloaders.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index fe6873c8f43bf..7fd1c068df3e1 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -95,7 +95,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): ) with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"): - new_trainer.fit(model) + new_trainer.fit(model, datamodule=dm) def test_early_stopping_no_extraneous_invocations(tmpdir): diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 1034fa26a3ac5..5e480d8314ad7 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -337,7 +337,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir): # run test set new_trainer = Trainer(**trainer_options) - new_trainer.test(pretrained_model) + new_trainer.test(pretrained_model, datamodule=dm) pretrained_model.cpu() dataloaders = model.test_dataloader() diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d709a86273403..a2ff937f27d02 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1304,8 +1304,8 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir): trainer.reset_test_dataloader.assert_called_once() assert tracker.mock_calls == [ - call.reset_val_dataloader(), call.reset_train_dataloader(model=model), + call.reset_val_dataloader(model=model), call.reset_test_dataloader(), ] From 8da6bfc2fde7dd782d1b3e860dd03e75c975ea01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 11 Oct 2021 23:51:47 +0200 Subject: [PATCH 17/50] fix unused imports --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 1 - tests/plugins/test_tpu_spawn.py | 1 - tests/utilities/test_model_helpers.py | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 84ed7a232ec53..dd85152590915 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -19,7 +19,6 @@ import torch import torch.multiprocessing as mp -from torch.nn import Module from torch.utils.data import DataLoader import pytorch_lightning as pl diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index deaef3d7c5f60..5537125ce3afb 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -21,7 +21,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins.training_type import TPUSpawnPlugin -from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.dataloaders import CustomNotImplementedErrorDataloader diff --git a/tests/utilities/test_model_helpers.py b/tests/utilities/test_model_helpers.py index a1d5277697bef..74239becf506f 100644 --- a/tests/utilities/test_model_helpers.py +++ b/tests/utilities/test_model_helpers.py @@ -16,7 +16,7 @@ import pytest -from pytorch_lightning import LightningDataModule, Trainer +from pytorch_lightning import LightningDataModule from pytorch_lightning.utilities.model_helpers import is_overridden from tests.helpers import BoringDataModule, BoringModel From 791ee39cb52b5c94d7de8e7b996c57cf5bee83c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 00:03:26 +0200 Subject: [PATCH 18/50] update unit tests to use attach data function --- tests/trainer/test_trainer_tricks.py | 30 +++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 922dbdd13ab41..a1bc7e6cafd49 100644 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -84,6 +84,7 @@ def test_overfit_batch_limits(tmpdir): # test train loader applies correct limits # ------------------------------------------------------ trainer = Trainer(overfit_batches=4) + trainer.data_connector.attach_dataloaders(model=model) trainer.reset_train_dataloader(model) assert trainer.num_training_batches == 4 @@ -93,6 +94,7 @@ def test_overfit_batch_limits(tmpdir): assert torch.eq(ya, yb).all() trainer = Trainer(overfit_batches=0.11) + trainer.data_connector.attach_dataloaders(model=model) trainer.reset_train_dataloader(model) # The dataloader should have been overwritten with a Sequential sampler. assert trainer.train_dataloader is not train_loader @@ -111,7 +113,9 @@ def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # test overfit_batches as percent # ------------------------------------------------------ - loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(split, model=model) + trainer = Trainer(overfit_batches=0.11) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == num_train_samples # make sure we turned off shuffle for the user @@ -125,23 +129,35 @@ def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # test overfit_batches as int # ------------------------------------------------------ - loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(split, model=model) + trainer = Trainer(overfit_batches=1) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 1 - loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(split, model=model) + trainer = Trainer(overfit_batches=5) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 5 # ------------------------------------------------------ # test limit_xxx_batches as percent AND int # ------------------------------------------------------ if split == RunningStage.VALIDATING: - loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(split, model=model) + trainer = Trainer(limit_val_batches=0.1) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == int(0.1 * len(val_loader)) - loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(split, model=model) + trainer = Trainer(limit_val_batches=10) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 10 else: - loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(split, model=model) + trainer = Trainer(limit_test_batches=0.1) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == int(0.1 * len(test_loader)) - loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(split, model=model) + trainer = Trainer(limit_test_batches=10) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 10 From 54c10860d3aa8026e548bf1e883c7a422e530856 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 00:36:21 +0200 Subject: [PATCH 19/50] use dataloader from trainer --- tests/models/test_restore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 5e480d8314ad7..ec9a626cfb24f 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -539,7 +539,7 @@ def on_pretrain_routine_end(self): # haven't trained with the new loaded model new_trainer.state.stage = RunningStage.VALIDATING - dataloader = self.train_dataloader() + dataloader = self.trainer.train_dataloader tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader) self.on_pretrain_routine_end_called = True From 28b95eea4986dcd7a6684de8916eef56f9bc3938 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 00:38:17 +0200 Subject: [PATCH 20/50] fix test not using the right dataloader --- tests/models/test_restore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index ec9a626cfb24f..101b3fd678205 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -340,7 +340,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir): new_trainer.test(pretrained_model, datamodule=dm) pretrained_model.cpu() - dataloaders = model.test_dataloader() + dataloaders = dm.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] From 5d7689074d5dc09c6536d846760c1f63864d3cb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 10:46:24 +0200 Subject: [PATCH 21/50] fix test --- tests/models/test_restore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 101b3fd678205..b0bd9e0ce7b51 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -539,7 +539,7 @@ def on_pretrain_routine_end(self): # haven't trained with the new loaded model new_trainer.state.stage = RunningStage.VALIDATING - dataloader = self.trainer.train_dataloader + dataloader = dm.train_dataloader tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader) self.on_pretrain_routine_end_called = True From 8d8f81742eda8b57dc1abde031ec96f917af26b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 10:46:58 +0200 Subject: [PATCH 22/50] fix test --- tests/models/test_restore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index b0bd9e0ce7b51..5c731a17b8a5c 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -539,7 +539,7 @@ def on_pretrain_routine_end(self): # haven't trained with the new loaded model new_trainer.state.stage = RunningStage.VALIDATING - dataloader = dm.train_dataloader + dataloader = dm.train_dataloader() tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader) self.on_pretrain_routine_end_called = True From 7881a5d817d27fe795eac7e9c2f306ca2f52729a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 10:53:41 +0200 Subject: [PATCH 23/50] remove redundant fixme comment --- tests/trainer/test_dataloaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index a2ff937f27d02..dc07f48ec6389 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1484,7 +1484,6 @@ def on_train_batch_start(self, batch, batch_idx: int) -> None: assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper) self.on_train_batch_start_called = True - # FIXME: this patching happens after we set the source, so we still call the old method def on_val_dataloader(self) -> None: loader = self.val_dataloader() self.val_dataloader = DataLoaderFunc(DataLoaderWrapper(loader)) From 791608304a2926869f8c2c17d788c9b1f3ff5da9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 10:54:11 +0200 Subject: [PATCH 24/50] remove comment --- pytorch_lightning/trainer/connectors/data_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 0b81cbd37a74e..362a10604441a 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -38,7 +38,6 @@ def __init__( self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle", - # TODO: remove these args train_data_fetcher: Optional[AbstractDataFetcher] = None, validate_data_fetcher: Optional[AbstractDataFetcher] = None, test_data_fetcher: Optional[AbstractDataFetcher] = None, From 7ec1b5c717e5b6eb0e88c3930df4d94378df6114 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 11:23:31 +0200 Subject: [PATCH 25/50] update --- tests/models/test_restore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 5c731a17b8a5c..f20bff249cd09 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -337,7 +337,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir): # run test set new_trainer = Trainer(**trainer_options) - new_trainer.test(pretrained_model, datamodule=dm) + new_trainer.test(pretrained_model) pretrained_model.cpu() dataloaders = dm.test_dataloader() From abc8bf6a831eafaf3d7a78f2be74706e360a029f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 12:48:16 +0200 Subject: [PATCH 26/50] rename dataloader source --- .../trainer/connectors/data_connector.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 362a10604441a..daae6da773582 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -50,10 +50,10 @@ def __init__( self.test_data_fetcher = test_data_fetcher self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None - self._train_dataloader_source = DataLoaderSource() - self._val_dataloader_source = DataLoaderSource() - self._test_dataloader_source = DataLoaderSource() - self._predict_dataloader_source = DataLoaderSource() + self._train_dataloader_source = _DataLoaderSource() + self._val_dataloader_source = _DataLoaderSource() + self._test_dataloader_source = _DataLoaderSource() + self._predict_dataloader_source = _DataLoaderSource() @property def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]: @@ -203,16 +203,16 @@ def attach_dataloaders( self.trainer.test_dataloaders = None self.trainer.predict_dataloaders = None - self._train_dataloader_source = DataLoaderSource( + self._train_dataloader_source = _DataLoaderSource( train_dataloaders if train_dataloaders is not None else model, "train_dataloader" ) - self._val_dataloader_source = DataLoaderSource( + self._val_dataloader_source = _DataLoaderSource( val_dataloaders if val_dataloaders is not None else model, "val_dataloader" ) - self._test_dataloader_source = DataLoaderSource( + self._test_dataloader_source = _DataLoaderSource( test_dataloaders if test_dataloaders is not None else model, "test_dataloader" ) - self._predict_dataloader_source = DataLoaderSource( + self._predict_dataloader_source = _DataLoaderSource( predict_dataloaders if predict_dataloaders is not None else model, "predict_dataloader" ) @@ -223,10 +223,10 @@ def attach_datamodule( if datamodule is None: return - self._train_dataloader_source = DataLoaderSource(datamodule, "train_dataloader") - self._val_dataloader_source = DataLoaderSource(datamodule, "val_dataloader") - self._test_dataloader_source = DataLoaderSource(datamodule, "test_dataloader") - self._predict_dataloader_source = DataLoaderSource(datamodule, "predict_dataloader") + self._train_dataloader_source = _DataLoaderSource(datamodule, "train_dataloader") + self._val_dataloader_source = _DataLoaderSource(datamodule, "val_dataloader") + self._test_dataloader_source = _DataLoaderSource(datamodule, "test_dataloader") + self._predict_dataloader_source = _DataLoaderSource(datamodule, "predict_dataloader") # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule batch_transfer_hooks = ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") @@ -261,7 +261,7 @@ def teardown(self) -> None: # TODO: type for list/dict of dataloaders @dataclass -class DataLoaderSource: +class _DataLoaderSource: instance: Optional[Union[_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] = None name: str = "" From 5587ea432f74da86b1a3b9105cc0d74da4ba88ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 12:53:40 +0200 Subject: [PATCH 27/50] typing dataloaders --- .../trainer/connectors/data_connector.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index daae6da773582..10e56aa5ec0de 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -256,26 +256,24 @@ def teardown(self) -> None: self.sanity_check_data_fetcher = None -_DATALOADERS = Union[DataLoader, List[DataLoader], Dict[str, DataLoader]] - - -# TODO: type for list/dict of dataloaders @dataclass class _DataLoaderSource: - instance: Optional[Union[_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] = None + instance: Optional[ + Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"] + ] = None name: str = "" @property def available(self) -> bool: return not self.is_module() or is_overridden(self.name, self.instance) - def request(self) -> _DATALOADERS: + def request(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: if self.is_module() and self.name: return getattr(self.instance, self.name)() return self.instance def is_module(self) -> bool: - from pytorch_lightning import LightningDataModule, LightningModule + from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import return isinstance(self.instance, (LightningModule, LightningDataModule)) From 4baa4dad7deaac54ac028963d768d470860cfa7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 13:10:25 +0200 Subject: [PATCH 28/50] add docs --- .../trainer/connectors/data_connector.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 10e56aa5ec0de..73929e1f2cc0e 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -258,6 +258,19 @@ def teardown(self) -> None: @dataclass class _DataLoaderSource: + """Stores the information where the dataloaders come from. + + The source can be + + 1. from a ``*_datalaoder()`` method on the :class:`~pytorch_lightning.core.lightning.LightningModule`, + 2. from a ``*_datalaoder()`` method on the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`, + 3. a direct instance of a :class:`~torch.utils.data.DataLoader` or supported collections thereof. + + Arguments: + instance: A LightningModule, LightningDataModule, or (a collection of) dataloader(s). + name: A name for this dataloader source. If the instance is a module, the name corresponds to the hook + that returns the desired dataloader(s). + """ instance: Optional[ Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"] @@ -266,14 +279,20 @@ class _DataLoaderSource: @property def available(self) -> bool: + """Returns whether the source dataloader is available. If the source is a module it checks that the method + with given :attr:`name` is overridden.""" return not self.is_module() or is_overridden(self.name, self.instance) def request(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: + """Returns the dataloader from the source. If the source is a module, the method with the corresponding + :attr:`name` gets called.""" if self.is_module() and self.name: return getattr(self.instance, self.name)() return self.instance def is_module(self) -> bool: + """Returns whether the the DataLoader source is a LightningModule or a LightningDataModule. + It does not check whether ``*_dataloader`` methods are actually overridden.""" from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import return isinstance(self.instance, (LightningModule, LightningDataModule)) From 6cca816d7270ea347f8528821d8488f195b93da0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 13:14:26 +0200 Subject: [PATCH 29/50] update changelog --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6461b95a05d66..5165b805ce095 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -445,6 +445,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed a redundant warning with `ModelCheckpoint(monitor=None)` callback ([#9875](https://github.com/PyTorchLightning/pytorch-lightning/pull/9875)) +- Removed automatic patching of `{train,val,test,predict}_dataloader()` on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764)) + + ### Fixed @@ -501,6 +504,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Reset `val_dataloader` in `tuner/batch_size_scaling` ([#9857](https://github.com/PyTorchLightning/pytorch-lightning/pull/9857)) +- Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764)) + + ## [1.4.9] - 2021-09-30 - Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704)) From caa869bffdf420e4a7a4acdcbe5e640365ae4c0f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Oct 2021 11:15:39 +0000 Subject: [PATCH 30/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/data_connector.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 73929e1f2cc0e..2895152f545cb 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -279,20 +279,27 @@ class _DataLoaderSource: @property def available(self) -> bool: - """Returns whether the source dataloader is available. If the source is a module it checks that the method - with given :attr:`name` is overridden.""" + """Returns whether the source dataloader is available. + + If the source is a module it checks that the method with given :attr:`name` is overridden. + """ return not self.is_module() or is_overridden(self.name, self.instance) def request(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: - """Returns the dataloader from the source. If the source is a module, the method with the corresponding - :attr:`name` gets called.""" + """Returns the dataloader from the source. + + If the source is a module, the method with the corresponding + :attr:`name` gets called. + """ if self.is_module() and self.name: return getattr(self.instance, self.name)() return self.instance def is_module(self) -> bool: """Returns whether the the DataLoader source is a LightningModule or a LightningDataModule. - It does not check whether ``*_dataloader`` methods are actually overridden.""" + + It does not check whether ``*_dataloader`` methods are actually overridden. + """ from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import return isinstance(self.instance, (LightningModule, LightningDataModule)) From ccbe47785117090ee5297991f83821be28cc0cb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 13:38:23 +0200 Subject: [PATCH 31/50] add unit tests --- .../trainer/connectors/test_data_connector.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 tests/trainer/connectors/test_data_connector.py diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py new file mode 100644 index 0000000000000..2b5acb157dab4 --- /dev/null +++ b/tests/trainer/connectors/test_data_connector.py @@ -0,0 +1,63 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +import pytest +from torch.utils.data import DataLoader + +from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource +from tests.helpers import BoringModel, BoringDataModule + + +class NoDataLoaderModel(BoringModel): + def __init__(self): + super().__init__() + self.train_dataloader = None + + +@pytest.mark.parametrize( + "instance,available", + [ + (None, True), + (BoringModel().train_dataloader(), True), + (BoringModel(), True), + (NoDataLoaderModel(), False), + (BoringDataModule(), True), + ], +) +def test_dataloader_source_available(instance, available): + """Test the availability check for _DataLoaderSource.""" + source = _DataLoaderSource(instance=instance, name="train_dataloader") + assert source.available is available + + +def test_dataloader_source_direct_access(): + """Test requesting a dataloader when the source is already a dataloader.""" + dataloader = BoringModel().train_dataloader() + source = _DataLoaderSource(instance=dataloader, name="any") + assert not source.is_module() + assert source.available + assert source.request() is dataloader + + +def test_dataloader_source_request_from_module(): + """Test requesting a dataloader from a module works.""" + module = BoringModel() + module.foo = Mock(return_value=module.train_dataloader()) + + source = _DataLoaderSource(module, "foo") + assert source.is_module() + module.foo.assert_not_called() + assert isinstance(source.request(), DataLoader) + module.foo.assert_called_once() From 909489c6d297051206bf3022e7d171fdcf03caae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Oct 2021 11:40:23 +0000 Subject: [PATCH 32/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/connectors/test_data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py index 2b5acb157dab4..3cb73b7a879b6 100644 --- a/tests/trainer/connectors/test_data_connector.py +++ b/tests/trainer/connectors/test_data_connector.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource -from tests.helpers import BoringModel, BoringDataModule +from tests.helpers import BoringDataModule, BoringModel class NoDataLoaderModel(BoringModel): From 8dbb918ef99eea4208e018fe56591ca7d53d1370 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 13:40:25 +0200 Subject: [PATCH 33/50] delete methods --- tests/trainer/connectors/test_data_connector.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py index 2b5acb157dab4..fb29f829b484c 100644 --- a/tests/trainer/connectors/test_data_connector.py +++ b/tests/trainer/connectors/test_data_connector.py @@ -24,6 +24,9 @@ class NoDataLoaderModel(BoringModel): def __init__(self): super().__init__() self.train_dataloader = None + self.val_dataloader = None + self.test_dataloader = None + self.predict_dataloader = None @pytest.mark.parametrize( From c4967792c1ca4946ee41f96f24bb653d5ed09e6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 13:47:09 +0200 Subject: [PATCH 34/50] address fixme --- pytorch_lightning/trainer/data_loading.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 13e53d44feae8..a48620d38120f 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -536,20 +536,18 @@ def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = No if self.val_dataloaders is None: self.reset_val_dataloader(model=model) - # FIXME: wrong docstring def request_dataloader( self, stage: RunningStage, model: Optional["pl.LightningModule"] = None ) -> Union[DataLoader, List[DataLoader]]: - """Handles downloading data in the GPU or TPU case. + """Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage. Returns: - The dataloader + The requested dataloader """ source = getattr(self.data_connector, f"_{stage.dataloader_prefix}_dataloader_source") hook = f"{stage.dataloader_prefix}_dataloader" self.call_hook("on_" + hook, pl_module=model) - # dataloader = self.call_hook(hook, pl_module=model) dataloader = source.request() if isinstance(dataloader, tuple): dataloader = list(dataloader) From ff906c11c85887d61a3eda34a15bda6b05d53d4e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 12 Oct 2021 19:05:49 +0530 Subject: [PATCH 35/50] val sanity --- pytorch_lightning/trainer/trainer.py | 4 +++- tests/trainer/test_dataloaders.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cfd8e7b5ef063..60842bdc084a5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1240,7 +1240,9 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: return self.predict_loop.run() def _run_sanity_check(self, ref_model): - using_val_step = ref_model.val_dataloader is not None and is_overridden("validation_step", ref_model) + using_val_step = self.data_connector._val_dataloader_source.available is not None and is_overridden( + "validation_step", ref_model + ) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 # run tiny validation (if validation defined) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 176388b1351ae..9a3c79bdd3cf6 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1304,8 +1304,8 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir): trainer.reset_test_dataloader.assert_called_once() assert tracker.mock_calls == [ + call.reset_val_dataloader(), call.reset_train_dataloader(model=model), - call.reset_val_dataloader(model=model), call.reset_test_dataloader(), ] From fb0f347745bab1d719c681c6cd7c1e7d114aad1e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 12 Oct 2021 19:08:07 +0530 Subject: [PATCH 36/50] val sanity --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 60842bdc084a5..349edc38321ce 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1240,7 +1240,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: return self.predict_loop.run() def _run_sanity_check(self, ref_model): - using_val_step = self.data_connector._val_dataloader_source.available is not None and is_overridden( + using_val_step = self.data_connector._val_dataloader_source.available and is_overridden( "validation_step", ref_model ) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 From 2f2a4312d86b734c1a854d5110c05444c3ee43d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 14 Oct 2021 10:06:17 +0200 Subject: [PATCH 37/50] is_available --- .../trainer/configuration_validator.py | 4 ++-- .../trainer/connectors/data_connector.py | 15 +++++++-------- pytorch_lightning/trainer/data_loading.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 2 +- tests/trainer/connectors/test_data_connector.py | 4 ++-- 5 files changed, 15 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 7d6de6db08022..a945d36c969a2 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -67,7 +67,7 @@ def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None # ----------------------------------- # verify model has a train dataloader # ----------------------------------- - has_train_dataloader = self.trainer.data_connector._train_dataloader_source.available + has_train_dataloader = self.trainer.data_connector._train_dataloader_source.is_available() if not has_train_dataloader: raise MisconfigurationException( "No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a" @@ -176,7 +176,7 @@ def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: s ) def __verify_predict_loop_configuration(self, model: "pl.LightningModule") -> None: - has_predict_dataloader = self.trainer.data_connector._predict_dataloader_source.available + has_predict_dataloader = self.trainer.data_connector._predict_dataloader_source.is_available() if not has_predict_dataloader: raise MisconfigurationException("Dataloader not found for `Trainer.predict`") # ---------------------------------------------- diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 2895152f545cb..d71d8fba6fb1e 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -277,14 +277,6 @@ class _DataLoaderSource: ] = None name: str = "" - @property - def available(self) -> bool: - """Returns whether the source dataloader is available. - - If the source is a module it checks that the method with given :attr:`name` is overridden. - """ - return not self.is_module() or is_overridden(self.name, self.instance) - def request(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: """Returns the dataloader from the source. @@ -295,6 +287,13 @@ def request(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: return getattr(self.instance, self.name)() return self.instance + def is_available(self) -> bool: + """Returns whether the source dataloader is available. + + If the source is a module it checks that the method with given :attr:`name` is overridden. + """ + return not self.is_module() or is_overridden(self.name, self.instance) + def is_module(self) -> bool: """Returns whether the the DataLoader source is a LightningModule or a LightningDataModule. diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index a48620d38120f..90857981d993b 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -490,7 +490,7 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> source = self.data_connector._val_dataloader_source pl_module = self.lightning_module or model has_step = is_overridden("validation_step", pl_module) - if source.available and has_step: + if source.is_available() and has_step: self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader( RunningStage.VALIDATING, model=pl_module ) @@ -504,7 +504,7 @@ def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> source = self.data_connector._test_dataloader_source pl_module = self.lightning_module or model has_step = is_overridden("test_step", pl_module) - if source.available and has_step: + if source.is_available() and has_step: self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader( RunningStage.TESTING, model=pl_module ) @@ -517,7 +517,7 @@ def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) """ source = self.data_connector._predict_dataloader_source pl_module = self.lightning_module or model - if source.available: + if source.is_available(): self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader( RunningStage.PREDICTING, model=pl_module ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 349edc38321ce..bc8ea790ac0bc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1240,7 +1240,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: return self.predict_loop.run() def _run_sanity_check(self, ref_model): - using_val_step = self.data_connector._val_dataloader_source.available and is_overridden( + using_val_step = self.data_connector._val_dataloader_source.is_available() and is_overridden( "validation_step", ref_model ) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py index dc20b42a96364..9328bc4b6eb89 100644 --- a/tests/trainer/connectors/test_data_connector.py +++ b/tests/trainer/connectors/test_data_connector.py @@ -42,7 +42,7 @@ def __init__(self): def test_dataloader_source_available(instance, available): """Test the availability check for _DataLoaderSource.""" source = _DataLoaderSource(instance=instance, name="train_dataloader") - assert source.available is available + assert source.is_available() is available def test_dataloader_source_direct_access(): @@ -50,7 +50,7 @@ def test_dataloader_source_direct_access(): dataloader = BoringModel().train_dataloader() source = _DataLoaderSource(instance=dataloader, name="any") assert not source.is_module() - assert source.available + assert source.is_available() assert source.request() is dataloader From f3bb2fbb5783c9c28af6835b5ada5de827bc892a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 14 Oct 2021 10:10:56 +0200 Subject: [PATCH 38/50] simplify --- .../plugins/training_type/tpu_spawn.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index dd85152590915..9fdf74e7cbd95 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -97,17 +97,15 @@ def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> No def _validate_patched_dataloaders(model: "pl.LightningModule") -> None: """Validate and fail fast if the dataloaders were passed directly to fit.""" connector: DataConnector = model.trainer.data_connector - if not connector._train_dataloader_source.is_module(): - TPUSpawnPlugin._validate_dataloader(connector._train_dataloader_source.instance) - - if not connector._val_dataloader_source.is_module(): - TPUSpawnPlugin._validate_dataloader(connector._val_dataloader_source.instance) - - if not connector._test_dataloader_source.is_module(): - TPUSpawnPlugin._validate_dataloader(connector._test_dataloader_source.instance) - - if not connector._predict_dataloader_source.is_module(): - TPUSpawnPlugin._validate_dataloader(connector._predict_dataloader_source.instance) + sources = ( + connector._train_dataloader_source, + connector._val_dataloader_source, + connector._test_dataloader_source, + connector._predict_dataloader_source, + ) + for dataloader_source in sources: + if not dataloader_source.is_module(): + TPUSpawnPlugin._validate_dataloader(dataloader_source.instance) def connect(self, model: "pl.LightningModule") -> None: TPUSpawnPlugin._validate_patched_dataloaders(model) From a68bbe393984e5a1b519a73b1049cf794a7b2ac3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 14 Oct 2021 10:38:42 +0200 Subject: [PATCH 39/50] use call_hook() for LightningModule --- .../trainer/connectors/data_connector.py | 11 ++++++++++- tests/trainer/connectors/test_data_connector.py | 2 ++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index d71d8fba6fb1e..7b9a498fe6008 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -283,8 +283,17 @@ def request(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: If the source is a module, the method with the corresponding :attr:`name` gets called. """ - if self.is_module() and self.name: + from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import + + if not self.name: + return self.instance + + if isinstance(self.instance, LightningModule): + return self.instance.trainer.call_hook(self.name, pl_module=self.instance) + + if isinstance(self.instance, LightningDataModule): return getattr(self.instance, self.name)() + return self.instance def is_available(self) -> bool: diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py index 9328bc4b6eb89..1e2b66d45d0d3 100644 --- a/tests/trainer/connectors/test_data_connector.py +++ b/tests/trainer/connectors/test_data_connector.py @@ -16,6 +16,7 @@ import pytest from torch.utils.data import DataLoader +from pytorch_lightning import Trainer from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource from tests.helpers import BoringDataModule, BoringModel @@ -57,6 +58,7 @@ def test_dataloader_source_direct_access(): def test_dataloader_source_request_from_module(): """Test requesting a dataloader from a module works.""" module = BoringModel() + module.trainer = Trainer() module.foo = Mock(return_value=module.train_dataloader()) source = _DataLoaderSource(module, "foo") From c355a9dd69b5f4b34d0330ac327672550c14dc6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 14 Oct 2021 12:38:27 +0200 Subject: [PATCH 40/50] ensure model has a trainer in unit tests --- tests/trainer/test_trainer_tricks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index a1bc7e6cafd49..1dd3ab92eb833 100644 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -84,6 +84,7 @@ def test_overfit_batch_limits(tmpdir): # test train loader applies correct limits # ------------------------------------------------------ trainer = Trainer(overfit_batches=4) + model.trainer = trainer trainer.data_connector.attach_dataloaders(model=model) trainer.reset_train_dataloader(model) assert trainer.num_training_batches == 4 @@ -94,6 +95,7 @@ def test_overfit_batch_limits(tmpdir): assert torch.eq(ya, yb).all() trainer = Trainer(overfit_batches=0.11) + model.trainer = trainer trainer.data_connector.attach_dataloaders(model=model) trainer.reset_train_dataloader(model) # The dataloader should have been overwritten with a Sequential sampler. From 3052405eafe7e498c504da8506883fbb7c9a389f Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 14 Oct 2021 09:55:50 -0400 Subject: [PATCH 41/50] fix deepspeed dl request --- pytorch_lightning/plugins/training_type/deepspeed.py | 5 +++-- pytorch_lightning/trainer/connectors/data_connector.py | 4 +--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index e2e8c316f48d1..405ec87e13a45 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -589,8 +589,9 @@ def _auto_select_batch_size(self): # train_micro_batch_size_per_gpu is used for throughput logging purposes # by default we try to use the batch size of the loader batch_size = 1 - if hasattr(self.lightning_module, "train_dataloader"): - train_dataloader = self.lightning_module.train_dataloader() + train_dl_source = self.lightning_module.trainer.data_connector._train_dataloader_source + if train_dl_source.is_available(): + train_dataloader = train_dl_source.request() if hasattr(train_dataloader, "batch_sampler"): batch_size = train_dataloader.batch_sampler.batch_size return batch_size diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 7b9a498fe6008..d1cfb16f4adbd 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,9 +14,7 @@ import os from dataclasses import dataclass from functools import partial -from typing import Dict, Iterable, List, Optional, Union - -from torch.utils.data import DataLoader +from typing import Iterable, Optional, Union import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation From 535f42315f45ad0ca3990805bc0d37823517a700 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 14 Oct 2021 19:15:04 +0200 Subject: [PATCH 42/50] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/trainer/connectors/data_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index d1cfb16f4adbd..1c288207fdff7 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -278,8 +278,7 @@ class _DataLoaderSource: def request(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: """Returns the dataloader from the source. - If the source is a module, the method with the corresponding - :attr:`name` gets called. + If the source is a module, the method with the corresponding :attr:`name` gets called. """ from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import @@ -290,7 +289,8 @@ def request(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: return self.instance.trainer.call_hook(self.name, pl_module=self.instance) if isinstance(self.instance, LightningDataModule): - return getattr(self.instance, self.name)() + method = getattr(self.instance, self.name) + return method() return self.instance From 5f9b699201fd07df0649e047546ae9e6f29829a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 14 Oct 2021 23:46:08 +0200 Subject: [PATCH 43/50] Update pytorch_lightning/utilities/model_helpers.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/utilities/model_helpers.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 240470645c481..3146b33fe153d 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -64,7 +64,4 @@ def is_overridden( if parent_attr is None: raise ValueError("The parent should define the method") - instance_code = instance_attr.__code__ - parent_code = parent_attr.__code__ - - return instance_code != parent_code + return instance_attr.__code__ != parent_attr.__code__ From 2e0496e373e3cf2427abdcf72b301fcb1dd1d5f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 15 Oct 2021 00:30:10 +0200 Subject: [PATCH 44/50] rename is_available --- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- pytorch_lightning/trainer/configuration_validator.py | 4 ++-- pytorch_lightning/trainer/connectors/data_connector.py | 4 ++-- pytorch_lightning/trainer/data_loading.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 2 +- tests/trainer/connectors/test_data_connector.py | 4 ++-- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 405ec87e13a45..05f00f7fb0119 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -590,7 +590,7 @@ def _auto_select_batch_size(self): # by default we try to use the batch size of the loader batch_size = 1 train_dl_source = self.lightning_module.trainer.data_connector._train_dataloader_source - if train_dl_source.is_available(): + if train_dl_source.is_defined(): train_dataloader = train_dl_source.request() if hasattr(train_dataloader, "batch_sampler"): batch_size = train_dataloader.batch_sampler.batch_size diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index a9d4d691079f7..bf38b741b8145 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -67,7 +67,7 @@ def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None # ----------------------------------- # verify model has a train dataloader # ----------------------------------- - has_train_dataloader = self.trainer.data_connector._train_dataloader_source.is_available() + has_train_dataloader = self.trainer.data_connector._train_dataloader_source.is_defined() if not has_train_dataloader: raise MisconfigurationException( "No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a" @@ -176,7 +176,7 @@ def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: s ) def __verify_predict_loop_configuration(self, model: "pl.LightningModule") -> None: - has_predict_dataloader = self.trainer.data_connector._predict_dataloader_source.is_available() + has_predict_dataloader = self.trainer.data_connector._predict_dataloader_source.is_defined() if not has_predict_dataloader: raise MisconfigurationException("Dataloader not found for `Trainer.predict`") # ---------------------------------------------- diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 1c288207fdff7..447151025c1a9 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -294,8 +294,8 @@ def request(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: return self.instance - def is_available(self) -> bool: - """Returns whether the source dataloader is available. + def is_defined(self) -> bool: + """Returns whether the source dataloader can be retrieved or not. If the source is a module it checks that the method with given :attr:`name` is overridden. """ diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 90857981d993b..669b4d0653898 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -490,7 +490,7 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> source = self.data_connector._val_dataloader_source pl_module = self.lightning_module or model has_step = is_overridden("validation_step", pl_module) - if source.is_available() and has_step: + if source.is_defined() and has_step: self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader( RunningStage.VALIDATING, model=pl_module ) @@ -504,7 +504,7 @@ def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> source = self.data_connector._test_dataloader_source pl_module = self.lightning_module or model has_step = is_overridden("test_step", pl_module) - if source.is_available() and has_step: + if source.is_defined() and has_step: self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader( RunningStage.TESTING, model=pl_module ) @@ -517,7 +517,7 @@ def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) """ source = self.data_connector._predict_dataloader_source pl_module = self.lightning_module or model - if source.is_available(): + if source.is_defined(): self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader( RunningStage.PREDICTING, model=pl_module ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2287842389132..b653a0df4f9ae 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1259,7 +1259,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: return self.predict_loop.run() def _run_sanity_check(self, ref_model): - using_val_step = self.data_connector._val_dataloader_source.is_available() and is_overridden( + using_val_step = self.data_connector._val_dataloader_source.is_defined() and is_overridden( "validation_step", ref_model ) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py index 1e2b66d45d0d3..ad315b179acd9 100644 --- a/tests/trainer/connectors/test_data_connector.py +++ b/tests/trainer/connectors/test_data_connector.py @@ -43,7 +43,7 @@ def __init__(self): def test_dataloader_source_available(instance, available): """Test the availability check for _DataLoaderSource.""" source = _DataLoaderSource(instance=instance, name="train_dataloader") - assert source.is_available() is available + assert source.is_defined() is available def test_dataloader_source_direct_access(): @@ -51,7 +51,7 @@ def test_dataloader_source_direct_access(): dataloader = BoringModel().train_dataloader() source = _DataLoaderSource(instance=dataloader, name="any") assert not source.is_module() - assert source.is_available() + assert source.is_defined() assert source.request() is dataloader From c56eb608e95ded9a7081c2e5f977d339854c01b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 Oct 2021 10:52:31 +0200 Subject: [PATCH 45/50] resolve merge error --- .../trainer/configuration_validator.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index bdacd06517f54..5de31a764ecf3 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -62,15 +62,15 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) - # ----------------------------------- - # verify model has a train dataloader - # ----------------------------------- - has_train_dataloader = self.trainer.data_connector._train_dataloader_source.is_defined() - if not has_train_dataloader: - raise MisconfigurationException( - "No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a" - " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." - ) + # ----------------------------------- + # verify model has a train dataloader + # ----------------------------------- + has_train_dataloader = trainer.data_connector._train_dataloader_source.is_defined() + if not has_train_dataloader: + raise MisconfigurationException( + "No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a" + " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." + ) # ----------------------------------- # verify model has optimizer From 7fc97ed4cbe5fc55d87a7c40c688fbe3fe5c64fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 Oct 2021 10:57:21 +0200 Subject: [PATCH 46/50] address reviews --- .../plugins/training_type/tpu_spawn.py | 6 +++--- .../trainer/connectors/data_connector.py | 14 ++++++-------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index bd9772cdc7dc1..6d18612b94f50 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -104,9 +104,9 @@ def _validate_patched_dataloaders(model: "pl.LightningModule") -> None: connector._test_dataloader_source, connector._predict_dataloader_source, ) - for dataloader_source in sources: - if not dataloader_source.is_module(): - TPUSpawnPlugin._validate_dataloader(dataloader_source.instance) + for source in sources: + if not source.is_module(): + TPUSpawnPlugin._validate_dataloader(source.instance) def connect(self, model: "pl.LightningModule") -> None: TPUSpawnPlugin._validate_patched_dataloaders(model) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 447151025c1a9..6037acdcda6f5 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -48,10 +48,10 @@ def __init__( self.test_data_fetcher = test_data_fetcher self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None - self._train_dataloader_source = _DataLoaderSource() - self._val_dataloader_source = _DataLoaderSource() - self._test_dataloader_source = _DataLoaderSource() - self._predict_dataloader_source = _DataLoaderSource() + self._train_dataloader_source = _DataLoaderSource(None, "") + self._val_dataloader_source = _DataLoaderSource(None, "") + self._test_dataloader_source = _DataLoaderSource(None, "") + self._predict_dataloader_source = _DataLoaderSource(None, "") @property def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]: @@ -270,10 +270,8 @@ class _DataLoaderSource: that returns the desired dataloader(s). """ - instance: Optional[ - Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"] - ] = None - name: str = "" + instance: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] + name: str def request(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: """Returns the dataloader from the source. From c194b4456a9417ba08dea21dda531d016315a9bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 Oct 2021 11:00:31 +0200 Subject: [PATCH 47/50] rename reqest -> dataloader --- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- pytorch_lightning/trainer/connectors/data_connector.py | 2 +- pytorch_lightning/trainer/data_loading.py | 2 +- tests/trainer/connectors/test_data_connector.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 6344a1440a55a..1e71c4041b04e 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -630,7 +630,7 @@ def _auto_select_batch_size(self): batch_size = 1 train_dl_source = self.lightning_module.trainer.data_connector._train_dataloader_source if train_dl_source.is_defined(): - train_dataloader = train_dl_source.request() + train_dataloader = train_dl_source.dataloader() if hasattr(train_dataloader, "batch_sampler"): batch_size = train_dataloader.batch_sampler.batch_size return batch_size diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 6037acdcda6f5..9b6f97f1ebec4 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -273,7 +273,7 @@ class _DataLoaderSource: instance: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] name: str - def request(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: + def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: """Returns the dataloader from the source. If the source is a module, the method with the corresponding :attr:`name` gets called. diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 43549ba83836b..59d05c38f0e1d 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -549,7 +549,7 @@ def request_dataloader( hook = f"{stage.dataloader_prefix}_dataloader" self.call_hook("on_" + hook, pl_module=model) - dataloader = source.request() + dataloader = source.dataloader() if isinstance(dataloader, tuple): dataloader = list(dataloader) self.training_type_plugin.barrier("get_dataloaders") diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py index ad315b179acd9..4d614ecc25f6e 100644 --- a/tests/trainer/connectors/test_data_connector.py +++ b/tests/trainer/connectors/test_data_connector.py @@ -52,7 +52,7 @@ def test_dataloader_source_direct_access(): source = _DataLoaderSource(instance=dataloader, name="any") assert not source.is_module() assert source.is_defined() - assert source.request() is dataloader + assert source.dataloader() is dataloader def test_dataloader_source_request_from_module(): @@ -64,5 +64,5 @@ def test_dataloader_source_request_from_module(): source = _DataLoaderSource(module, "foo") assert source.is_module() module.foo.assert_not_called() - assert isinstance(source.request(), DataLoader) + assert isinstance(source.dataloader(), DataLoader) module.foo.assert_called_once() From 7883b22f4c00862a9de65cc76e508d51e3293514 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 Oct 2021 11:26:09 +0200 Subject: [PATCH 48/50] update predict check --- pytorch_lightning/trainer/configuration_validator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 5de31a764ecf3..88c319ac57431 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -38,7 +38,7 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule elif trainer.state.fn == TrainerFn.TESTING: __verify_eval_loop_configuration(model, "test") elif trainer.state.fn == TrainerFn.PREDICTING: - __verify_predict_loop_configuration(model) + __verify_predict_loop_configuration(trainer, model) __verify_dp_batch_transfer_support(trainer, model) _check_add_get_queue(model) # TODO(@daniellepintz): Delete _check_progress_bar in v1.7 @@ -175,8 +175,8 @@ def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) -> ) -def __verify_predict_loop_configuration(model: "pl.LightningModule") -> None: - has_predict_dataloader = is_overridden("predict_dataloader", model) +def __verify_predict_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: + has_predict_dataloader = trainer.data_connector._predict_dataloader_source.is_defined() if not has_predict_dataloader: raise MisconfigurationException("Dataloader not found for `Trainer.predict`") # ---------------------------------------------- From b8e80349e6c8ac089999e26446b3063bba756ec1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 Oct 2021 14:34:53 +0200 Subject: [PATCH 49/50] fix bug in example --- pl_examples/basic_examples/backbone_image_classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index f3e0297d0ed15..c388745bb3ea5 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -125,7 +125,7 @@ def cli_main(): cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False) cli.trainer.fit(cli.model, datamodule=cli.datamodule) cli.trainer.test(ckpt_path="best") - predictions = cli.trainer.predict(ckpt_path="best") + predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule) print(predictions[0]) From 4fb6081fde3557e2193c560cfa1be4a27c9239d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 Oct 2021 14:45:26 +0200 Subject: [PATCH 50/50] add datamodules test() and predict() calls, otherwise loops get skipped with warning --- pl_examples/basic_examples/autoencoder.py | 4 ++-- pl_examples/basic_examples/backbone_image_classifier.py | 2 +- pl_examples/basic_examples/dali_image_classifier.py | 2 +- pl_examples/basic_examples/simple_image_classifier.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index e4890c0867b05..dc1e3d09d0a59 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -181,8 +181,8 @@ def cli_main(): trainer_defaults={"callbacks": ImageSampler(), "max_epochs": 10}, ) cli.trainer.fit(cli.model, datamodule=cli.datamodule) - cli.trainer.test(ckpt_path="best") - predictions = cli.trainer.predict(ckpt_path="best") + cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) + predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule) print(predictions[0]) diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index c388745bb3ea5..1f279ca85b4bb 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -124,7 +124,7 @@ def predict_dataloader(self): def cli_main(): cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False) cli.trainer.fit(cli.model, datamodule=cli.datamodule) - cli.trainer.test(ckpt_path="best") + cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule) print(predictions[0]) diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index 2cbc35f6b4805..49bebf44ca522 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -194,7 +194,7 @@ def cli_main(): cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False) cli.trainer.fit(cli.model, datamodule=cli.datamodule) - cli.trainer.test(ckpt_path="best") + cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) if __name__ == "__main__": diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py index 8e2850e17cd8a..146f25c27c0d4 100644 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ b/pl_examples/basic_examples/simple_image_classifier.py @@ -74,7 +74,7 @@ def cli_main(): LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False ) cli.trainer.fit(cli.model, datamodule=cli.datamodule) - cli.trainer.test(ckpt_path="best") + cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) if __name__ == "__main__":