Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pl_bolts.datamodules and pl_bolts.datasets #338

Merged
merged 38 commits into from
Dec 14, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
3f09070
Remove try: ... except: ...
akihironitta Nov 5, 2020
3f17e4c
Fix experience_source
akihironitta Nov 6, 2020
9ef4cca
Fix imagenet
akihironitta Nov 6, 2020
c398ed3
Fix kitti
akihironitta Nov 6, 2020
053d4dc
Fix sklearn
akihironitta Nov 6, 2020
e06d09a
Fix vocdetection
akihironitta Nov 6, 2020
273fcba
Fix typo
akihironitta Nov 6, 2020
53046f1
Remove duplicate
akihironitta Nov 6, 2020
0b4ea8b
Fix by flake8
akihironitta Nov 6, 2020
04c9898
Merge branch 'master' into fix/datamodules_init
akihironitta Nov 7, 2020
886c0d2
Add optional packages availability vars
akihironitta Nov 25, 2020
fe50e0c
binary_mnist
akihironitta Nov 25, 2020
e183889
Merge branch 'master' into fix/datamodules_init
akihironitta Nov 25, 2020
b5895af
Use pl_bolts._SKLEARN_AVAILABLE
akihironitta Nov 25, 2020
7113645
Apply isort
akihironitta Nov 25, 2020
7fde915
cifar10
akihironitta Nov 25, 2020
69dfd5a
mnist
akihironitta Nov 25, 2020
b2411b5
cityscapes
akihironitta Nov 25, 2020
2f98b4e
fashion mnist
akihironitta Nov 25, 2020
9bdcc65
ssl_imagenet
akihironitta Nov 25, 2020
53f2fe0
stl10
akihironitta Nov 25, 2020
d7bbba2
cifar10
akihironitta Nov 25, 2020
319041d
dummy
akihironitta Nov 25, 2020
451d033
fix city
akihironitta Nov 25, 2020
48821f0
fix stl10
akihironitta Nov 25, 2020
88543a1
fix mnist
akihironitta Nov 25, 2020
8568398
ssl_amdim
akihironitta Nov 25, 2020
1f074ca
remove unused DataLoader and fix docs
akihironitta Nov 25, 2020
d8a9db2
use from ... import ...
akihironitta Nov 25, 2020
8cf8728
fix pragma: no cover
akihironitta Nov 25, 2020
bdd1bd2
Fix forward reference in annotations
akihironitta Nov 25, 2020
8bda71b
binmnist
akihironitta Nov 25, 2020
c2129bd
Same order as imports
akihironitta Nov 25, 2020
faf2281
Merge branch 'master' into fix/datamodules_init
akihironitta Dec 1, 2020
81a04d0
Move vars from __init__ to utils/__init__
akihironitta Dec 14, 2020
43e0abc
Remove vars from __init__
akihironitta Dec 14, 2020
280ed47
Update vars
akihironitta Dec 14, 2020
fc579ea
Apply isort
akihironitta Dec 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 47 additions & 104 deletions pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,105 +1,48 @@
from pl_bolts.datamodules.async_dataloader import AsynchronousLoader

__all__ = []

try:
from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['BinaryMNISTDataModule']

try:
from pl_bolts.datamodules.cifar10_datamodule import (
CIFAR10DataModule,
TinyCIFAR10DataModule,
)
except ModuleNotFoundError:
pass
else:
__all__ += ['CIFAR10DataModule', 'TinyCIFAR10DataModule']

try:
from pl_bolts.datamodules.experience_source import (
ExperienceSourceDataset,
ExperienceSource,
DiscountedExperienceSource,
)
except ModuleNotFoundError:
pass
else:
__all__ += ['ExperienceSourceDataset', 'ExperienceSource', 'DiscountedExperienceSource']

try:
from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['FashionMNISTDataModule']

try:
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['ImagenetDataModule']

try:
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['MNISTDataModule']

try:
from pl_bolts.datamodules.sklearn_datamodule import (
SklearnDataset,
SklearnDataModule,
TensorDataset,
)
except ModuleNotFoundError:
pass
else:
__all__ += ['SklearnDataset', 'SklearnDataModule', 'TensorDataset']

try:
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['SSLImagenetDataModule']

try:
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['STL10DataModule']

try:
from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['VOCDetectionDataModule']

try:
from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule
except ModuleNotFoundError: # pragma: no-cover
pass
else:
__all__ += ['CityscapesDataModule']

try:
from pl_bolts.datasets.kitti_dataset import KittiDataset
except ModuleNotFoundError:
pass
else:
__all__ += ['KittiDataset']

try:
from pl_bolts.datamodules.kitti_datamodule import KittiDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['KittiDataModule']
from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
from pl_bolts.datamodules.cifar10_datamodule import (
CIFAR10DataModule,
TinyCIFAR10DataModule,
)
from pl_bolts.datamodules.experience_source import (
ExperienceSourceDataset,
ExperienceSource,
DiscountedExperienceSource,
)
from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pl_bolts.datamodules.sklearn_datamodule import (
SklearnDataset,
SklearnDataModule,
TensorDataset,
)
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule
from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule
from pl_bolts.datasets.kitti_dataset import KittiDataset
from pl_bolts.datamodules.kitti_datamodule import KittiDataModule


__all__ = [
'BinaryMNISTDataModule',
'CIFAR10DataModule',
'TinyCIFAR10DataModule',
'ExperienceSourceDataset',
'ExperienceSource',
'DiscountedExperienceSource',
'FashionMNISTDataModule',
'ImagenetDataModule',
'MNISTDataModule',
'SklearnDataset',
'SklearnDataModule',
'TensorDataset',
'SSLImagenetDataModule',
'STL10DataModule',
'VOCDetectionDataModule',
'CityscapesDataModule',
'KittiDataset',
'KittiDataModule',
'AsynchronousLoader',
]
10 changes: 7 additions & 3 deletions pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
"""
from abc import ABC
from collections import deque, namedtuple
import importlib
from typing import Iterable, Callable, Tuple, List

import torch
from gym import Env
from torch.utils.data import IterableDataset

_GYM_AVAILABLE = importlib.util.find_spec("gym") is not None
if _GYM_AVAILABLE:
from gym import Env

# Datasets

Experience = namedtuple(
Expand Down Expand Up @@ -172,7 +176,7 @@ def env_actions(self, device) -> List[List[int]]:

return actions

def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience:
def env_step(self, env_idx: int, env: "Env", action: List[int]) -> Experience:
"""
Carries out a step through the given environment using the given action

Expand Down Expand Up @@ -236,7 +240,7 @@ def pop_rewards_steps(self):
class DiscountedExperienceSource(ExperienceSource):
"""Outputs experiences with a discounted reward over N steps"""

def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99):
def __init__(self, env: "Env", agent, n_steps: int = 1, gamma: float = 0.99):
super().__init__(env, agent, (n_steps + 1))
self.gamma = gamma
self.steps = n_steps
Expand Down
16 changes: 4 additions & 12 deletions pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os
from typing import Optional

Expand All @@ -7,21 +8,12 @@
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization
from pl_bolts.utils.warnings import warn_missing_pkg

try:
_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True

try:
from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
warn_missing_pkg('torchvision') # pragma: no-cover


class ImagenetDataModule(LightningDataModule):
Expand Down
11 changes: 10 additions & 1 deletion pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import importlib
import os

import torch
import torchvision.transforms as transforms
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split

from pl_bolts.datasets.kitti_dataset import KittiDataset

_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
if _TORCHVISION_AVAILABLE:
import torchvision.transforms as transforms


class KittiDataModule(LightningDataModule):

Expand Down Expand Up @@ -56,6 +60,11 @@ def __init__(
batch_size: the batch size
seed: random seed to be used for train/val/test splits
"""
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

super().__init__(*args, **kwargs)
self.data_dir = data_dir if data_dir is not None else os.getcwd()
self.batch_size = batch_size
Expand Down
11 changes: 4 additions & 7 deletions pl_bolts/datamodules/sklearn_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from importlib.util import find_spec
import math
from typing import Any

Expand All @@ -6,14 +7,10 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader

try:

_SKLEARN_AVAILABLE = find_spec("sklearn") is not None
if _SKLEARN_AVAILABLE:
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
from sklearn.utils import shuffle as sk_shuffle
except ModuleNotFoundError:
raise ModuleNotFoundError('You want to use `sklearn` which is not installed yet,' # pragma: no-cover
' install it with `pip install sklearn`.')
_SKLEARN_AVAILABLE = False
else:
_SKLEARN_AVAILABLE = True


class SklearnDataset(Dataset):
Expand Down
17 changes: 8 additions & 9 deletions pl_bolts/datamodules/vocdetection_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import importlib

import torch
import torchvision.transforms as T
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from pl_bolts.utils.warnings import warn_missing_pkg

try:
_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
if _TORCHVISION_AVAILABLE:
import torchvision.transforms as T
from torchvision.datasets import VOCDetection

except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
warn_missing_pkg('torchvision') # pragma: no-cover


class Compose(object):
Expand Down Expand Up @@ -117,13 +116,13 @@ def __init__(
*args,
**kwargs,
):
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use VOC dataset loaded from `torchvision` which is not installed yet.'
)

super().__init__(*args, **kwargs)

self.year = year
self.data_dir = data_dir
self.num_workers = num_workers
Expand Down
15 changes: 10 additions & 5 deletions pl_bolts/datasets/imagenet_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gzip
import hashlib
import importlib
import os
import shutil
import tarfile
Expand All @@ -11,13 +12,12 @@
import torch
from torch._six import PY3

try:
_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
if _TORCHVISION_AVAILABLE:
from torchvision.datasets import ImageNet
from torchvision.datasets.imagenet import load_meta_file
except ModuleNotFoundError as err:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.'
) from err
else:
ImageNet = object # pragma: no-cover


class UnlabeledImagenet(ImageNet):
Expand Down Expand Up @@ -48,6 +48,11 @@ def __init__(
download:
kwargs:
"""
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.'
)

root = self.root = os.path.expanduser(root)

# [train], [val] --> [train, val], [test]
Expand Down
12 changes: 11 additions & 1 deletion pl_bolts/datasets/kitti_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import importlib
import os

import numpy as np
from PIL import Image
from torch.utils.data import Dataset

_PIL_AVAILABLE = importlib.util.find_spec("PIL") is not None
if _PIL_AVAILABLE:
from PIL import Image


DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)

Expand Down Expand Up @@ -41,6 +46,11 @@ def __init__(
void_labels: useless classes to be excluded from training
valid_labels: useful classes to include
"""
if not _PIL_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use `Pillow` which is not installed yet.'
)

self.img_size = img_size
self.void_labels = void_labels
self.valid_labels = valid_labels
Expand Down