Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pl_bolts.datamodules and pl_bolts.datasets #338

Merged
merged 38 commits into from
Dec 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
3f09070
Remove try: ... except: ...
akihironitta Nov 5, 2020
3f17e4c
Fix experience_source
akihironitta Nov 6, 2020
9ef4cca
Fix imagenet
akihironitta Nov 6, 2020
c398ed3
Fix kitti
akihironitta Nov 6, 2020
053d4dc
Fix sklearn
akihironitta Nov 6, 2020
e06d09a
Fix vocdetection
akihironitta Nov 6, 2020
273fcba
Fix typo
akihironitta Nov 6, 2020
53046f1
Remove duplicate
akihironitta Nov 6, 2020
0b4ea8b
Fix by flake8
akihironitta Nov 6, 2020
04c9898
Merge branch 'master' into fix/datamodules_init
akihironitta Nov 7, 2020
886c0d2
Add optional packages availability vars
akihironitta Nov 25, 2020
fe50e0c
binary_mnist
akihironitta Nov 25, 2020
e183889
Merge branch 'master' into fix/datamodules_init
akihironitta Nov 25, 2020
b5895af
Use pl_bolts._SKLEARN_AVAILABLE
akihironitta Nov 25, 2020
7113645
Apply isort
akihironitta Nov 25, 2020
7fde915
cifar10
akihironitta Nov 25, 2020
69dfd5a
mnist
akihironitta Nov 25, 2020
b2411b5
cityscapes
akihironitta Nov 25, 2020
2f98b4e
fashion mnist
akihironitta Nov 25, 2020
9bdcc65
ssl_imagenet
akihironitta Nov 25, 2020
53f2fe0
stl10
akihironitta Nov 25, 2020
d7bbba2
cifar10
akihironitta Nov 25, 2020
319041d
dummy
akihironitta Nov 25, 2020
451d033
fix city
akihironitta Nov 25, 2020
48821f0
fix stl10
akihironitta Nov 25, 2020
88543a1
fix mnist
akihironitta Nov 25, 2020
8568398
ssl_amdim
akihironitta Nov 25, 2020
1f074ca
remove unused DataLoader and fix docs
akihironitta Nov 25, 2020
d8a9db2
use from ... import ...
akihironitta Nov 25, 2020
8cf8728
fix pragma: no cover
akihironitta Nov 25, 2020
bdd1bd2
Fix forward reference in annotations
akihironitta Nov 25, 2020
8bda71b
binmnist
akihironitta Nov 25, 2020
c2129bd
Same order as imports
akihironitta Nov 25, 2020
faf2281
Merge branch 'master' into fix/datamodules_init
akihironitta Dec 1, 2020
81a04d0
Move vars from __init__ to utils/__init__
akihironitta Dec 14, 2020
43e0abc
Remove vars from __init__
akihironitta Dec 14, 2020
280ed47
Update vars
akihironitta Dec 14, 2020
fc579ea
Apply isort
akihironitta Dec 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 35 additions & 97 deletions pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,98 +1,36 @@
from pl_bolts.datamodules.async_dataloader import AsynchronousLoader

__all__ = []

try:
from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['BinaryMNISTDataModule']

try:
from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['CIFAR10DataModule', 'TinyCIFAR10DataModule']

try:
from pl_bolts.datamodules.experience_source import (
DiscountedExperienceSource,
ExperienceSource,
ExperienceSourceDataset,
)
except ModuleNotFoundError:
pass
else:
__all__ += ['ExperienceSourceDataset', 'ExperienceSource', 'DiscountedExperienceSource']

try:
from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['FashionMNISTDataModule']

try:
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['ImagenetDataModule']

try:
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['MNISTDataModule']

try:
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset, TensorDataset
except ModuleNotFoundError:
pass
else:
__all__ += ['SklearnDataset', 'SklearnDataModule', 'TensorDataset']

try:
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['SSLImagenetDataModule']

try:
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['STL10DataModule']

try:
from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['VOCDetectionDataModule']

try:
from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule
except ModuleNotFoundError: # pragma: no-cover
pass
else:
__all__ += ['CityscapesDataModule']

try:
from pl_bolts.datasets.kitti_dataset import KittiDataset
except ModuleNotFoundError:
pass
else:
__all__ += ['KittiDataset']

try:
from pl_bolts.datamodules.kitti_datamodule import KittiDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['KittiDataModule']
from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule
from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule
from pl_bolts.datamodules.experience_source import DiscountedExperienceSource, ExperienceSource, ExperienceSourceDataset
from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
from pl_bolts.datamodules.kitti_datamodule import KittiDataModule
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset, TensorDataset
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule
from pl_bolts.datasets.kitti_dataset import KittiDataset

__all__ = [
'AsynchronousLoader',
'BinaryMNISTDataModule',
'CIFAR10DataModule',
'TinyCIFAR10DataModule',
'CityscapesDataModule',
'DiscountedExperienceSource',
'ExperienceSource',
'ExperienceSourceDataset',
'FashionMNISTDataModule',
'ImagenetDataModule',
'KittiDataModule',
'MNISTDataModule',
'SklearnDataModule',
'SklearnDataset',
'TensorDataset',
'SSLImagenetDataModule',
'STL10DataModule',
'VOCDetectionDataModule',
'KittiDataset',
]
17 changes: 7 additions & 10 deletions pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split

from pl_bolts.datasets.mnist_dataset import BinaryMNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import MNIST

from pl_bolts.datasets.mnist_dataset import BinaryMNIST
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
else: # pragma: no cover
warn_missing_pkg('torchvision')


class BinaryMNISTDataModule(LightningDataModule):
Expand Down Expand Up @@ -67,8 +64,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
15 changes: 6 additions & 9 deletions pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@

from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import CIFAR10

except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
else: # pragma: no cover
warn_missing_pkg('torchvision')


class CIFAR10DataModule(LightningDataModule):
Expand Down Expand Up @@ -83,8 +80,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use CIFAR10 dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
14 changes: 6 additions & 8 deletions pl_bolts/datamodules/cityscapes_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import Cityscapes
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
else: # pragma: no cover
warn_missing_pkg('torchvision')


class CityscapesDataModule(LightningDataModule):
Expand Down Expand Up @@ -82,8 +80,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use CityScapes dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
13 changes: 6 additions & 7 deletions pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@
import torch
from torch.utils.data import IterableDataset

from pl_bolts.utils import _GYM_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

_GYM_AVAILABLE = importlib.util.find_spec("gym") is not None
if _GYM_AVAILABLE:
from gym import Env
else:
warn_missing_pkg("gym") # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg("gym")
Env = object


# Datasets

Experience = namedtuple(
"Experience", field_names=["state", "action", "reward", "done", "new_state"]
)
Expand Down Expand Up @@ -181,7 +180,7 @@ def env_actions(self, device) -> List[List[int]]:

return actions

def env_step(self, env_idx: int, env: "Env", action: List[int]) -> Experience:
def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience:
"""
Carries out a step through the given environment using the given action

Expand Down Expand Up @@ -245,7 +244,7 @@ def pop_rewards_steps(self):
class DiscountedExperienceSource(ExperienceSource):
"""Outputs experiences with a discounted reward over N steps"""

def __init__(self, env: "Env", agent, n_steps: int = 1, gamma: float = 0.99):
def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99):
super().__init__(env, agent, (n_steps + 1))
self.gamma = gamma
self.steps = n_steps
Expand Down
14 changes: 6 additions & 8 deletions pl_bolts/datamodules/fashion_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import FashionMNIST
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
else: # pragma: no cover
warn_missing_pkg('torchvision')


class FashionMNISTDataModule(LightningDataModule):
Expand Down Expand Up @@ -64,8 +62,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use fashion MNIST dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
23 changes: 7 additions & 16 deletions pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,15 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True

try:
from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
else: # pragma: no cover
warn_missing_pkg('torchvision')


class ImagenetDataModule(LightningDataModule):
Expand Down Expand Up @@ -78,8 +69,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use ImageNet dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
11 changes: 5 additions & 6 deletions pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import os

import torch
Expand All @@ -7,13 +6,13 @@
from torch.utils.data.dataset import random_split

from pl_bolts.datasets.kitti_dataset import KittiDataset
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
if _TORCHVISION_AVAILABLE:
import torchvision.transforms as transforms
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')


class KittiDataModule(LightningDataModule):
Expand Down Expand Up @@ -63,8 +62,8 @@ def __init__(
batch_size: the batch size
seed: random seed to be used for train/val/test splits
"""
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `torchvision` which is not installed yet.'
)

Expand Down
Loading