diff --git a/pl_bolts/callbacks/byol_updates.py b/pl_bolts/callbacks/byol_updates.py index 8f47815521..1c4d3ba7d4 100644 --- a/pl_bolts/callbacks/byol_updates.py +++ b/pl_bolts/callbacks/byol_updates.py @@ -66,7 +66,8 @@ def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float: def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None: # apply MA weight update for (name, online_p), (_, target_p) in zip( - online_net.named_parameters(), target_net.named_parameters() - ): # type: ignore[union-attr] + online_net.named_parameters(), # type: ignore[union-attr] + target_net.named_parameters() # type: ignore[union-attr] + ): if 'weight' in name: target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data diff --git a/pl_bolts/callbacks/variational.py b/pl_bolts/callbacks/variational.py index 5947f40be8..4d5f4c6e23 100644 --- a/pl_bolts/callbacks/variational.py +++ b/pl_bolts/callbacks/variational.py @@ -62,8 +62,9 @@ def __init__( def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0: images = self.interpolate_latent_space( - pl_module, latent_dim=pl_module.hparams.latent_dim - ) # type: ignore[union-attr] + pl_module, + latent_dim=pl_module.hparams.latent_dim # type: ignore[union-attr] + ) images = torch.cat(images, dim=0) # type: ignore[assignment] num_images = (self.range_end - self.range_start)**2 diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 7ded9d9ef1..24fa820d67 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -1,10 +1,11 @@ import re from queue import Queue from threading import Thread +from typing import Any, Optional, Union import torch from torch._six import container_abcs, string_classes -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset class AsynchronousLoader(object): @@ -26,7 +27,14 @@ class AsynchronousLoader(object): constructing one here """ - def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs): + def __init__( + self, + data: Union[DataLoader, Dataset], + device: torch.device = torch.device('cuda', 0), + q_size: int = 10, + num_batches: Optional[int] = None, + **kwargs: Any, + ) -> None: if isinstance(data, torch.utils.data.DataLoader): self.dataloader = data else: @@ -43,20 +51,20 @@ def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches= self.q_size = q_size self.load_stream = torch.cuda.Stream(device=device) - self.queue = Queue(maxsize=self.q_size) + self.queue: Queue = Queue(maxsize=self.q_size) self.idx = 0 self.np_str_obj_array_pattern = re.compile(r'[SaUO]') - def load_loop(self): # The loop that will load into the queue in the background + def load_loop(self) -> None: # The loop that will load into the queue in the background for i, sample in enumerate(self.dataloader): self.queue.put(self.load_instance(sample)) if i == len(self): break # Recursive loading for each instance based on torch.utils.data.default_collate - def load_instance(self, sample): + def load_instance(self, sample: Any) -> Any: elem_type = type(sample) if torch.is_tensor(sample): @@ -80,16 +88,19 @@ def load_instance(self, sample): else: return sample - def __iter__(self): + def __iter__(self) -> "AsynchronousLoader": # We don't want to run the thread more than once # Start a new thread if we are at the beginning of a new epoch, and our current worker is dead - if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: + + # yapf: disable + if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: # type: ignore[has-type] # noqa: E501 self.worker = Thread(target=self.load_loop) + # yapf: enable self.worker.daemon = True self.worker.start() return self - def __next__(self): + def __next__(self) -> torch.Tensor: # If we've reached the number of batches to return # or the queue is empty and the worker is dead then exit done = not self.worker.is_alive() and self.queue.empty() @@ -105,5 +116,5 @@ def __next__(self): self.idx += 1 return out - def __len__(self): + def __len__(self) -> int: return self.num_batches diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 5202c2c261..c43065984b 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.datasets.mnist_dataset import BinaryMNIST @@ -76,7 +76,7 @@ def __init__( "You want to use transforms loaded from `torchvision` which is not installed yet." ) - super().__init__( + super().__init__( # type: ignore[misc] data_dir=data_dir, val_split=val_split, num_workers=num_workers, @@ -98,7 +98,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Callable: if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index cb4dcb0944..e54eb37deb 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10 @@ -85,7 +85,7 @@ def __init__( returning them drop_last: If true drops the last incomplete batch """ - super().__init__( + super().__init__( # type: ignore[misc] data_dir=data_dir, val_split=val_split, num_workers=num_workers, @@ -112,7 +112,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Callable: if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: @@ -146,7 +146,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule): def __init__( self, - data_dir: str, + data_dir: Optional[str] = None, val_split: int = 50, num_workers: int = 16, num_samples: int = 100, @@ -164,7 +164,7 @@ def __init__( """ super().__init__(data_dir, val_split, num_workers, *args, **kwargs) - self.num_samples = num_samples + self.num_samples = num_samples # type: ignore[misc] self.labels = sorted(labels) if labels is not None else set(range(10)) self.extra_args = dict(num_samples=self.num_samples, labels=self.labels) diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 2e6889038e..3f1c223baf 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -1,3 +1,6 @@ +# type: ignore[override] +from typing import Any, Callable + from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -56,7 +59,7 @@ class CityscapesDataModule(LightningDataModule): """ name = 'Cityscapes' - extra_args = {} + extra_args: dict = {} def __init__( self, @@ -69,9 +72,9 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> None: """ Args: data_dir: where to load the data from path, i.e. where directory leftImg8bit and gtFine or gtCoarse @@ -109,14 +112,14 @@ def __init__( self.target_transforms = None @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 30 """ return 30 - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Cityscapes train set """ @@ -143,7 +146,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Cityscapes val set """ @@ -170,7 +173,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Cityscapes test set """ @@ -196,7 +199,7 @@ def test_dataloader(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Callable: cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -205,7 +208,7 @@ def _default_transforms(self): ]) return cityscapes_transforms - def _default_target_transforms(self): + def _default_target_transforms(self) -> Callable: cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) ]) diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index 6c85f76fd2..50ed2a6a7b 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -27,7 +27,7 @@ class ExperienceSourceDataset(IterableDataset): The logic for the experience source and how the batch is generated is defined the Lightning model itself """ - def __init__(self, generate_batch: Callable): + def __init__(self, generate_batch: Callable) -> None: self.generate_batch = generate_batch def __iter__(self) -> Iterable: @@ -240,7 +240,7 @@ def pop_rewards_steps(self): class DiscountedExperienceSource(ExperienceSource): """Outputs experiences with a discounted reward over N steps""" - def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99): + def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99) -> None: super().__init__(env, agent, (n_steps + 1)) self.gamma = gamma self.steps = n_steps @@ -299,5 +299,5 @@ def discount_rewards(self, experiences: Tuple[Experience]) -> float: """ total_reward = 0.0 for exp in reversed(experiences): - total_reward = (self.gamma * total_reward) + exp.reward + total_reward = (self.gamma * total_reward) + exp.reward # type: ignore[attr-defined] return total_reward diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 2de4b31acf..f945e00912 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.utils import _TORCHVISION_AVAILABLE @@ -76,7 +76,7 @@ def __init__( 'You want to use FashionMNIST dataset loaded from `torchvision` which is not installed yet.' ) - super().__init__( + super().__init__( # type: ignore[misc] data_dir=data_dir, val_split=val_split, num_workers=num_workers, @@ -98,7 +98,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Callable: if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 6432d5399a..b9cd811335 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -1,5 +1,6 @@ +# type: ignore[override] import os -from typing import Optional +from typing import Any, Callable, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -58,9 +59,9 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> None: """ Args: data_dir: path to the imagenet dataset file @@ -94,7 +95,7 @@ def __init__( self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes @property - def num_classes(self): + def num_classes(self) -> int: """ Return: @@ -103,7 +104,7 @@ def num_classes(self): """ return 1000 - def _verify_splits(self, data_dir, split): + def _verify_splits(self, data_dir: str, split: str) -> None: dirs = os.listdir(data_dir) if split not in dirs: @@ -112,7 +113,7 @@ def _verify_splits(self, data_dir, split): f' make sure the folder contains a subfolder named {split}' ) - def prepare_data(self): + def prepare_data(self) -> None: """ This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin. @@ -142,7 +143,7 @@ def prepare_data(self): """ ) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Uses the train split of imagenet2012 and puts away a portion of it for the validation split """ @@ -156,7 +157,7 @@ def train_dataloader(self): split='train', transform=transforms ) - loader = DataLoader( + loader: DataLoader = DataLoader( dataset, batch_size=self.batch_size, shuffle=self.shuffle, @@ -166,7 +167,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Uses the part of the train split of imagenet2012 that was not used for training via `num_imgs_per_val_class` @@ -183,7 +184,7 @@ def val_dataloader(self): split='val', transform=transforms ) - loader = DataLoader( + loader: DataLoader = DataLoader( dataset, batch_size=self.batch_size, shuffle=False, @@ -193,7 +194,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Uses the validation split of imagenet2012 for testing """ @@ -202,7 +203,7 @@ def test_dataloader(self): dataset = UnlabeledImagenet( self.data_dir, num_imgs_per_class=-1, meta_dir=self.meta_dir, split='test', transform=transforms ) - loader = DataLoader( + loader: DataLoader = DataLoader( dataset, batch_size=self.batch_size, shuffle=False, @@ -212,7 +213,7 @@ def test_dataloader(self): ) return loader - def train_transform(self): + def train_transform(self) -> Callable: """ The standard imagenet transforms @@ -238,7 +239,7 @@ def train_transform(self): return preprocessing - def val_transform(self): + def val_transform(self) -> Callable: """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 12012a5477..cd6d198185 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -1,4 +1,6 @@ +# type: ignore[override] import os +from typing import Any, Callable, Optional import torch from pytorch_lightning import LightningDataModule @@ -21,7 +23,7 @@ class KittiDataModule(LightningDataModule): def __init__( self, - data_dir: str, + data_dir: Optional[str] = None, val_split: float = 0.2, test_split: float = 0.1, num_workers: int = 16, @@ -30,9 +32,9 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> None: """ Kitti train, validation and test dataloaders. @@ -92,7 +94,7 @@ def __init__( kitti_dataset, lengths=[train_len, val_len, test_len], generator=torch.Generator().manual_seed(self.seed) ) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: loader = DataLoader( self.trainset, batch_size=self.batch_size, @@ -103,7 +105,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: loader = DataLoader( self.valset, batch_size=self.batch_size, @@ -114,7 +116,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: loader = DataLoader( self.testset, batch_size=self.batch_size, @@ -125,7 +127,7 @@ def test_dataloader(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Callable: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324]) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 1e4053770d..0889d71d09 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.utils import _TORCHVISION_AVAILABLE @@ -75,7 +75,7 @@ def __init__( 'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.' ) - super().__init__( + super().__init__( # type: ignore[misc] data_dir=data_dir, val_split=val_split, num_workers=num_workers, @@ -97,7 +97,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Callable: if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index 79d7d8aeb3..45ce688192 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -1,5 +1,5 @@ import math -from typing import Any +from typing import Any, Tuple import numpy as np import torch @@ -29,7 +29,7 @@ class SklearnDataset(Dataset): 506 """ - def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None): + def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None) -> None: """ Args: X: Numpy ndarray @@ -43,10 +43,10 @@ def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_tran self.X_transform = X_transform self.y_transform = y_transform - def __len__(self): + def __len__(self) -> int: return len(self.X) - def __getitem__(self, idx): + def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]: x = self.X[idx].astype(np.float32) y = self.Y[idx] @@ -77,7 +77,7 @@ class TensorDataset(Dataset): 10 """ - def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None): + def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None) -> None: """ Args: X: PyTorch tensor @@ -91,10 +91,10 @@ def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_ self.X_transform = X_transform self.y_transform = y_transform - def __len__(self): + def __len__(self) -> int: return len(self.X) - def __getitem__(self, idx): + def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]: x = self.X[idx].float() y = self.Y[idx] @@ -160,7 +160,7 @@ def __init__( drop_last=False, *args, **kwargs, - ): + ) -> None: super().__init__(*args, **kwargs) self.num_workers = num_workers @@ -200,12 +200,14 @@ def __init__( self._init_datasets(X, y, x_val, y_val, x_test, y_test) - def _init_datasets(self, X, y, x_val, y_val, x_test, y_test): + def _init_datasets( + self, X: np.ndarray, y: np.ndarray, x_val: np.ndarray, y_val: np.ndarray, x_test: np.ndarray, y_test: np.ndarray + ) -> None: self.train_dataset = SklearnDataset(X, y) self.val_dataset = SklearnDataset(x_val, y_val) self.test_dataset = SklearnDataset(x_test, y_test) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: loader = DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -216,7 +218,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: loader = DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -227,7 +229,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: loader = DataLoader( self.test_dataset, batch_size=self.batch_size, diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 14cafc73e1..fc14dd2cae 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -1,4 +1,6 @@ +# type: ignore[override] import os +from typing import Any, Callable, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -20,16 +22,16 @@ class SSLImagenetDataModule(LightningDataModule): # pragma: no cover def __init__( self, - data_dir, - meta_dir=None, - num_workers=16, + data_dir: str, + meta_dir: Optional[str] = None, + num_workers: int = 16, batch_size: int = 32, shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: # pragma: no cover @@ -46,10 +48,10 @@ def __init__( self.drop_last = drop_last @property - def num_classes(self): + def num_classes(self) -> int: return 1000 - def _verify_splits(self, data_dir, split): + def _verify_splits(self, data_dir: str, split: str) -> None: dirs = os.listdir(data_dir) if split not in dirs: @@ -58,7 +60,7 @@ def _verify_splits(self, data_dir, split): f' folder contains a subfolder named {split}' ) - def prepare_data(self): + def prepare_data(self) -> None: # imagenet cannot be downloaded... must provide path to folder with the train/val splits self._verify_splits(self.data_dir, 'train') self._verify_splits(self.data_dir, 'val') @@ -83,7 +85,7 @@ def prepare_data(self): """ ) - def train_dataloader(self, num_images_per_class=-1, add_normalize=False): + def train_dataloader(self, num_images_per_class: int = -1, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms dataset = UnlabeledImagenet( @@ -93,7 +95,7 @@ def train_dataloader(self, num_images_per_class=-1, add_normalize=False): split='train', transform=transforms ) - loader = DataLoader( + loader: DataLoader = DataLoader( dataset, batch_size=self.batch_size, shuffle=self.shuffle, @@ -103,7 +105,7 @@ def train_dataloader(self, num_images_per_class=-1, add_normalize=False): ) return loader - def val_dataloader(self, num_images_per_class=50, add_normalize=False): + def val_dataloader(self, num_images_per_class: int = 50, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = UnlabeledImagenet( @@ -113,7 +115,7 @@ def val_dataloader(self, num_images_per_class=50, add_normalize=False): split='val', transform=transforms ) - loader = DataLoader( + loader: DataLoader = DataLoader( dataset, batch_size=self.batch_size, shuffle=False, @@ -123,7 +125,7 @@ def val_dataloader(self, num_images_per_class=50, add_normalize=False): ) return loader - def test_dataloader(self, num_images_per_class, add_normalize=False): + def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms dataset = UnlabeledImagenet( @@ -133,7 +135,7 @@ def test_dataloader(self, num_images_per_class, add_normalize=False): split='test', transform=transforms ) - loader = DataLoader( + loader: DataLoader = DataLoader( dataset, batch_size=self.batch_size, shuffle=False, @@ -143,6 +145,6 @@ def test_dataloader(self, num_images_per_class, add_normalize=False): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Callable: mnist_transforms = transform_lib.Compose([transform_lib.ToTensor(), imagenet_normalization()]) return mnist_transforms diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 50554c75e6..43ff3ebb6a 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -1,5 +1,6 @@ +# type: ignore[override] import os -from typing import Optional +from typing import Any, Callable, Optional import torch from pytorch_lightning import LightningDataModule @@ -63,9 +64,9 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> None: """ Args: data_dir: where to save/load the data @@ -99,10 +100,10 @@ def __init__( self.num_unlabeled_samples = 100000 - unlabeled_val_split @property - def num_classes(self): + def num_classes(self) -> int: return 10 - def prepare_data(self): + def prepare_data(self) -> None: """ Downloads the unlabeled, train and test split """ @@ -110,7 +111,7 @@ def prepare_data(self): STL10(self.data_dir, split='train', download=True, transform=transform_lib.ToTensor()) STL10(self.data_dir, split='test', download=True, transform=transform_lib.ToTensor()) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Loads the 'unlabeled' split minus a portion set aside for validation via `unlabeled_val_split`. """ @@ -132,7 +133,7 @@ def train_dataloader(self): ) return loader - def train_dataloader_mixed(self): + def train_dataloader_mixed(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data and 'train' (labeled) data. both portions have a subset removed for validation via `unlabeled_val_split` and `train_val_split` @@ -169,7 +170,7 @@ def train_dataloader_mixed(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data set aside for validation The val dataset = (unlabeled - train_val_split) @@ -192,12 +193,12 @@ def val_dataloader(self): batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, - drpo_last=self.drop_last, + drop_last=self.drop_last, pin_memory=self.pin_memory ) return loader - def val_dataloader_mixed(self): + def val_dataloader_mixed(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data set aside for validation along with the portion of the 'train' dataset to be used for validation @@ -239,7 +240,7 @@ def val_dataloader_mixed(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Loads the test split of STL10 @@ -260,7 +261,7 @@ def test_dataloader(self): ) return loader - def train_dataloader_labeled(self): + def train_dataloader_labeled(self) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = STL10(self.data_dir, split='train', download=False, transform=transforms) @@ -279,7 +280,7 @@ def train_dataloader_labeled(self): ) return loader - def val_dataloader_labeled(self): + def val_dataloader_labeled(self) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = STL10(self.data_dir, split='train', download=False, transform=transforms) labeled_length = len(dataset) @@ -298,6 +299,6 @@ def val_dataloader_labeled(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Callable: data_transforms = transform_lib.Compose([transform_lib.ToTensor(), stl10_normalization()]) return data_transforms diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 5a6f4af4c2..d15d9c59d4 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -1,6 +1,6 @@ import os from abc import abstractmethod -from typing import Any, List, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch from pytorch_lightning import LightningDataModule @@ -9,12 +9,12 @@ class VisionDataModule(LightningDataModule): - EXTRA_ARGS = {} + EXTRA_ARGS: dict = {} name: str = "" #: Dataset class to use - dataset_cls = ... + dataset_cls: type #: A tuple describing the shape of the data - dims: tuple = ... + dims: tuple def __init__( self, @@ -56,7 +56,7 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self) -> None: + def prepare_data(self, *args: Any, **kwargs: Any) -> None: """ Saves files to data_dir """ @@ -88,7 +88,7 @@ def _split_dataset(self, dataset: Dataset, train: bool = True) -> Dataset: """ Splits the dataset into train and validation set """ - len_dataset = len(dataset) + len_dataset = len(dataset) # type: ignore[arg-type] splits = self._get_splits(len_dataset) dataset_train, dataset_val = random_split(dataset, splits, generator=torch.Generator().manual_seed(self.seed)) @@ -113,18 +113,18 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self): + def default_transforms(self) -> Callable: """ Default transform for the dataset """ - def train_dataloader(self) -> DataLoader: + def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: """ The train dataloader """ return self._data_loader(self.dataset_train, shuffle=self.shuffle) - def val_dataloader(self) -> DataLoader: + def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The val dataloader """ return self._data_loader(self.dataset_val) - def test_dataloader(self) -> DataLoader: + def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The test dataloader """ return self._data_loader(self.dataset_test) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index d7a18fd3e8..97b63cc86e 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -1,3 +1,5 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + import torch from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -17,11 +19,11 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms, image_transforms=None): + def __init__(self, transforms: List[Callable], image_transforms: Optional[Callable] = None) -> None: self.transforms = transforms self.image_transforms = image_transforms - def __call__(self, image, target): + def __call__(self, image: Any, target: Any) -> Tuple[torch.Tensor, torch.Tensor]: for t in self.transforms: image, target = t(image, target) if self.image_transforms: @@ -29,7 +31,7 @@ def __call__(self, image, target): return image, target -def _collate_fn(batch): +def _collate_fn(batch: List[torch.Tensor]) -> tuple: return tuple(zip(*batch)) @@ -58,7 +60,7 @@ def _collate_fn(batch): ) -def _prepare_voc_instance(image, target): +def _prepare_voc_instance(image: Any, target: Dict[str, Any]): """ Prepares VOC dataset into appropriate target for fasterrcnn @@ -116,9 +118,9 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> None: if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( 'You want to use VOC dataset loaded from `torchvision` which is not installed yet.' @@ -135,21 +137,23 @@ def __init__( self.drop_last = drop_last @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 21 """ return 21 - def prepare_data(self): + def prepare_data(self) -> None: """ Saves VOCDetection files to data_dir """ VOCDetection(self.data_dir, year=self.year, image_set="train", download=True) VOCDetection(self.data_dir, year=self.year, image_set="val", download=True) - def train_dataloader(self, batch_size=1, image_transforms=None): + def train_dataloader( + self, batch_size: int = 1, image_transforms: Union[List[Callable], Callable] = None + ) -> DataLoader: """ VOCDetection train set uses the `train` subset @@ -172,7 +176,7 @@ def train_dataloader(self, batch_size=1, image_transforms=None): ) return loader - def val_dataloader(self, batch_size=1, image_transforms=None): + def val_dataloader(self, batch_size: int = 1, image_transforms: Optional[List[Callable]] = None) -> DataLoader: """ VOCDetection val set uses the `val` subset @@ -195,7 +199,7 @@ def val_dataloader(self, batch_size=1, image_transforms=None): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Callable: if self.normalize: voc_transforms = transform_lib.Compose([ transform_lib.ToTensor(), diff --git a/setup.cfg b/setup.cfg index 080004f375..5883253ce5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,10 +72,20 @@ show_error_codes = True disallow_untyped_defs = True ignore_missing_imports = True -[mypy-pl_bolts.datamodules.*] +[mypy-pl_bolts.datasets.*] ignore_errors = True -[mypy-pl_bolts.datasets.*] +[mypy-pl_bolts.datamodules] +# pl_bolts/datamodules/__init__.py +ignore_errors = True + +[mypy-pl_bolts.datamodules.experience_source] +ignore_errors = True + +[mypy-pl_bolts.datamodules.sklearn_datamodule] +ignore_errors = True + +[mypy-pl_bolts.datamodules.vocdetection_datamodule] ignore_errors = True [mypy-pl_bolts.losses.*]