Skip to content

Commit

Permalink
Refactor datamodules/datasets (#338)
Browse files Browse the repository at this point in the history
* Remove try: ... except: ...

* Fix experience_source

* Fix imagenet

* Fix kitti

* Fix sklearn

* Fix vocdetection

* Fix typo

* Remove duplicate

* Fix by flake8

* Add optional packages availability vars

* binary_mnist

* Use pl_bolts._SKLEARN_AVAILABLE

* Apply isort

* cifar10

* mnist

* cityscapes

* fashion mnist

* ssl_imagenet

* stl10

* cifar10

* dummy

* fix city

* fix stl10

* fix mnist

* ssl_amdim

* remove unused DataLoader and fix docs

* use from ... import ...

* fix pragma: no cover

* Fix forward reference in annotations

* binmnist

* Same order as imports

* Move vars from __init__ to utils/__init__

* Remove vars from __init__

* Update vars

* Apply isort
  • Loading branch information
akihironitta authored Dec 14, 2020
1 parent 2dfd598 commit 58536c2
Show file tree
Hide file tree
Showing 22 changed files with 176 additions and 262 deletions.
132 changes: 35 additions & 97 deletions pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,98 +1,36 @@
from pl_bolts.datamodules.async_dataloader import AsynchronousLoader

__all__ = []

try:
from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['BinaryMNISTDataModule']

try:
from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['CIFAR10DataModule', 'TinyCIFAR10DataModule']

try:
from pl_bolts.datamodules.experience_source import (
DiscountedExperienceSource,
ExperienceSource,
ExperienceSourceDataset,
)
except ModuleNotFoundError:
pass
else:
__all__ += ['ExperienceSourceDataset', 'ExperienceSource', 'DiscountedExperienceSource']

try:
from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['FashionMNISTDataModule']

try:
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['ImagenetDataModule']

try:
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['MNISTDataModule']

try:
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset, TensorDataset
except ModuleNotFoundError:
pass
else:
__all__ += ['SklearnDataset', 'SklearnDataModule', 'TensorDataset']

try:
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['SSLImagenetDataModule']

try:
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['STL10DataModule']

try:
from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['VOCDetectionDataModule']

try:
from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule
except ModuleNotFoundError: # pragma: no-cover
pass
else:
__all__ += ['CityscapesDataModule']

try:
from pl_bolts.datasets.kitti_dataset import KittiDataset
except ModuleNotFoundError:
pass
else:
__all__ += ['KittiDataset']

try:
from pl_bolts.datamodules.kitti_datamodule import KittiDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['KittiDataModule']
from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule
from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule
from pl_bolts.datamodules.experience_source import DiscountedExperienceSource, ExperienceSource, ExperienceSourceDataset
from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
from pl_bolts.datamodules.kitti_datamodule import KittiDataModule
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset, TensorDataset
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule
from pl_bolts.datasets.kitti_dataset import KittiDataset

__all__ = [
'AsynchronousLoader',
'BinaryMNISTDataModule',
'CIFAR10DataModule',
'TinyCIFAR10DataModule',
'CityscapesDataModule',
'DiscountedExperienceSource',
'ExperienceSource',
'ExperienceSourceDataset',
'FashionMNISTDataModule',
'ImagenetDataModule',
'KittiDataModule',
'MNISTDataModule',
'SklearnDataModule',
'SklearnDataset',
'TensorDataset',
'SSLImagenetDataModule',
'STL10DataModule',
'VOCDetectionDataModule',
'KittiDataset',
]
17 changes: 7 additions & 10 deletions pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split

from pl_bolts.datasets.mnist_dataset import BinaryMNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import MNIST

from pl_bolts.datasets.mnist_dataset import BinaryMNIST
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
else: # pragma: no cover
warn_missing_pkg('torchvision')


class BinaryMNISTDataModule(LightningDataModule):
Expand Down Expand Up @@ -67,8 +64,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
15 changes: 6 additions & 9 deletions pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@

from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import CIFAR10

except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
else: # pragma: no cover
warn_missing_pkg('torchvision')


class CIFAR10DataModule(LightningDataModule):
Expand Down Expand Up @@ -83,8 +80,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use CIFAR10 dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
14 changes: 6 additions & 8 deletions pl_bolts/datamodules/cityscapes_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import Cityscapes
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
else: # pragma: no cover
warn_missing_pkg('torchvision')


class CityscapesDataModule(LightningDataModule):
Expand Down Expand Up @@ -82,8 +80,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use CityScapes dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
13 changes: 6 additions & 7 deletions pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@
import torch
from torch.utils.data import IterableDataset

from pl_bolts.utils import _GYM_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

_GYM_AVAILABLE = importlib.util.find_spec("gym") is not None
if _GYM_AVAILABLE:
from gym import Env
else:
warn_missing_pkg("gym") # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg("gym")
Env = object


# Datasets

Experience = namedtuple(
"Experience", field_names=["state", "action", "reward", "done", "new_state"]
)
Expand Down Expand Up @@ -181,7 +180,7 @@ def env_actions(self, device) -> List[List[int]]:

return actions

def env_step(self, env_idx: int, env: "Env", action: List[int]) -> Experience:
def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience:
"""
Carries out a step through the given environment using the given action
Expand Down Expand Up @@ -245,7 +244,7 @@ def pop_rewards_steps(self):
class DiscountedExperienceSource(ExperienceSource):
"""Outputs experiences with a discounted reward over N steps"""

def __init__(self, env: "Env", agent, n_steps: int = 1, gamma: float = 0.99):
def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99):
super().__init__(env, agent, (n_steps + 1))
self.gamma = gamma
self.steps = n_steps
Expand Down
14 changes: 6 additions & 8 deletions pl_bolts/datamodules/fashion_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import FashionMNIST
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
else: # pragma: no cover
warn_missing_pkg('torchvision')


class FashionMNISTDataModule(LightningDataModule):
Expand Down Expand Up @@ -64,8 +62,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use fashion MNIST dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
23 changes: 7 additions & 16 deletions pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,15 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True

try:
from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
else: # pragma: no cover
warn_missing_pkg('torchvision')


class ImagenetDataModule(LightningDataModule):
Expand Down Expand Up @@ -78,8 +69,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use ImageNet dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
11 changes: 5 additions & 6 deletions pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import os

import torch
Expand All @@ -7,13 +6,13 @@
from torch.utils.data.dataset import random_split

from pl_bolts.datasets.kitti_dataset import KittiDataset
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
if _TORCHVISION_AVAILABLE:
import torchvision.transforms as transforms
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')


class KittiDataModule(LightningDataModule):
Expand Down Expand Up @@ -63,8 +62,8 @@ def __init__(
batch_size: the batch size
seed: random seed to be used for train/val/test splits
"""
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `torchvision` which is not installed yet.'
)

Expand Down
Loading

0 comments on commit 58536c2

Please sign in to comment.