Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Add missing names to pl_bolts/datasets/__init__.py #493

Merged
merged 12 commits into from
Jan 19, 2021
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets.mnist_dataset import BinaryMNIST
from pl_bolts.datasets import BinaryMNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional, Sequence, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10
from pl_bolts.datasets import TrialCIFAR10
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
from pl_bolts.datasets import UnlabeledImagenet
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down Expand Up @@ -136,7 +136,7 @@ def prepare_data(self):

To generate the meta.bin do the following:

from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
from pl_bolts.datasets import UnlabeledImagenet
path = '/path/to/folder/with/ILSVRC2012_devkit_t12.tar.gz/'
UnlabeledImagenet.generate_meta_bins(path)
"""
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split

from pl_bolts.datasets.kitti_dataset import KittiDataset
from pl_bolts.datasets import KittiDataset
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/ssl_imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
from pl_bolts.datasets import UnlabeledImagenet
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/stl10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split

from pl_bolts.datasets.concat_dataset import ConcatDataset
from pl_bolts.datasets import ConcatDataset
from pl_bolts.transforms.dataset_normalizations import stl10_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down
24 changes: 21 additions & 3 deletions pl_bolts/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
from pl_bolts.datasets.base_dataset import LightDataset
from pl_bolts.datasets.cifar10_dataset import CIFAR10, TrialCIFAR10
from pl_bolts.datasets.concat_dataset import ConcatDataset
from pl_bolts.datasets.dummy_dataset import (
DummyDataset,
DummyDetectionDataset,
RandomDataset,
RandomDictDataset,
RandomDictStringDataset,
)
from pl_bolts.datasets.imagenet_dataset import extract_archive, parse_devkit_archive, UnlabeledImagenet
from pl_bolts.datasets.kitti_dataset import KittiDataset
from pl_bolts.datasets.mnist_dataset import BinaryMNIST
from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed, SSLDatasetMixin

__all__ = [
"RandomDictStringDataset",
"RandomDictDataset",
"RandomDataset",
"LightDataset",
"CIFAR10",
"TrialCIFAR10",
"ConcatDataset",
"DummyDataset",
"DummyDetectionDataset",
"RandomDataset",
"RandomDictDataset",
"RandomDictStringDataset",
"extract_archive",
"parse_devkit_archive",
"UnlabeledImagenet",
"KittiDataset",
"BinaryMNIST",
"CIFAR10Mixed",
"SSLDatasetMixin",
]
2 changes: 1 addition & 1 deletion pl_bolts/datasets/cifar10_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch import Tensor

from pl_bolts.datasets.base_dataset import LightDataset
from pl_bolts.datasets import LightDataset
from pl_bolts.utils import _PIL_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

Expand Down
3 changes: 1 addition & 2 deletions pl_bolts/models/self_supervised/amdim/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from torch.utils.data import random_split

from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed
from pl_bolts.datasets import CIFAR10Mixed, UnlabeledImagenet
from pl_bolts.models.self_supervised.amdim import transforms as amdim_transforms
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down
2 changes: 1 addition & 1 deletion tests/models/self_supervised/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_byol(tmpdir, datadir):
def test_amdim(tmpdir, datadir):
seed_everything()

model = AMDIM(data_dir=datadir, batch_size=2, online_ft=True, encoder='resnet18')
model = AMDIM(data_dir=datadir, batch_size=2, online_ft=True, encoder='resnet18', num_workers=0)
trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model)
loss = trainer.progress_bar_dict['loss']
Expand Down