From 3f09070554c0b18759897267e3edc3072388b29f Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 6 Nov 2020 08:06:08 +0900 Subject: [PATCH 01/35] Remove try: ... except: ... --- pl_bolts/datamodules/__init__.py | 148 ++++++++++--------------------- 1 file changed, 46 insertions(+), 102 deletions(-) diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index 29f29f2ec9..fc23797081 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -1,105 +1,49 @@ from pl_bolts.datamodules.async_dataloader import AsynchronousLoader +from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule +from pl_bolts.datamodules.cifar10_datamodule import ( + CIFAR10DataModule, + TinyCIFAR10DataModule, +) +from pl_bolts.datamodules.experience_source import ( + ExperienceSourceDataset, + ExperienceSource, + DiscountedExperienceSource, +) +from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule +from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule +from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule +from pl_bolts.datamodules.sklearn_datamodule import ( + SklearnDataset, + SklearnDataModule, + TensorDataset, +) +from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule +from pl_bolts.datamodules.stl10_datamodule import STL10DataModule +from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule +from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule +from pl_bolts.datasets.kitti_dataset import KittiDataset +from pl_bolts.datamodules.kitti_datamodule import KittiDataModule +from pl_bolts.datamodules.async_dataloader import AsynchronousLoader -__all__ = [] - -try: - from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule -except ModuleNotFoundError: - pass -else: - __all__ += ['BinaryMNISTDataModule'] - -try: - from pl_bolts.datamodules.cifar10_datamodule import ( - CIFAR10DataModule, - TinyCIFAR10DataModule, - ) -except ModuleNotFoundError: - pass -else: - __all__ += ['CIFAR10DataModule', 'TinyCIFAR10DataModule'] - -try: - from pl_bolts.datamodules.experience_source import ( - ExperienceSourceDataset, - ExperienceSource, - DiscountedExperienceSource, - ) -except ModuleNotFoundError: - pass -else: - __all__ += ['ExperienceSourceDataset', 'ExperienceSource', 'DiscountedExperienceSource'] - -try: - from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule -except ModuleNotFoundError: - pass -else: - __all__ += ['FashionMNISTDataModule'] - -try: - from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule -except ModuleNotFoundError: - pass -else: - __all__ += ['ImagenetDataModule'] - -try: - from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule -except ModuleNotFoundError: - pass -else: - __all__ += ['MNISTDataModule'] - -try: - from pl_bolts.datamodules.sklearn_datamodule import ( - SklearnDataset, - SklearnDataModule, - TensorDataset, - ) -except ModuleNotFoundError: - pass -else: - __all__ += ['SklearnDataset', 'SklearnDataModule', 'TensorDataset'] - -try: - from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule -except ModuleNotFoundError: - pass -else: - __all__ += ['SSLImagenetDataModule'] - -try: - from pl_bolts.datamodules.stl10_datamodule import STL10DataModule -except ModuleNotFoundError: - pass -else: - __all__ += ['STL10DataModule'] - -try: - from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule -except ModuleNotFoundError: - pass -else: - __all__ += ['VOCDetectionDataModule'] - -try: - from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule -except ModuleNotFoundError: # pragma: no-cover - pass -else: - __all__ += ['CityscapesDataModule'] - -try: - from pl_bolts.datasets.kitti_dataset import KittiDataset -except ModuleNotFoundError: - pass -else: - __all__ += ['KittiDataset'] -try: - from pl_bolts.datamodules.kitti_datamodule import KittiDataModule -except ModuleNotFoundError: - pass -else: - __all__ += ['KittiDataModule'] +__all__ = [ + 'BinaryMNISTDataModule', + 'CIFAR10DataModule', + 'TinyCIFAR10DataModule', + 'ExperienceSourceDataset', + 'ExperienceSource', + 'DiscountedExperienceSource', + 'FashionMNISTDataModule', + 'ImagenetDataModule', + 'MNISTDataModule', + 'SklearnDataset', + 'SklearnDataModule', + 'TensorDataset', + 'SSLImagenetDataModule', + 'STL10DataModule', + 'VOCDetectionDataModule', + 'CityscapesDataModule', + 'KittiDataset', + 'KittiDataModule', + 'AsynchronousLoader', +] From 3f17e4cea71e13803b9cf3c6c323dd046e6c7fc6 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 6 Nov 2020 10:25:30 +0900 Subject: [PATCH 02/35] Fix experience_source --- pl_bolts/datamodules/experience_source.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index e4df09c85d..3a56072230 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -4,12 +4,16 @@ """ from abc import ABC from collections import deque, namedtuple +import importlib from typing import Iterable, Callable, Tuple, List import torch -from gym import Env from torch.utils.data import IterableDataset +_GYM_AVAILABLE = importlib.util.find_spec("gym") is not None +if _GYM_AVAILABLE: + from gym import Env + # Datasets Experience = namedtuple( @@ -172,7 +176,7 @@ def env_actions(self, device) -> List[List[int]]: return actions - def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience: + def env_step(self, env_idx: int, env: "Env", action: List[int]) -> Experience: """ Carries out a step through the given environment using the given action @@ -236,7 +240,7 @@ def pop_rewards_steps(self): class DiscountedExperienceSource(ExperienceSource): """Outputs experiences with a discounted reward over N steps""" - def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99): + def __init__(self, env: "Env", agent, n_steps: int = 1, gamma: float = 0.99): super().__init__(env, agent, (n_steps + 1)) self.gamma = gamma self.steps = n_steps From 9ef4ccabf9a3af0d18e987a7624cefcec542ccbf Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 6 Nov 2020 10:25:57 +0900 Subject: [PATCH 03/35] Fix imagenet --- pl_bolts/datamodules/imagenet_datamodule.py | 17 ++++------------- pl_bolts/datasets/imagenet_dataset.py | 15 ++++++++++----- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 806a63ffe1..00ef694c3c 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -1,3 +1,4 @@ +import importlib import os from typing import Optional from warnings import warn @@ -7,23 +8,13 @@ from pl_bolts.transforms.dataset_normalizations import imagenet_normalization -try: +_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib -except ModuleNotFoundError: - warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover - ' install it with `pip install torchvision`.') - _TORCHVISION_AVAILABLE = False -else: - _TORCHVISION_AVAILABLE = True - -try: from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet -except ModuleNotFoundError: +else: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') - _TORCHVISION_AVAILABLE = False -else: - _TORCHVISION_AVAILABLE = True class ImagenetDataModule(LightningDataModule): diff --git a/pl_bolts/datasets/imagenet_dataset.py b/pl_bolts/datasets/imagenet_dataset.py index f070bebe1a..44cfe87bfb 100644 --- a/pl_bolts/datasets/imagenet_dataset.py +++ b/pl_bolts/datasets/imagenet_dataset.py @@ -1,5 +1,6 @@ import gzip import hashlib +import importlib import os import shutil import tarfile @@ -11,13 +12,12 @@ import torch from torch._six import PY3 -try: +_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None +if _TORCHVISION_AVAILABLE: from torchvision.datasets import ImageNet from torchvision.datasets.imagenet import load_meta_file -except ModuleNotFoundError as err: - raise ModuleNotFoundError( # pragma: no-cover - 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' - ) from err +else: + ImageNet = object # pragma: no-cover class UnlabeledImagenet(ImageNet): @@ -48,6 +48,11 @@ def __init__( download: kwargs: """ + if not _TORCHVISION_AVAILABLE: + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' + ) + root = self.root = os.path.expanduser(root) # [train], [val] --> [train, val], [test] From c398ed3fffc84e10b6c3c66992d6016c700424e4 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 6 Nov 2020 10:26:15 +0900 Subject: [PATCH 04/35] Fix kitti --- pl_bolts/datamodules/kitti_datamodule.py | 11 ++++++++++- pl_bolts/datasets/kitti_dataset.py | 12 +++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 5b39228742..3dcbc4f720 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -1,13 +1,17 @@ +import importlib import os import torch -import torchvision.transforms as transforms from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from torch.utils.data.dataset import random_split from pl_bolts.datasets.kitti_dataset import KittiDataset +_TORCHVIVION_AVAILABLE = importlib.util.find_spec("torchvision") is not None +if _TORCHVIVION_AVAILABLE: + import torchvision.transforms as transforms + class KittiDataModule(LightningDataModule): @@ -56,6 +60,11 @@ def __init__( batch_size: the batch size seed: random seed to be used for train/val/test splits """ + if not _TORCHVISION_AVAILABLE: + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) + super().__init__(*args, **kwargs) self.data_dir = data_dir if data_dir is not None else os.getcwd() self.batch_size = batch_size diff --git a/pl_bolts/datasets/kitti_dataset.py b/pl_bolts/datasets/kitti_dataset.py index bd8c774c39..edbb7d95e1 100644 --- a/pl_bolts/datasets/kitti_dataset.py +++ b/pl_bolts/datasets/kitti_dataset.py @@ -1,9 +1,14 @@ +import importlib import os import numpy as np -from PIL import Image from torch.utils.data import Dataset +_PIL_AVAILABLE = importlib.util.find_spec("PIL") is not None +if _PIL_AVAILABLE: + from PIL import Image + + DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) @@ -41,6 +46,11 @@ def __init__( void_labels: useless classes to be excluded from training valid_labels: useful classes to include """ + if not _PIL_AVAILABLE: + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `Pillow` which is not installed yet.' + ) + self.img_size = img_size self.void_labels = void_labels self.valid_labels = valid_labels From 053d4dc8ebbfe962c9a84646961e175333d44c90 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 6 Nov 2020 10:26:36 +0900 Subject: [PATCH 05/35] Fix sklearn --- pl_bolts/datamodules/sklearn_datamodule.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index 4415a3e769..c127c64eec 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -1,3 +1,4 @@ +from importlib.util import find_spec import math from typing import Any @@ -6,14 +7,10 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import Dataset, DataLoader -try: + +_SKLEARN_AVAILABLE = find_spec("sklearn") is not None +if _SKLEARN_AVAILABLE: from sklearn.utils import shuffle as sk_shuffle -except ModuleNotFoundError: - raise ModuleNotFoundError('You want to use `sklearn` which is not installed yet,' # pragma: no-cover - ' install it with `pip install sklearn`.') - _SKLEARN_AVAILABLE = False -else: - _SKLEARN_AVAILABLE = True class SklearnDataset(Dataset): From e06d09a94915551de6c611118c95ea3a343ef038 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 6 Nov 2020 10:26:48 +0900 Subject: [PATCH 06/35] Fix vocdetection --- pl_bolts/datamodules/vocdetection_datamodule.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 580aca09f2..8c501a8570 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -1,19 +1,17 @@ +import importlib from warnings import warn import torch -import torchvision.transforms as T from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader -try: +_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None +if _TORCHVISION_AVAILABLE: + import torchvision.transforms as T from torchvision.datasets import VOCDetection - -except ModuleNotFoundError: +else: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') - _TORCHVISION_AVAILABLE = False -else: - _TORCHVISION_AVAILABLE = True class Compose(object): @@ -118,12 +116,12 @@ def __init__( *args, **kwargs, ): - super().__init__(*args, **kwargs) - if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover 'You want to use VOC dataset loaded from `torchvision` which is not installed yet.' ) + + super().__init__(*args, **kwargs) self.year = year self.data_dir = data_dir From 273fcbaee379dcfa271d1860803ae795dc44de4d Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 6 Nov 2020 10:31:30 +0900 Subject: [PATCH 07/35] Fix typo --- pl_bolts/datamodules/kitti_datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 3dcbc4f720..cb1ff721a2 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -8,8 +8,8 @@ from pl_bolts.datasets.kitti_dataset import KittiDataset -_TORCHVIVION_AVAILABLE = importlib.util.find_spec("torchvision") is not None -if _TORCHVIVION_AVAILABLE: +_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None +if _TORCHVISION_AVAILABLE: import torchvision.transforms as transforms From 53046f1fc315522b8322e4a0da64a87212fdc451 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 6 Nov 2020 10:33:03 +0900 Subject: [PATCH 08/35] Remove duplicate --- pl_bolts/datamodules/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index fc23797081..21eb5c5e1d 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -23,7 +23,6 @@ from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule from pl_bolts.datasets.kitti_dataset import KittiDataset from pl_bolts.datamodules.kitti_datamodule import KittiDataModule -from pl_bolts.datamodules.async_dataloader import AsynchronousLoader __all__ = [ From 0b4ea8b6fa675cbf51d13329817fefe3c4306462 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 6 Nov 2020 10:34:07 +0900 Subject: [PATCH 09/35] Fix by flake8 --- pl_bolts/datamodules/kitti_datamodule.py | 2 +- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- pl_bolts/datasets/imagenet_dataset.py | 2 +- pl_bolts/datasets/kitti_dataset.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index cb1ff721a2..071eb0ee7c 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -64,7 +64,7 @@ def __init__( raise ModuleNotFoundError( # pragma: no-cover 'You want to use `transforms` from `torchvision` which is not installed yet.' ) - + super().__init__(*args, **kwargs) self.data_dir = data_dir if data_dir is not None else os.getcwd() self.batch_size = batch_size diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 8c501a8570..257174c3d7 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -120,7 +120,7 @@ def __init__( raise ModuleNotFoundError( # pragma: no-cover 'You want to use VOC dataset loaded from `torchvision` which is not installed yet.' ) - + super().__init__(*args, **kwargs) self.year = year diff --git a/pl_bolts/datasets/imagenet_dataset.py b/pl_bolts/datasets/imagenet_dataset.py index 44cfe87bfb..731a022fe8 100644 --- a/pl_bolts/datasets/imagenet_dataset.py +++ b/pl_bolts/datasets/imagenet_dataset.py @@ -52,7 +52,7 @@ def __init__( raise ModuleNotFoundError( # pragma: no-cover 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' ) - + root = self.root = os.path.expanduser(root) # [train], [val] --> [train, val], [test] diff --git a/pl_bolts/datasets/kitti_dataset.py b/pl_bolts/datasets/kitti_dataset.py index edbb7d95e1..5a1b3eddf4 100644 --- a/pl_bolts/datasets/kitti_dataset.py +++ b/pl_bolts/datasets/kitti_dataset.py @@ -50,7 +50,7 @@ def __init__( raise ModuleNotFoundError( # pragma: no-cover 'You want to use `Pillow` which is not installed yet.' ) - + self.img_size = img_size self.void_labels = void_labels self.valid_labels = valid_labels From 886c0d240f232eb99ba2daa2513d8b8267999a50 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 03:11:02 +0900 Subject: [PATCH 10/35] Add optional packages availability vars --- pl_bolts/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pl_bolts/__init__.py b/pl_bolts/__init__.py index 661c51e9c7..4ae74b2e90 100644 --- a/pl_bolts/__init__.py +++ b/pl_bolts/__init__.py @@ -1,6 +1,7 @@ """Root package info.""" import os +from importlib.util import find_spec __version__ = '0.2.5rc1' __author__ = 'PyTorchLightning et al.' @@ -30,6 +31,12 @@ PACKAGE_ROOT = os.path.dirname(__file__) +_TORCHVISION_AVAILABLE = find_spec("torchvision") is not None +_SKLEARN_AVAILABLE = find_spec("sklearn") is not None +_PIL_AVAILABLE = find_spec("PIL") is not None +_GYM_AVAILABLE = find_spec("gym") is not None +_OPENCV_AVAILABLE = find_spec("cv2") is not None + try: # This variable is injected in the __builtins__ by the build process. # It used to enable importing subpackages when the binaries are not built. From fe50e0c6476c26f68a0b775e20839adfcbb6b9b3 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 03:14:22 +0900 Subject: [PATCH 11/35] binary_mnist --- pl_bolts/datamodules/binary_mnist_datamodule.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index fc08c6f5a8..110591d4b3 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -2,17 +2,15 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split +from pl_bolts import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST from pl_bolts.datasets.mnist_dataset import BinaryMNIST -except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover - _TORCHVISION_AVAILABLE = False else: - _TORCHVISION_AVAILABLE = True + warn_missing_pkg('torchvision') # pragma: no-cover class BinaryMNISTDataModule(LightningDataModule): From b5895af05877060f0cb59e813407eb5f1873e03c Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 04:25:13 +0900 Subject: [PATCH 12/35] Use pl_bolts._SKLEARN_AVAILABLE --- pl_bolts/utils/semi_supervised.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pl_bolts/utils/semi_supervised.py b/pl_bolts/utils/semi_supervised.py index 1dd7ab43c1..82f9343b01 100644 --- a/pl_bolts/utils/semi_supervised.py +++ b/pl_bolts/utils/semi_supervised.py @@ -1,16 +1,15 @@ -import importlib import math import numpy as np import torch +from pl_bolts import _SKLEARN_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -_SKLEARN_AVAILABLE = importlib.util.find_spec("sklearn") is not None if _SKLEARN_AVAILABLE: from sklearn.utils import shuffle as sk_shuffle -else: - warn_missing_pkg('sklearn', pypi_name='scikit-learn') # pragma: no-cover +else: # pragma: no-cover + warn_missing_pkg('sklearn', pypi_name='scikit-learn') class Identity(torch.nn.Module): From 7113645abe435b3e2e2834856a54c0330d6edaab Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 04:26:30 +0900 Subject: [PATCH 13/35] Apply isort --- pl_bolts/datamodules/__init__.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index 21eb5c5e1d..d914c4cea1 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -1,29 +1,17 @@ from pl_bolts.datamodules.async_dataloader import AsynchronousLoader from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule -from pl_bolts.datamodules.cifar10_datamodule import ( - CIFAR10DataModule, - TinyCIFAR10DataModule, -) -from pl_bolts.datamodules.experience_source import ( - ExperienceSourceDataset, - ExperienceSource, - DiscountedExperienceSource, -) +from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule +from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule +from pl_bolts.datamodules.experience_source import DiscountedExperienceSource, ExperienceSource, ExperienceSourceDataset from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule +from pl_bolts.datamodules.kitti_datamodule import KittiDataModule from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule -from pl_bolts.datamodules.sklearn_datamodule import ( - SklearnDataset, - SklearnDataModule, - TensorDataset, -) +from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset, TensorDataset from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule from pl_bolts.datamodules.stl10_datamodule import STL10DataModule from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule -from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule from pl_bolts.datasets.kitti_dataset import KittiDataset -from pl_bolts.datamodules.kitti_datamodule import KittiDataModule - __all__ = [ 'BinaryMNISTDataModule', From 7fde915dbb753b46554d71e69f39c5c34354f200 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:01:27 +0900 Subject: [PATCH 14/35] cifar10 --- pl_bolts/datamodules/cifar10_datamodule.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 19535b2932..b6745a4233 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -5,19 +5,16 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split +from pl_bolts import _TORCHVISION_AVAILABLE from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10 from pl_bolts.transforms.dataset_normalizations import cifar10_normalization from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import CIFAR10 - -except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover - _TORCHVISION_AVAILABLE = False -else: - _TORCHVISION_AVAILABLE = True +else: # pragma: no-cover + warn_missing_pkg('torchvision') class CIFAR10DataModule(LightningDataModule): From 69dfd5af552a6e026221b29b02940c9ef7421ce3 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:02:39 +0900 Subject: [PATCH 15/35] mnist --- pl_bolts/datamodules/mnist_datamodule.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 38d3d001bb..a942958048 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -2,16 +2,14 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split +from pl_bolts import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST -except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover - _TORCHVISION_AVAILABLE = False -else: - _TORCHVISION_AVAILABLE = True +else: # pragma: no-cover + warn_missing_pkg('torchvision') class MNISTDataModule(LightningDataModule): From b2411b510156fc78a2055f8033ce99d8a1a3233b Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:03:23 +0900 Subject: [PATCH 16/35] cityscapes --- pl_bolts/datamodules/cityscapes_datamodule.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index a0d623253e..634ea13e6a 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -3,14 +3,11 @@ from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import Cityscapes -except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover - _TORCHVISION_AVAILABLE = False -else: - _TORCHVISION_AVAILABLE = True +else: # pragma: no-cover + warn_missing_pkg('torchvision') class CityscapesDataModule(LightningDataModule): From 2f98b4eab70a2d9c06cdf469761427d337f541aa Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:04:31 +0900 Subject: [PATCH 17/35] fashion mnist --- pl_bolts/datamodules/fashion_mnist_datamodule.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 32ccb2ce81..73182d4020 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -2,16 +2,14 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split +from pl_bolts import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import FashionMNIST -except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover - _TORCHVISION_AVAILABLE = False -else: - _TORCHVISION_AVAILABLE = True +else: # pragma: no-cover + warn_missing_pkg('torchvision') class FashionMNISTDataModule(LightningDataModule): From 9bdcc658ceeb3c361442cd0e5dd9be3eb5d77bf4 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:10:07 +0900 Subject: [PATCH 18/35] ssl_imagenet --- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 06bcf77ce1..ac11a1e551 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -3,18 +3,15 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader +from pl_bolts import _TORCHVISION_AVAILABLE from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib - -except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover - _TORCHVISION_AVAILABLE = False -else: - _TORCHVISION_AVAILABLE = True +else: # pragma: no-cover + warn_missing_pkg('torchvision') class SSLImagenetDataModule(LightningDataModule): # pragma: no cover From 53f2fe0ca20dfbcddb584d7e7c513fd18522e820 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:12:26 +0900 Subject: [PATCH 19/35] stl10 --- pl_bolts/datamodules/stl10_datamodule.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index b1ee3058a8..12311796f1 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -5,18 +5,16 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split +from pl_bolts import _TORCHVISION_AVAILABLE from pl_bolts.datasets.concat_dataset import ConcatDataset from pl_bolts.transforms.dataset_normalizations import stl10_normalization from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _TORCHVISION_AVAILABLE from torchvision import transforms as transform_lib from torchvision.datasets import STL10 -except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover - _TORCHVISION_AVAILABLE = False -else: - _TORCHVISION_AVAILABLE = True +else: # pragma: no-cover + warn_missing_pkg('torchvision') class STL10DataModule(LightningDataModule): # pragma: no cover From d7bbba2a5fe33bbd1b7080a1fc9500890491703a Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:13:53 +0900 Subject: [PATCH 20/35] cifar10 --- pl_bolts/datasets/cifar10_dataset.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index 1038c7aa87..74c56a776b 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -6,17 +6,14 @@ import torch from torch import Tensor +from pl_bolts import _PIL_AVAILABLE +from pl_bolts.datasets.base_dataset import LightDataset from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _PIL_AVAILABLE: from PIL import Image -except ModuleNotFoundError: - warn_missing_pkg('PIL', pypi_name='Pillow') # pragma: no-cover - _PIL_AVAILABLE = False -else: - _PIL_AVAILABLE = True - -from pl_bolts.datasets.base_dataset import LightDataset +else: # pragma: no-cover + warn_missing_pkg('PIL', pypi_name='Pillow') class CIFAR10(LightDataset): From 319041d89f68a4a5013a451e2fbdf047475cafca Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:14:54 +0900 Subject: [PATCH 21/35] dummy --- pl_bolts/datasets/dummy_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datasets/dummy_dataset.py b/pl_bolts/datasets/dummy_dataset.py index 59884ceaaf..b14c503a90 100644 --- a/pl_bolts/datasets/dummy_dataset.py +++ b/pl_bolts/datasets/dummy_dataset.py @@ -1,5 +1,5 @@ import torch -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset class DummyDataset(Dataset): From 451d0334805c38c81f4b9379bea19840f1ae2b25 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:19:01 +0900 Subject: [PATCH 22/35] fix city --- pl_bolts/datamodules/cityscapes_datamodule.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 634ea13e6a..9b9e12b0a1 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -1,6 +1,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader +from pl_bolts import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: From 48821f03777ce15fda9df5599c6396405dfaa6ec Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:19:49 +0900 Subject: [PATCH 23/35] fix stl10 --- pl_bolts/datamodules/stl10_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 12311796f1..1c4e090363 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -10,7 +10,7 @@ from pl_bolts.transforms.dataset_normalizations import stl10_normalization from pl_bolts.utils.warnings import warn_missing_pkg -if _TORCHVISION_AVAILABLE +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import STL10 else: # pragma: no-cover From 88543a13873217c1ae1b10f9660fbe4ce4d6a62d Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:21:44 +0900 Subject: [PATCH 24/35] fix mnist --- pl_bolts/datasets/mnist_dataset.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/pl_bolts/datasets/mnist_dataset.py b/pl_bolts/datasets/mnist_dataset.py index cdea04566b..74f41e58cc 100644 --- a/pl_bolts/datasets/mnist_dataset.py +++ b/pl_bolts/datasets/mnist_dataset.py @@ -1,20 +1,17 @@ +from pl_bolts import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST -except ModuleNotFoundError as err: - raise ModuleNotFoundError( # pragma: no-cover - 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' - ) from err +else: # pragma: no-cover + warn_missing_pkg('torchvision') + MNIST = object -try: +if _PIL_AVAILABLE: from PIL import Image -except ModuleNotFoundError: - warn_missing_pkg('PIL', pypi_name='Pillow') # pragma: no-cover - _PIL_AVAILABLE = False -else: - _PIL_AVAILABLE = True +else: # pragma: no-cover + warn_missing_pkg('PIL', pypi_name='Pillow') class BinaryMNIST(MNIST): From 8568398e6e98a2deb02552d6022d96ca1d650a0e Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:23:18 +0900 Subject: [PATCH 25/35] ssl_amdim --- pl_bolts/datasets/ssl_amdim_datasets.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pl_bolts/datasets/ssl_amdim_datasets.py b/pl_bolts/datasets/ssl_amdim_datasets.py index e7cc43a78d..75660562bc 100644 --- a/pl_bolts/datasets/ssl_amdim_datasets.py +++ b/pl_bolts/datasets/ssl_amdim_datasets.py @@ -3,12 +3,14 @@ import numpy as np +from pl_bolts import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _TORCHVISION_AVAILABLE: from torchvision.datasets import CIFAR10 -except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover +else: # pragma: no-cover + warn_missing_pkg('torchvision') + CIFAR10 = object class SSLDatasetMixin(ABC): From 1f074cac92812fabccdb6b578dc018107262ce26 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:41:13 +0900 Subject: [PATCH 26/35] remove unused DataLoader and fix docs --- pl_bolts/datasets/dummy_dataset.py | 35 +++++++++++++----------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/pl_bolts/datasets/dummy_dataset.py b/pl_bolts/datasets/dummy_dataset.py index b14c503a90..910748abbf 100644 --- a/pl_bolts/datasets/dummy_dataset.py +++ b/pl_bolts/datasets/dummy_dataset.py @@ -6,10 +6,9 @@ class DummyDataset(Dataset): """ Generate a dummy dataset - Example:: - - from pl_bolts.datasets import DummyDataset - + Example: + >>> from pl_bolts.datasets import DummyDataset + >>> from torch.utils.data import DataLoader >>> # mnist dims >>> ds = DummyDataset((1, 28, 28), (1, )) >>> dl = DataLoader(ds, batch_size=7) @@ -46,10 +45,9 @@ class DummyDetectionDataset(Dataset): """ Generate a dummy dataset for detection - Example:: - - from pl_bolts.datasets import DummyDetectionDataset - + Example: + >>> from pl_bolts.datasets import DummyDetectionDataset + >>> from torch.utils.data import DataLoader >>> ds = DummyDetectionDataset() >>> dl = DataLoader(ds, batch_size=7) """ @@ -87,10 +85,9 @@ class RandomDictDataset(Dataset): """ Generate a dummy dataset with a dict structure - Example:: - - from pl_bolts.datasets import RandomDictDataset - + Example: + >>> from pl_bolts.datasets import RandomDictDataset + >>> from torch.utils.data import DataLoader >>> ds = RandomDictDataset(10) >>> dl = DataLoader(ds, batch_size=7) """ @@ -116,10 +113,9 @@ class RandomDictStringDataset(Dataset): """ Generate a dummy dataset with strings - Example:: - - from pl_bolts.datasets import RandomDictStringDataset - + Example: + >>> from pl_bolts.datasets import RandomDictStringDataset + >>> from torch.utils.data import DataLoader >>> ds = RandomDictStringDataset(10) >>> dl = DataLoader(ds, batch_size=7) """ @@ -143,10 +139,9 @@ class RandomDataset(Dataset): """ Generate a dummy dataset - Example:: - - from pl_bolts.datasets import RandomDataset - + Example: + >>> from pl_bolts.datasets import RandomDataset + >>> from torch.utils.data import DataLoader >>> ds = RandomDataset(10) >>> dl = DataLoader(ds, batch_size=7) """ From d8a9db26f54e9ec52d9ed754a971d16e5fc60edd Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:45:22 +0900 Subject: [PATCH 27/35] use from ... import ... --- pl_bolts/datasets/concat_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/datasets/concat_dataset.py b/pl_bolts/datasets/concat_dataset.py index 48b6782c58..ae09a37c7f 100644 --- a/pl_bolts/datasets/concat_dataset.py +++ b/pl_bolts/datasets/concat_dataset.py @@ -1,7 +1,7 @@ -import torch +from torch.utils.data import Dataset -class ConcatDataset(torch.utils.data.Dataset): +class ConcatDataset(Dataset): def __init__(self, *datasets): self.datasets = datasets From 8cf8728050e71a9fda99a68ce23df8e0fe9f1fb5 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 05:55:41 +0900 Subject: [PATCH 28/35] fix pragma: no cover --- pl_bolts/datamodules/binary_mnist_datamodule.py | 8 ++++---- pl_bolts/datamodules/cifar10_datamodule.py | 6 +++--- pl_bolts/datamodules/cityscapes_datamodule.py | 6 +++--- pl_bolts/datamodules/experience_source.py | 4 ++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 6 +++--- pl_bolts/datamodules/imagenet_datamodule.py | 8 ++++---- pl_bolts/datamodules/kitti_datamodule.py | 8 ++++---- pl_bolts/datamodules/mnist_datamodule.py | 6 +++--- pl_bolts/datamodules/sklearn_datamodule.py | 8 ++++---- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 6 +++--- pl_bolts/datamodules/stl10_datamodule.py | 6 +++--- pl_bolts/datamodules/vocdetection_datamodule.py | 8 ++++---- pl_bolts/datasets/cifar10_dataset.py | 2 +- pl_bolts/datasets/imagenet_dataset.py | 6 +++--- pl_bolts/datasets/kitti_dataset.py | 4 ++-- pl_bolts/datasets/mnist_dataset.py | 4 ++-- pl_bolts/datasets/ssl_amdim_datasets.py | 2 +- 17 files changed, 49 insertions(+), 49 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 0aed6b584f..ad806cc575 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -10,8 +10,8 @@ from torchvision.datasets import MNIST from pl_bolts.datasets.mnist_dataset import BinaryMNIST -else: - warn_missing_pkg('torchvision') # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg('torchvision') class BinaryMNISTDataModule(LightningDataModule): @@ -63,8 +63,8 @@ def __init__( """ super().__init__(*args, **kwargs) - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( 'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.' ) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index b6745a4233..893b74c7b8 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -13,7 +13,7 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import CIFAR10 -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('torchvision') @@ -80,8 +80,8 @@ def __init__( """ super().__init__(*args, **kwargs) - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( 'You want to use CIFAR10 dataset loaded from `torchvision` which is not installed yet.' ) diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 9b9e12b0a1..322bd61504 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -7,7 +7,7 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import Cityscapes -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('torchvision') @@ -80,8 +80,8 @@ def __init__( """ super().__init__(*args, **kwargs) - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( 'You want to use CityScapes dataset loaded from `torchvision` which is not installed yet.' ) diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index 16e7a2ce7b..2a62a05a85 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -15,8 +15,8 @@ if _GYM_AVAILABLE: from gym import Env -else: - warn_missing_pkg("gym") # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg("gym") Experience = namedtuple( diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 73182d4020..f17879d1e7 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -8,7 +8,7 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import FashionMNIST -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('torchvision') @@ -60,8 +60,8 @@ def __init__( """ super().__init__(*args, **kwargs) - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( 'You want to use fashion MNIST dataset loaded from `torchvision` which is not installed yet.' ) diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 0cfa25cdce..b67733491f 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -11,8 +11,8 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib -else: - warn_missing_pkg('torchvision') # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg('torchvision') class ImagenetDataModule(LightningDataModule): @@ -69,8 +69,8 @@ def __init__( """ super().__init__(*args, **kwargs) - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( 'You want to use ImageNet dataset loaded from `torchvision` which is not installed yet.' ) diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 30756586b9..556eb02d22 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -11,8 +11,8 @@ if _TORCHVISION_AVAILABLE: import torchvision.transforms as transforms -else: - warn_missing_pkg('torchvision') # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg('torchvision') class KittiDataModule(LightningDataModule): @@ -62,8 +62,8 @@ def __init__( batch_size: the batch size seed: random seed to be used for train/val/test splits """ - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( 'You want to use `torchvision` which is not installed yet.' ) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index a942958048..36af8af6dc 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -8,7 +8,7 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('torchvision') @@ -62,8 +62,8 @@ def __init__( """ super().__init__(*args, **kwargs) - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( 'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.' ) diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index 75bd8ea656..ef4823c325 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -11,8 +11,8 @@ if _SKLEARN_AVAILABLE: from sklearn.utils import shuffle as sk_shuffle -else: - warn_missing_pkg("sklearn") # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg("sklearn") class SklearnDataset(Dataset): @@ -158,8 +158,8 @@ def __init__( # shuffle x and y if shuffle and _SKLEARN_AVAILABLE: X, y = sk_shuffle(X, y, random_state=random_state) - elif shuffle and not _SKLEARN_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover + elif shuffle and not _SKLEARN_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( 'You want to use shuffle function from `scikit-learn` which is not installed yet.' ) diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index ac11a1e551..ab4362be85 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -10,7 +10,7 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('torchvision') @@ -28,8 +28,8 @@ def __init__( ): super().__init__(*args, **kwargs) - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( 'You want to use ImageNet dataset loaded from `torchvision` which is not installed yet.' ) diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 1c4e090363..4844e26739 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -13,7 +13,7 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import STL10 -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('torchvision') @@ -73,8 +73,8 @@ def __init__( """ super().__init__(*args, **kwargs) - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( 'You want to use STL10 dataset loaded from `torchvision` which is not installed yet.' ) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 1175328188..dc5c360140 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -8,8 +8,8 @@ if _TORCHVISION_AVAILABLE: import torchvision.transforms as T from torchvision.datasets import VOCDetection -else: - warn_missing_pkg('torchvision') # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg('torchvision') class Compose(object): @@ -114,8 +114,8 @@ def __init__( *args, **kwargs, ): - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( 'You want to use VOC dataset loaded from `torchvision` which is not installed yet.' ) diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index 74c56a776b..1f6d5fb6b6 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -12,7 +12,7 @@ if _PIL_AVAILABLE: from PIL import Image -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('PIL', pypi_name='Pillow') diff --git a/pl_bolts/datasets/imagenet_dataset.py b/pl_bolts/datasets/imagenet_dataset.py index 7d9fbb51f9..db0c539ec2 100644 --- a/pl_bolts/datasets/imagenet_dataset.py +++ b/pl_bolts/datasets/imagenet_dataset.py @@ -18,7 +18,7 @@ if _TORCHVISION_AVAILABLE: from torchvision.datasets import ImageNet from torchvision.datasets.imagenet import load_meta_file -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('torchvision') ImageNet = object @@ -51,8 +51,8 @@ def __init__( download: kwargs: """ - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' ) diff --git a/pl_bolts/datasets/kitti_dataset.py b/pl_bolts/datasets/kitti_dataset.py index 0a66ccbf9e..a2650ef6e3 100644 --- a/pl_bolts/datasets/kitti_dataset.py +++ b/pl_bolts/datasets/kitti_dataset.py @@ -8,7 +8,7 @@ if _PIL_AVAILABLE: from PIL import Image -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('PIL') @@ -49,7 +49,7 @@ def __init__( void_labels: useless classes to be excluded from training valid_labels: useful classes to include """ - if not _PIL_AVAILABLE: # pragma: no-cover + if not _PIL_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( 'You want to use `PIL` which is not installed yet.' ) diff --git a/pl_bolts/datasets/mnist_dataset.py b/pl_bolts/datasets/mnist_dataset.py index 74f41e58cc..806cfff2d8 100644 --- a/pl_bolts/datasets/mnist_dataset.py +++ b/pl_bolts/datasets/mnist_dataset.py @@ -4,13 +4,13 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('torchvision') MNIST = object if _PIL_AVAILABLE: from PIL import Image -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('PIL', pypi_name='Pillow') diff --git a/pl_bolts/datasets/ssl_amdim_datasets.py b/pl_bolts/datasets/ssl_amdim_datasets.py index 75660562bc..c6ee662a4c 100644 --- a/pl_bolts/datasets/ssl_amdim_datasets.py +++ b/pl_bolts/datasets/ssl_amdim_datasets.py @@ -8,7 +8,7 @@ if _TORCHVISION_AVAILABLE: from torchvision.datasets import CIFAR10 -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('torchvision') CIFAR10 = object From bdd1bd22d2acf597737640dd9d63bd8cd88e039c Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 06:00:36 +0900 Subject: [PATCH 29/35] Fix forward reference in annotations --- pl_bolts/datamodules/experience_source.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index 2a62a05a85..503774bc75 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -17,6 +17,7 @@ from gym import Env else: # pragma: no cover warn_missing_pkg("gym") + Env = object Experience = namedtuple( @@ -179,7 +180,7 @@ def env_actions(self, device) -> List[List[int]]: return actions - def env_step(self, env_idx: int, env: "Env", action: List[int]) -> Experience: + def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience: """ Carries out a step through the given environment using the given action @@ -243,7 +244,7 @@ def pop_rewards_steps(self): class DiscountedExperienceSource(ExperienceSource): """Outputs experiences with a discounted reward over N steps""" - def __init__(self, env: "Env", agent, n_steps: int = 1, gamma: float = 0.99): + def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99): super().__init__(env, agent, (n_steps + 1)) self.gamma = gamma self.steps = n_steps From 8bda71be8958ac75d1a3532dc6b86c8f333c1371 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 06:03:41 +0900 Subject: [PATCH 30/35] binmnist --- pl_bolts/datamodules/binary_mnist_datamodule.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index ad806cc575..5e6b5d0a3b 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -3,13 +3,12 @@ from torch.utils.data import DataLoader, random_split from pl_bolts import _TORCHVISION_AVAILABLE +from pl_bolts.datasets.mnist_dataset import BinaryMNIST from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST - - from pl_bolts.datasets.mnist_dataset import BinaryMNIST else: # pragma: no cover warn_missing_pkg('torchvision') From c2129bde630cdd9fa84820e1b01b05a5e2f16444 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Nov 2020 06:08:23 +0900 Subject: [PATCH 31/35] Same order as imports --- pl_bolts/datamodules/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index d914c4cea1..e608d71010 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -14,23 +14,23 @@ from pl_bolts.datasets.kitti_dataset import KittiDataset __all__ = [ + 'AsynchronousLoader', 'BinaryMNISTDataModule', 'CIFAR10DataModule', 'TinyCIFAR10DataModule', - 'ExperienceSourceDataset', - 'ExperienceSource', + 'CityscapesDataModule', 'DiscountedExperienceSource', + 'ExperienceSource', + 'ExperienceSourceDataset', 'FashionMNISTDataModule', 'ImagenetDataModule', + 'KittiDataModule', 'MNISTDataModule', - 'SklearnDataset', 'SklearnDataModule', + 'SklearnDataset', 'TensorDataset', 'SSLImagenetDataModule', 'STL10DataModule', 'VOCDetectionDataModule', - 'CityscapesDataModule', 'KittiDataset', - 'KittiDataModule', - 'AsynchronousLoader', ] From 81a04d06c68615a434b2d1cd3849d11e0ef87e50 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 14 Dec 2020 15:08:24 +0900 Subject: [PATCH 32/35] Move vars from __init__ to utils/__init__ --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- pl_bolts/datamodules/cityscapes_datamodule.py | 2 +- pl_bolts/datamodules/experience_source.py | 2 +- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/imagenet_datamodule.py | 2 +- pl_bolts/datamodules/kitti_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/sklearn_datamodule.py | 2 +- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 2 +- pl_bolts/datamodules/stl10_datamodule.py | 2 +- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- pl_bolts/datasets/cifar10_dataset.py | 2 +- pl_bolts/datasets/imagenet_dataset.py | 2 +- pl_bolts/datasets/kitti_dataset.py | 2 +- pl_bolts/datasets/mnist_dataset.py | 2 +- pl_bolts/datasets/ssl_amdim_datasets.py | 2 +- pl_bolts/utils/__init__.py | 10 ++++++++++ 18 files changed, 27 insertions(+), 17 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 6d2562ad01..177a6406c2 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -2,7 +2,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split -from pl_bolts import _TORCHVISION_AVAILABLE +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.datasets.mnist_dataset import BinaryMNIST from pl_bolts.utils.warnings import warn_missing_pkg diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 6ff8c2e596..8da54f3cd2 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -5,7 +5,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split -from pl_bolts import _TORCHVISION_AVAILABLE +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10 from pl_bolts.transforms.dataset_normalizations import cifar10_normalization from pl_bolts.utils.warnings import warn_missing_pkg diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 6df548d25a..6445d52f7f 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -1,7 +1,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader -from pl_bolts import _TORCHVISION_AVAILABLE +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index 503774bc75..8d9170d8bc 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -10,7 +10,7 @@ import torch from torch.utils.data import IterableDataset -from pl_bolts import _GYM_AVAILABLE +from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index e5cc07a882..4f6c1bd4cb 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -2,7 +2,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split -from pl_bolts import _TORCHVISION_AVAILABLE +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index b67733491f..a03ee2c9fb 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -4,7 +4,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader -from pl_bolts import _TORCHVISION_AVAILABLE +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization from pl_bolts.utils.warnings import warn_missing_pkg diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 1056b1cda8..c57b1b06e6 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader from torch.utils.data.dataset import random_split -from pl_bolts import _TORCHVISION_AVAILABLE +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.datasets.kitti_dataset import KittiDataset from pl_bolts.utils.warnings import warn_missing_pkg diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 09816b04c6..f753561345 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -2,7 +2,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split -from pl_bolts import _TORCHVISION_AVAILABLE +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index a2184be81d..8ef4bc47e1 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -6,7 +6,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset -from pl_bolts import _SKLEARN_AVAILABLE +from pl_bolts.utils import _SKLEARN_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _SKLEARN_AVAILABLE: diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index b2eaa891ed..63360d8306 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -3,7 +3,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader -from pl_bolts import _TORCHVISION_AVAILABLE +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization from pl_bolts.utils.warnings import warn_missing_pkg diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 81f16ef543..6e93f2750d 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -5,7 +5,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split -from pl_bolts import _TORCHVISION_AVAILABLE +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.datasets.concat_dataset import ConcatDataset from pl_bolts.transforms.dataset_normalizations import stl10_normalization from pl_bolts.utils.warnings import warn_missing_pkg diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index dc5c360140..4235316cd2 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -2,7 +2,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader -from pl_bolts import _TORCHVISION_AVAILABLE +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index 1f6d5fb6b6..cdcf34bbcc 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -6,7 +6,7 @@ import torch from torch import Tensor -from pl_bolts import _PIL_AVAILABLE +from pl_bolts.utils import _PIL_AVAILABLE from pl_bolts.datasets.base_dataset import LightDataset from pl_bolts.utils.warnings import warn_missing_pkg diff --git a/pl_bolts/datasets/imagenet_dataset.py b/pl_bolts/datasets/imagenet_dataset.py index db0c539ec2..e85f1d4b83 100644 --- a/pl_bolts/datasets/imagenet_dataset.py +++ b/pl_bolts/datasets/imagenet_dataset.py @@ -12,7 +12,7 @@ import torch from torch._six import PY3 -from pl_bolts import _TORCHVISION_AVAILABLE +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/datasets/kitti_dataset.py b/pl_bolts/datasets/kitti_dataset.py index a2650ef6e3..80a230e7d2 100644 --- a/pl_bolts/datasets/kitti_dataset.py +++ b/pl_bolts/datasets/kitti_dataset.py @@ -3,7 +3,7 @@ import numpy as np from torch.utils.data import Dataset -from pl_bolts import _PIL_AVAILABLE +from pl_bolts.utils import _PIL_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: diff --git a/pl_bolts/datasets/mnist_dataset.py b/pl_bolts/datasets/mnist_dataset.py index 806cfff2d8..e48722aa7f 100644 --- a/pl_bolts/datasets/mnist_dataset.py +++ b/pl_bolts/datasets/mnist_dataset.py @@ -1,4 +1,4 @@ -from pl_bolts import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/datasets/ssl_amdim_datasets.py b/pl_bolts/datasets/ssl_amdim_datasets.py index c6ee662a4c..0fe935be07 100644 --- a/pl_bolts/datasets/ssl_amdim_datasets.py +++ b/pl_bolts/datasets/ssl_amdim_datasets.py @@ -3,7 +3,7 @@ import numpy as np -from pl_bolts import _TORCHVISION_AVAILABLE +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index e69de29bb2..7bfaeb94d3 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -0,0 +1,10 @@ +import torch +from pytorch_lightning.utilities import _module_available + +_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") + +_TORCHVISION_AVAILABLE = _module_available("torchvision") +_GYM_AVAILABLE = _module_available("gym") +_SKLEARN_AVAILABLE = _module_available("sklearn") +_PIL_AVAILABLE = _module_available("PIL") +_OPENCV_AVAILABLE = _module_available("cv2") From 43e0abcf93b222f26f56b96106eae75d6aed4753 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 14 Dec 2020 15:10:43 +0900 Subject: [PATCH 33/35] Remove vars from __init__ --- pl_bolts/__init__.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pl_bolts/__init__.py b/pl_bolts/__init__.py index a8e52841ee..82a4f14063 100644 --- a/pl_bolts/__init__.py +++ b/pl_bolts/__init__.py @@ -1,7 +1,6 @@ """Root package info.""" import os -from importlib.util import find_spec __version__ = '0.2.5rc1' __author__ = 'PyTorchLightning et al.' @@ -31,12 +30,6 @@ PACKAGE_ROOT = os.path.dirname(__file__) -_TORCHVISION_AVAILABLE = find_spec("torchvision") is not None -_SKLEARN_AVAILABLE = find_spec("sklearn") is not None -_PIL_AVAILABLE = find_spec("PIL") is not None -_GYM_AVAILABLE = find_spec("gym") is not None -_OPENCV_AVAILABLE = find_spec("cv2") is not None - try: # This variable is injected in the __builtins__ by the build process. # It used to enable importing subpackages when the binaries are not built. From 280ed47c022e4673fe62439e6405a9bd0448b74a Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 14 Dec 2020 15:20:30 +0900 Subject: [PATCH 34/35] Update vars --- pl_bolts/utils/__init__.py | 12 +++++++----- pl_bolts/utils/semi_supervised.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index 7bfaeb94d3..d05bac0028 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -1,10 +1,12 @@ +from importlib.util import find_spec + import torch from pytorch_lightning.utilities import _module_available _NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") -_TORCHVISION_AVAILABLE = _module_available("torchvision") -_GYM_AVAILABLE = _module_available("gym") -_SKLEARN_AVAILABLE = _module_available("sklearn") -_PIL_AVAILABLE = _module_available("PIL") -_OPENCV_AVAILABLE = _module_available("cv2") +_TORCHVISION_AVAILABLE = find_spec("torchvision") is not None +_GYM_AVAILABLE = find_spec("gym") is not None +_SKLEARN_AVAILABLE = find_spec("sklearn") is not None +_PIL_AVAILABLE = find_spec("PIL") is not None +_OPENCV_AVAILABLE = find_spec("cv2") is not None diff --git a/pl_bolts/utils/semi_supervised.py b/pl_bolts/utils/semi_supervised.py index 82f9343b01..8363fa44b5 100644 --- a/pl_bolts/utils/semi_supervised.py +++ b/pl_bolts/utils/semi_supervised.py @@ -3,7 +3,7 @@ import numpy as np import torch -from pl_bolts import _SKLEARN_AVAILABLE +from pl_bolts.utils import _SKLEARN_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _SKLEARN_AVAILABLE: From fc579eac701c33f91886f286bec68152bac265ae Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 14 Dec 2020 15:30:23 +0900 Subject: [PATCH 35/35] Apply isort --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- pl_bolts/datamodules/imagenet_datamodule.py | 2 +- pl_bolts/datamodules/kitti_datamodule.py | 2 +- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 2 +- pl_bolts/datamodules/stl10_datamodule.py | 2 +- pl_bolts/datasets/cifar10_dataset.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 177a6406c2..97460dfe5f 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -2,8 +2,8 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split -from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.datasets.mnist_dataset import BinaryMNIST +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 8da54f3cd2..8976bf41db 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -5,9 +5,9 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split -from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10 from pl_bolts.transforms.dataset_normalizations import cifar10_normalization +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index a03ee2c9fb..7984f7c66e 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -4,9 +4,9 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader -from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index c57b1b06e6..aa4895a9b0 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -5,8 +5,8 @@ from torch.utils.data import DataLoader from torch.utils.data.dataset import random_split -from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.datasets.kitti_dataset import KittiDataset +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 63360d8306..0781677321 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -3,9 +3,9 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader -from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 6e93f2750d..98577f2f4a 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -5,9 +5,9 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split -from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.datasets.concat_dataset import ConcatDataset from pl_bolts.transforms.dataset_normalizations import stl10_normalization +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index cdcf34bbcc..8ea84fc9cc 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -6,8 +6,8 @@ import torch from torch import Tensor -from pl_bolts.utils import _PIL_AVAILABLE from pl_bolts.datasets.base_dataset import LightDataset +from pl_bolts.utils import _PIL_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: