diff --git a/CHANGELOG.md b/CHANGELOG.md index 5457a6e980318..66bf99ba78a43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added transfer learning example (for a binary classification task in computer vision) ([#1564](https://github.com/PyTorchLightning/pytorch-lightning/pull/1564)) +- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)). + ### Changed - Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7bfd97bfb83f1..23c40bfb50a78 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -672,7 +672,7 @@ def fit( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[DataLoader] = None + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None ): r""" Runs the full optimization routine. @@ -913,7 +913,11 @@ def run_pretrain_routine(self, model: LightningModule): # CORE TRAINING LOOP self.train() - def test(self, model: Optional[LightningModule] = None, test_dataloaders: Optional[DataLoader] = None): + def test( + self, + model: Optional[LightningModule] = None, + test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None + ): r""" Separates from fit to make sure you never run on your test set until you want to. diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index b6f6262ee90e1..d847b6c8730c2 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -113,7 +113,7 @@ class CurrentTestModel( trainer.fit(model) trainer.test() - # verify there are 2 val loaders + # verify there are 2 test loaders assert len(trainer.test_dataloaders) == 2, \ 'Multiple test_dataloaders not initiated properly' @@ -125,7 +125,7 @@ class CurrentTestModel( trainer.test() -def test_train_dataloaders_passed_to_fit(tmpdir): +def test_train_dataloader_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit """ class CurrentTestModel(LightTrainDataloader, TestModelBase): @@ -175,7 +175,7 @@ class CurrentTestModel( def test_all_dataloaders_passed_to_fit(tmpdir): - """Verify train, val & test dataloader can be passed to fit """ + """Verify train, val & test dataloader(s) can be passed to fit and test method""" class CurrentTestModel( LightTrainDataloader,