-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
248 additions
and
324 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
from typing import List, Union | ||
|
||
from torch.utils.data import DataLoader | ||
|
||
from pytorch_lightning.utilities import rank_zero_warn | ||
|
||
class DataHooks: | ||
|
||
def setup(self, stage: str): | ||
""" | ||
Called at the beginning of fit and test. | ||
This is a good hook when you need to build models dynamically or adjust something about them. | ||
This hook is called on every process when using DDP. | ||
Args: | ||
stage: either 'fit' or 'test' | ||
Example:: | ||
class LitModel(...): | ||
def __init__(self): | ||
self.l1 = None | ||
def prepare_data(self): | ||
download_data() | ||
tokenize() | ||
# don't do this | ||
self.something = else | ||
def setup(stage): | ||
data = Load_data(...) | ||
self.l1 = nn.Linear(28, data.num_classes) | ||
""" | ||
|
||
def prepare_data(self) -> None: | ||
""" | ||
Use this to download and prepare data. | ||
.. warning:: DO NOT set state to the model (use `setup` instead) | ||
since this is NOT called on every GPU in DDP/TPU | ||
Example:: | ||
def prepare_data(self): | ||
# good | ||
download_data() | ||
tokenize() | ||
etc() | ||
# bad | ||
self.split = data_split | ||
self.some_state = some_other_state() | ||
In DDP prepare_data can be called in two ways (using Trainer(prepare_data_per_node)): | ||
1. Once per node. This is the default and is only called on LOCAL_RANK=0. | ||
2. Once in total. Only called on GLOBAL_RANK=0. | ||
Example:: | ||
# DEFAULT | ||
# called once per node on LOCAL_RANK=0 of that node | ||
Trainer(prepare_data_per_node=True) | ||
# call on GLOBAL_RANK=0 (great for shared file systems) | ||
Trainer(prepare_data_per_node=False) | ||
This is called before requesting the dataloaders: | ||
.. code-block:: python | ||
model.prepare_data() | ||
if ddp/tpu: init() | ||
model.setup(stage) | ||
model.train_dataloader() | ||
model.val_dataloader() | ||
model.test_dataloader() | ||
""" | ||
|
||
def train_dataloader(self) -> DataLoader: | ||
""" | ||
Implement a PyTorch DataLoader for training. | ||
Return: | ||
Single PyTorch :class:`~torch.utils.data.DataLoader`. | ||
The dataloader you return will not be called every epoch unless you set | ||
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. | ||
For data processing use the following pattern: | ||
- download in :meth:`prepare_data` | ||
- process and split in :meth:`setup` | ||
However, the above are only necessary for distributed processing. | ||
.. warning:: do not assign state in prepare_data | ||
- :meth:`~pytorch_lightning.trainer.Trainer.fit` | ||
- ... | ||
- :meth:`prepare_data` | ||
- :meth:`setup` | ||
- :meth:`train_dataloader` | ||
Note: | ||
Lightning adds the correct sampler for distributed and arbitrary hardware. | ||
There is no need to set it yourself. | ||
Example: | ||
.. code-block:: python | ||
def train_dataloader(self): | ||
transform = transforms.Compose([transforms.ToTensor(), | ||
transforms.Normalize((0.5,), (1.0,))]) | ||
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, | ||
download=True) | ||
loader = torch.utils.data.DataLoader( | ||
dataset=dataset, | ||
batch_size=self.batch_size, | ||
shuffle=True | ||
) | ||
return loader | ||
""" | ||
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer') | ||
|
||
def tng_dataloader(self): # todo: remove in v1.0.0 | ||
""" | ||
Warnings: | ||
Deprecated in v0.5.0. Use :meth:`train_dataloader` instead. Will be removed in 1.0.0. | ||
""" | ||
output = self.train_dataloader() | ||
rank_zero_warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0." | ||
" and this method will be removed in v1.0.0", DeprecationWarning) | ||
return output | ||
|
||
def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: | ||
r""" | ||
Implement one or multiple PyTorch DataLoaders for testing. | ||
The dataloader you return will not be called every epoch unless you set | ||
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. | ||
For data processing use the following pattern: | ||
- download in :meth:`prepare_data` | ||
- process and split in :meth:`setup` | ||
However, the above are only necessary for distributed processing. | ||
.. warning:: do not assign state in prepare_data | ||
- :meth:`~pytorch_lightning.trainer.Trainer.fit` | ||
- ... | ||
- :meth:`prepare_data` | ||
- :meth:`setup` | ||
- :meth:`train_dataloader` | ||
- :meth:`val_dataloader` | ||
- :meth:`test_dataloader` | ||
Note: | ||
Lightning adds the correct sampler for distributed and arbitrary hardware. | ||
There is no need to set it yourself. | ||
Return: | ||
Single or multiple PyTorch DataLoaders. | ||
Example: | ||
.. code-block:: python | ||
def test_dataloader(self): | ||
transform = transforms.Compose([transforms.ToTensor(), | ||
transforms.Normalize((0.5,), (1.0,))]) | ||
dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, | ||
download=True) | ||
loader = torch.utils.data.DataLoader( | ||
dataset=dataset, | ||
batch_size=self.batch_size, | ||
shuffle=False | ||
) | ||
return loader | ||
Note: | ||
If you don't need a test dataset and a :meth:`test_step`, you don't need to implement | ||
this method. | ||
""" | ||
|
||
def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: | ||
r""" | ||
Implement one or multiple PyTorch DataLoaders for validation. | ||
The dataloader you return will not be called every epoch unless you set | ||
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. | ||
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. | ||
- :meth:`~pytorch_lightning.trainer.Trainer.fit` | ||
- ... | ||
- :meth:`prepare_data` | ||
- :meth:`train_dataloader` | ||
- :meth:`val_dataloader` | ||
- :meth:`test_dataloader` | ||
Note: | ||
Lightning adds the correct sampler for distributed and arbitrary hardware | ||
There is no need to set it yourself. | ||
Return: | ||
Single or multiple PyTorch DataLoaders. | ||
Examples: | ||
.. code-block:: python | ||
def val_dataloader(self): | ||
transform = transforms.Compose([transforms.ToTensor(), | ||
transforms.Normalize((0.5,), (1.0,))]) | ||
dataset = MNIST(root='/path/to/mnist/', train=False, | ||
transform=transform, download=True) | ||
loader = torch.utils.data.DataLoader( | ||
dataset=dataset, | ||
batch_size=self.batch_size, | ||
shuffle=False | ||
) | ||
return loader | ||
# can also return multiple dataloaders | ||
def val_dataloader(self): | ||
return [loader_a, loader_b, ..., loader_n] | ||
Note: | ||
If you don't need a validation dataset and a :meth:`validation_step`, you don't need to | ||
implement this method. | ||
Note: | ||
In the case where you return multiple validation dataloaders, the :meth:`validation_step` | ||
will have an argument ``dataset_idx`` which matches the order here. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.