Skip to content

Commit

Permalink
Use Lightning DataModules (#130)
Browse files Browse the repository at this point in the history
* ✨ Use pl.LightningDataModule

* 📝 add docs

* 🔥 remove bolts LightningDataModule

* 📝 add docs.

* 🎨 remove batch_size arg from loaders, move to init

* 📌 Pin Pytorch Lightning v0.9.0rc3

* 🎨 cleanup dm

* 🎨 cleanup dm

* ✅ update tests

* ✅ update tests

* 🎨 cleanup dm

* 🎨 cleanup dm

* ✅ add random test to make coverage pass

* 🎨 update dm

* ✅ add another random test to pass coverage

* ✅ add lars tests
  • Loading branch information
nateraw authored Jul 29, 2020
1 parent de5f972 commit 07dcc0f
Show file tree
Hide file tree
Showing 23 changed files with 129 additions and 390 deletions.
11 changes: 3 additions & 8 deletions docs/source/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ Step 4, also needs special care to make sure that it's only done on 1 GPU in a m
In addition, there are other challenges such as models that are built using information from the dataset
such as needing to know image dimensions or number of classes.

A datamodule simplifies all of these parts and integrates seamlessly into Lightning.
A datamodule simplifies all of these parts and has been integrated directly into Lightning in version 0.9.0.
You can view the documentation for the datamodule in the `Pytorch Lightning docs here. <https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html>`_

.. code-block:: python
Expand Down Expand Up @@ -92,7 +93,7 @@ Use this to build your own consistent train, validation, test splits.

Example::

from pl_bolts.datamodules import LightningDataModule
from pytorch_lightning import LightningDataModule

class MyDataModule(LightningDataModule):

Expand Down Expand Up @@ -157,12 +158,6 @@ or::
for b in dataloader:
...

DataModule class
^^^^^^^^^^^^^^^^

.. autoclass:: pl_bolts.datamodules.lightning_datamodule.LightningDataModule
:noindex:

-------------

DummyDataset
Expand Down
1 change: 0 additions & 1 deletion pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
from pl_bolts.datamodules.lightning_datamodule import LightningDataModule
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset, SklearnDataModule, TensorDataset, TensorDataModule
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
Expand Down
27 changes: 10 additions & 17 deletions pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Optional, Sequence

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
from torchvision import transforms as transform_lib
from torchvision.datasets import CIFAR10

from pl_bolts.datamodules.cifar10_dataset import TrialCIFAR10
from pl_bolts.datamodules.lightning_datamodule import LightningDataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization


Expand All @@ -19,6 +19,7 @@ def __init__(
data_dir,
val_split=5000,
num_workers=16,
batch_size=32,
*args,
**kwargs,
):
Expand Down Expand Up @@ -54,13 +55,15 @@ def __init__(
data_dir: where to save/load the data
val_split: how many of the training images to use for the validation split
num_workers: how many workers to use for loading data
batch_size: number of examples per training/eval step
"""
super().__init__(*args, **kwargs)
self.dims = (3, 32, 32)
self.DATASET = CIFAR10
self.data_dir = data_dir
self.val_split = val_split
self.num_workers = num_workers
self.batch_size = batch_size

@property
def num_classes(self):
Expand All @@ -77,12 +80,9 @@ def prepare_data(self):
self.DATASET(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor(), **self.extra_args)
self.DATASET(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor(), **self.extra_args)

def train_dataloader(self, batch_size):
def train_dataloader(self):
"""
CIFAR train set removes a subset to use for validation
Args:
batch_size: size of batch
"""
transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms

Expand All @@ -91,20 +91,17 @@ def train_dataloader(self, batch_size):
dataset_train, _ = random_split(dataset, [train_length - self.val_split, self.val_split])
loader = DataLoader(
dataset_train,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader

def val_dataloader(self, batch_size):
def val_dataloader(self):
"""
CIFAR10 val set uses a subset of the training set for validation
Args:
batch_size: size of batch
"""
transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms

Expand All @@ -113,28 +110,24 @@ def val_dataloader(self, batch_size):
_, dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split])
loader = DataLoader(
dataset_val,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True
)
return loader

def test_dataloader(self, batch_size):
def test_dataloader(self):
"""
CIFAR10 test set uses the test split
Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms

dataset = self.DATASET(self.data_dir, train=False, download=False, transform=transforms, **self.extra_args)
loader = DataLoader(
dataset,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
Expand Down
3 changes: 1 addition & 2 deletions pl_bolts/datamodules/fashion_mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
from torchvision import transforms as transform_lib
from torchvision.datasets import FashionMNIST

from pl_bolts.datamodules.lightning_datamodule import LightningDataModule


class FashionMNISTDataModule(LightningDataModule):

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torchvision import transforms as transform_lib

from pl_bolts.datamodules.imagenet_dataset import UnlabeledImagenet
from pl_bolts.datamodules.lightning_datamodule import LightningDataModule
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization


Expand Down
Loading

0 comments on commit 07dcc0f

Please sign in to comment.