Skip to content

Commit

Permalink
Update docs on arg train_dataloader in fit (#6076)
Browse files Browse the repository at this point in the history
* add to docs

* update docs

* Apply suggestions from code review

* Update pytorch_lightning/core/hooks.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* nested loaders

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* shorten text length

* Update pytorch_lightning/core/hooks.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
4 people authored and tchaton committed Mar 9, 2021
1 parent cf9e408 commit de0efa9
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 7 deletions.
23 changes: 23 additions & 0 deletions docs/source/advanced/multiple_loaders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Lightning supports multiple dataloaders in a few ways.

----------

.. _multiple-training-dataloaders:

Multiple training dataloaders
-----------------------------
For training, the usual way to use multiple dataloaders is to create a ``DataLoader`` class
Expand Down Expand Up @@ -86,6 +88,27 @@ For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer

return loaders

Furthermore, Lightning also supports that nested lists and dicts (or a combination) can
be returned

.. testcode::

class LitModel(LightningModule):

def train_dataloader(self):

loader_a = torch.utils.data.DataLoader(range(8), batch_size=4)
loader_b = torch.utils.data.DataLoader(range(16), batch_size=4)
loader_c = torch.utils.data.DataLoader(range(32), batch_size=4)
loader_c = torch.utils.data.DataLoader(range(64), batch_size=4)

# pass loaders as a nested dict. This will create batches like this:
# {'loader_a_b': {'a': batch from loader a, 'b': batch from loader b},
# 'loader_c_d': {'c': batch from loader c, 'd': batch from loader d}}
loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b},
'loaders_c_d': {'c': loader_c, 'd': loader_d}}
return loaders

----------

Test/Val dataloaders
Expand Down
35 changes: 32 additions & 3 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,14 @@ def prepare_data(self):
model.test_dataloader()
"""

def train_dataloader(self) -> DataLoader:
def train_dataloader(self) -> Any:
"""
Implement a PyTorch DataLoader for training.
Implement one or more PyTorch DataLoaders for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these
(list, dict, nested lists and dicts). In the case of multiple dataloaders, please see
this :ref:`page <multiple-training-dataloaders>`
The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
Expand All @@ -414,6 +416,7 @@ def train_dataloader(self) -> DataLoader:
Example::
# single dataloader
def train_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
Expand All @@ -426,6 +429,32 @@ def train_dataloader(self):
)
return loader
# multiple dataloaders, return as list
def train_dataloader(self):
mnist = MNIST(...)
cifar = CIFAR(...)
mnist_loader = torch.utils.data.DataLoader(
dataset=mnist, batch_size=self.batch_size, shuffle=True
)
cifar_loader = torch.utils.data.DataLoader(
dataset=cifar, batch_size=self.batch_size, shuffle=True
)
# each batch will be a list of tensors: [batch_mnist, batch_cifar]
return [mnist_loader, cifar_loader]
# multiple dataloader, return as dict
def train_dataloader(self):
mnist = MNIST(...)
cifar = CIFAR(...)
mnist_loader = torch.utils.data.DataLoader(
dataset=mnist, batch_size=self.batch_size, shuffle=True
)
cifar_loader = torch.utils.data.DataLoader(
dataset=cifar, batch_size=self.batch_size, shuffle=True
)
# each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
return {'mnist': mnist_loader, 'cifar': cifar_loader}
"""
rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer")

Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings
from itertools import count
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -425,7 +425,7 @@ def setup_trainer(self, model: LightningModule):
def fit(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
train_dataloader: Any = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
):
Expand All @@ -437,8 +437,9 @@ def fit(
model: Model to fit.
train_dataloader: A Pytorch DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.
train_dataloader: Either a single PyTorch DataLoader or a collection of these
(list, dict, nested lists and dicts). In the case of multiple dataloaders, please
see this :ref:`page <multiple-training-dataloaders>`
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped
Expand Down

0 comments on commit de0efa9

Please sign in to comment.