From ad674c50f1dd0e6f2b4ad1f9cfedf15d93cf7a5b Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 16 Mar 2021 10:55:15 +0530 Subject: [PATCH 1/4] Add trainer.predict config validation --- .../trainer/configuration_validator.py | 9 ++++- tests/trainer/test_trainer.py | 33 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 8c539b5ff478d..9eea5d4479607 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -40,7 +40,8 @@ def verify_loop_configurations(self, model: LightningModule) -> None: self.__verify_eval_loop_configuration(model, 'val') elif self.trainer.state == TrainerState.TESTING: self.__verify_eval_loop_configuration(model, 'test') - # TODO: add predict + elif self.trainer.state == TrainerState.PREDICTING: + self.__verify_predict_loop_configuration(model) def __verify_train_loop_configuration(self, model): # ----------------------------------- @@ -99,3 +100,9 @@ def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) - rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop') if has_step and not has_loader: rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop') + + def __verify_predict_loop_configuration(self, model): + + has_predict_dataloader = is_overridden('predict_dataloader', model) + if not has_predict_dataloader: + raise MisconfigurationException('Dataloader not found for `Trainer.predict`') diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3375b02c5496b..2cfe8c5cea060 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1856,3 +1856,36 @@ def test_check_val_every_n_epoch_exception(tmpdir): max_epochs=1, check_val_every_n_epoch=1.2, ) + + +@pytest.mark.parametrize("datamodule", [False, True]) +def test_trainer_predict_verify_config(tmpdir, datamodule): + + class TestModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + + model = TestModel() + + trainer = Trainer(default_root_dir=tmpdir) + + if datamodule: + datamodule = TestLightningDataModule(dataloaders) + results = trainer.predict(model, datamodule=datamodule) + else: + results = trainer.predict(model, dataloaders=dataloaders) + + assert len(results) == 2 + assert results[0][0].shape == torch.Size([1, 2]) + + model.predict_dataloader = None + + with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): + trainer.predict(model) From fadebc6c964d39f508a65d0ed93c8369f2a6698d Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 16 Mar 2021 10:58:33 +0530 Subject: [PATCH 2/4] Update Changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f5dcb20375137..f33372e81f2ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) +- Added `Trainer.predict` config validation ([#6543]https://github.com/PyTorchLightning/pytorch-lightning/pull/6543) + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) From fee8f88778e529feed8a9894ee90c990ea0aeaa4 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Sat, 20 Mar 2021 01:35:58 +0530 Subject: [PATCH 3/4] Update Tests --- .../trainer/configuration_validator.py | 2 +- tests/trainer/test_config_validator.py | 50 ++++++++++++++++++- tests/trainer/test_trainer.py | 33 ------------ 3 files changed, 49 insertions(+), 36 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 9eea5d4479607..a7ba2b1c40123 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -101,7 +101,7 @@ def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) - if has_step and not has_loader: rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop') - def __verify_predict_loop_configuration(self, model): + def __verify_predict_loop_configuration(self, model: LightningModule) -> None: has_predict_dataloader = is_overridden('predict_dataloader', model) if not has_predict_dataloader: diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 59e10480a485e..9fccd9b36440a 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +import torch -from pytorch_lightning import Trainer +from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset def test_wrong_train_setting(tmpdir): @@ -101,3 +102,48 @@ def test_val_loop_config(tmpdir): model = BoringModel() model.validation_step = None trainer.validate(model) + + +@pytest.mark.parametrize("datamodule", [False, True]) +def test_trainer_predict_verify_config(tmpdir, datamodule): + + class TestModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + class TestLightningDataModule(LightningDataModule): + + def __init__(self, dataloaders): + super().__init__() + self._dataloaders = dataloaders + + def test_dataloader(self): + return self._dataloaders + + def predict_dataloader(self): + return self._dataloaders + + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + + model = TestModel() + + trainer = Trainer(default_root_dir=tmpdir) + + if datamodule: + datamodule = TestLightningDataModule(dataloaders) + results = trainer.predict(model, datamodule=datamodule) + else: + results = trainer.predict(model, dataloaders=dataloaders) + + assert len(results) == 2 + assert results[0][0].shape == torch.Size([1, 2]) + + model.predict_dataloader = None + + with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): + trainer.predict(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2cfe8c5cea060..3375b02c5496b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1856,36 +1856,3 @@ def test_check_val_every_n_epoch_exception(tmpdir): max_epochs=1, check_val_every_n_epoch=1.2, ) - - -@pytest.mark.parametrize("datamodule", [False, True]) -def test_trainer_predict_verify_config(tmpdir, datamodule): - - class TestModel(LightningModule): - - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2) - - def forward(self, x): - return self.layer(x) - - dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] - - model = TestModel() - - trainer = Trainer(default_root_dir=tmpdir) - - if datamodule: - datamodule = TestLightningDataModule(dataloaders) - results = trainer.predict(model, datamodule=datamodule) - else: - results = trainer.predict(model, dataloaders=dataloaders) - - assert len(results) == 2 - assert results[0][0].shape == torch.Size([1, 2]) - - model.predict_dataloader = None - - with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): - trainer.predict(model) From 631b6e74f3524e78451262f3df7987755d293ec1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sun, 21 Mar 2021 21:34:39 +0100 Subject: [PATCH 4/4] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 304425d251a22..63437bf1d6dfb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) -- Added `Trainer.predict` config validation ([#6543]https://github.com/PyTorchLightning/pytorch-lightning/pull/6543) +- Added `Trainer.predict` config validation ([#6543](https://github.com/PyTorchLightning/pytorch-lightning/pull/6543)) - Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))