Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Vision DataModules #400

Merged
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d5d4dda
Add BaseDataModule
chris-clem Nov 24, 2020
4876be0
Add pre-commit hooks
chris-clem Nov 24, 2020
9476033
Refactor cifar10_datamodule
chris-clem Nov 24, 2020
0651131
Move torchvision warning
chris-clem Nov 24, 2020
e1f6238
Refactor binary_mnist_datamodule
chris-clem Nov 24, 2020
5b9e2fd
Refactor fashion_mnist_datamodule
chris-clem Nov 24, 2020
9ab0640
Fix errors
chris-clem Nov 24, 2020
b7840bf
Remove VisionDataset type hint so CI base testing does not fail (torc…
chris-clem Nov 24, 2020
4395be3
Implement Nate's suggestions
chris-clem Nov 25, 2020
e82b243
Remove train and eval batch size because it brakes a lot of tests
chris-clem Nov 25, 2020
8e9ae04
Properly add transforms to train and val dataset
chris-clem Nov 25, 2020
790d6e0
Add num_samples property to cifar10 dm
chris-clem Nov 25, 2020
7c9b3ce
Add tesats and docs
chris-clem Nov 25, 2020
d86f432
Fix flake8 and codafactor issue
chris-clem Nov 25, 2020
e5e69e4
Update changelog
chris-clem Nov 27, 2020
1d821c2
Fix isort
chris-clem Dec 9, 2020
a7d6bd4
Add typing
chris-clem Dec 15, 2020
1d9fa44
Rename to VisionDataModule
chris-clem Dec 15, 2020
6b9bdcb
Remove transform_lib type annotation
chris-clem Dec 15, 2020
8ae4907
suggestions
Borda Dec 16, 2020
4de2ea9
Apply suggestions from code review
Borda Dec 16, 2020
716dedf
Apply suggestions from code review
Borda Dec 17, 2020
54114c5
Add flags from #388 to API
chris-clem Dec 17, 2020
be6fb25
Make tests work
chris-clem Dec 17, 2020
a55ad63
Merge branch 'master' into feature/395_refactor-vision-dms
chris-clem Dec 17, 2020
25382a2
Move _TORCHVISION_AVAILABLE check
chris-clem Dec 17, 2020
4902c1c
Update changelog
chris-clem Dec 17, 2020
2d29a1f
Merge branch 'master' into feature/395_refactor-vision-dms
chris-clem Dec 17, 2020
26eafdf
Merge remote-tracking branch 'origin/feature/395_refactor-vision-dms'…
chris-clem Dec 17, 2020
875f13a
Fix CI base testing
chris-clem Dec 17, 2020
aac84ab
Fix CI base testing
chris-clem Dec 17, 2020
1d5f0b5
Merge remote-tracking branch 'origin/feature/395_refactor-vision-dms'…
chris-clem Dec 17, 2020
4cf20b6
Apply suggestions from code review
akihironitta Dec 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ repos:
- id: isort

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.790
hooks:
- id: mypy
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added data monitor callbacks `ModuleDataMonitor` and `TrainingDataMonitor` ([#285](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/285))

- Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`,
and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/400))

### Changed

- Decoupled datamodules from models ([#332](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/332), [#270](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/270))
Expand Down
155 changes: 44 additions & 111 deletions pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
from typing import Any, Optional, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets.mnist_dataset import BinaryMNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand All @@ -12,7 +11,7 @@
warn_missing_pkg('torchvision')


class BinaryMNISTDataModule(LightningDataModule):
class BinaryMNISTDataModule(VisionDataModule):
"""
.. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png
:width: 400
Expand Down Expand Up @@ -41,136 +40,70 @@ class BinaryMNISTDataModule(LightningDataModule):
"""

name = "binary_mnist"
dataset_cls = BinaryMNIST
dims = (1, 28, 28)
Borda marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
data_dir: str,
val_split: int = 5000,
num_workers: int = 16,
normalize: bool = False,
batch_size: int = 32,
seed: int = 42,
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args,
**kwargs,
):
self,
data_dir: Optional[str] = None,
val_split: Union[int, float] = 0.2,
num_workers: int = 16,
normalize: bool = False,
batch_size: int = 32,
seed: int = 42,
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: where to save/load the data
val_split: how many of the training images to use for the validation split
num_workers: how many workers to use for loading data
data_dir: Where to save/load the data
val_split: Percent (float) or number (int) of samples to use for the validation split
num_workers: How many workers to use for loading data
normalize: If true applies image normalize
batch_size: size of batch
seed: random seed to be used for train/val/test splits
shuffle: If true shuffles the data every epoch
batch_size: How many samples per batch to load
seed: Random seed to be used for train/val/test splits
shuffle: If true shuffles the train data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.'
"You want to use transforms loaded from `torchvision` which is not installed yet."
)

self.dims = (1, 28, 28)
self.data_dir = data_dir
self.val_split = val_split
self.num_workers = num_workers
self.normalize = normalize
self.batch_size = batch_size
self.seed = seed
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
super().__init__(
data_dir=data_dir,
val_split=val_split,
num_workers=num_workers,
normalize=normalize,
batch_size=batch_size,
seed=seed,
shuffle=shuffle,
pin_memory=pin_memory,
drop_last=drop_last,
*args,
**kwargs,
)

@property
def num_classes(self):
def num_classes(self) -> int:
"""
Return:
10
"""
return 10

def prepare_data(self):
"""
Saves MNIST files to data_dir
"""
BinaryMNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor())
BinaryMNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor())

def train_dataloader(self):
"""
MNIST train set removes a subset to use for validation
"""
transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms

dataset = BinaryMNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
dataset_train, _ = random_split(
dataset,
[train_length - self.val_split, self.val_split],
generator=torch.Generator().manual_seed(self.seed)
)
loader = DataLoader(
dataset_train,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory
)
return loader

def val_dataloader(self):
"""
MNIST val set uses a subset of the training set for validation
"""
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
dataset = BinaryMNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
_, dataset_val = random_split(
dataset,
[train_length - self.val_split, self.val_split],
generator=torch.Generator().manual_seed(self.seed)
)
loader = DataLoader(
dataset_val,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory
)
return loader

def test_dataloader(self):
"""
MNIST test set uses the test split
"""
transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms

dataset = BinaryMNIST(self.data_dir, train=False, download=False, transform=transforms)
loader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory
)
return loader

def _default_transforms(self):
def default_transforms(self):
if self.normalize:
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
mnist_transforms = transform_lib.Compose(
[transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
)
else:
mnist_transforms = transform_lib.ToTensor()
mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()])

return mnist_transforms
Loading