diff --git a/docs/source/advanced/multiple_loaders.rst b/docs/source/advanced/multiple_loaders.rst index fb1aa33f80462..2e3e3201b2181 100644 --- a/docs/source/advanced/multiple_loaders.rst +++ b/docs/source/advanced/multiple_loaders.rst @@ -16,6 +16,8 @@ Lightning supports multiple dataloaders in a few ways. ---------- +.. _multiple-training-dataloaders: + Multiple training dataloaders ----------------------------- For training, the usual way to use multiple dataloaders is to create a ``DataLoader`` class @@ -86,6 +88,27 @@ For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer return loaders +Furthermore, Lightning also supports that nested lists and dicts (or a combination) can +be returned + +.. testcode:: + + class LitModel(LightningModule): + + def train_dataloader(self): + + loader_a = torch.utils.data.DataLoader(range(8), batch_size=4) + loader_b = torch.utils.data.DataLoader(range(16), batch_size=4) + loader_c = torch.utils.data.DataLoader(range(32), batch_size=4) + loader_c = torch.utils.data.DataLoader(range(64), batch_size=4) + + # pass loaders as a nested dict. This will create batches like this: + # {'loader_a_b': {'a': batch from loader a, 'b': batch from loader b}, + # 'loader_c_d': {'c': batch from loader c, 'd': batch from loader d}} + loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b}, + 'loaders_c_d': {'c': loader_c, 'd': loader_d}} + return loaders + ---------- Test/Val dataloaders diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index e0b33c1219e8b..604803365298c 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -383,12 +383,14 @@ def prepare_data(self): model.test_dataloader() """ - def train_dataloader(self) -> DataLoader: + def train_dataloader(self) -> Any: """ - Implement a PyTorch DataLoader for training. + Implement one or more PyTorch DataLoaders for training. Return: - Single PyTorch :class:`~torch.utils.data.DataLoader`. + Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these + (list, dict, nested lists and dicts). In the case of multiple dataloaders, please see + this :ref:`page ` The dataloader you return will not be called every epoch unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. @@ -414,6 +416,7 @@ def train_dataloader(self) -> DataLoader: Example:: + # single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) @@ -426,6 +429,32 @@ def train_dataloader(self): ) return loader + # multiple dataloaders, return as list + def train_dataloader(self): + mnist = MNIST(...) + cifar = CIFAR(...) + mnist_loader = torch.utils.data.DataLoader( + dataset=mnist, batch_size=self.batch_size, shuffle=True + ) + cifar_loader = torch.utils.data.DataLoader( + dataset=cifar, batch_size=self.batch_size, shuffle=True + ) + # each batch will be a list of tensors: [batch_mnist, batch_cifar] + return [mnist_loader, cifar_loader] + + # multiple dataloader, return as dict + def train_dataloader(self): + mnist = MNIST(...) + cifar = CIFAR(...) + mnist_loader = torch.utils.data.DataLoader( + dataset=mnist, batch_size=self.batch_size, shuffle=True + ) + cifar_loader = torch.utils.data.DataLoader( + dataset=cifar, batch_size=self.batch_size, shuffle=True + ) + # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} + return {'mnist': mnist_loader, 'cifar': cifar_loader} + """ rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 58ce7f09ea2a2..dd81f4b53ce13 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -15,7 +15,7 @@ import warnings from itertools import count from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union import torch from torch.utils.data import DataLoader @@ -425,7 +425,7 @@ def setup_trainer(self, model: LightningModule): def fit( self, model: LightningModule, - train_dataloader: Optional[DataLoader] = None, + train_dataloader: Any = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, ): @@ -437,8 +437,9 @@ def fit( model: Model to fit. - train_dataloader: A Pytorch DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. + train_dataloader: Either a single PyTorch DataLoader or a collection of these + (list, dict, nested lists and dicts). In the case of multiple dataloaders, please + see this :ref:`page ` val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped