diff --git a/CHANGELOG.md b/CHANGELOG.md index 99eb26935f197..3ba25873bfee7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -149,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511)) +- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) + + ## [1.2.3] - 2021-03-09 ### Fixed diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 9cb22f39b7228..a5c3a8d04a1dd 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -30,7 +30,9 @@ def verify_loop_configurations(self, model: LightningModule): model: The model to check the configuration. """ - if not self.trainer.testing: + if self.trainer.predicting: + self.__verify_predict_loop_configuration(model) + elif not self.trainer.testing: self.__verify_train_loop_configuration(model) self.__verify_eval_loop_configuration(model, 'validation') else: @@ -98,3 +100,9 @@ def __verify_eval_loop_configuration(self, model, eval_loop_name): rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop') if has_step and not has_loader: rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop') + + def __verify_predict_loop_configuration(self, model): + + has_predict_dataloader = is_overridden('predict_dataloader', model) + if not has_predict_dataloader: + raise MisconfigurationException('Dataloader not found for `Trainer.predict`') diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 59f3c2b54c13c..6966edc3cbf70 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1850,3 +1850,35 @@ def compare_optimizers(): trainer.max_epochs = 2 # simulate multiple fit calls trainer.fit(model) compare_optimizers() + + +@pytest.mark.parametrize("use_datamodule", [False, True]) +def test_trainer_predict_verify_config(tmpdir, use_datamodule): + + class TestModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir) + + if use_datamodule: + datamodule = TestLightningDataModule(dataloaders) + results = trainer.predict(model, datamodule=datamodule) + else: + results = trainer.predict(model, dataloaders=dataloaders) + + assert len(results) == 2 + assert results[0][0].shape == torch.Size([1, 2]) + + model.predict_dataloader = None + + with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): + trainer.predict(model)