Skip to content

Commit

Permalink
Adding types to some of datamodules (#462)
Browse files Browse the repository at this point in the history
* Adding types to datamodules

* Fixing typing imports

* Removing torchvision.transforms from return typing

* Remove more torchvision.transforms typing

* Removing return typing

* Add `None` for optional arguments

* Remove unnecessary import

* Remove unnecessary import

* Add `None` return type

* Add type for torchvision transforms

* Adding types to datamodules

* Fixing typing imports

* Removing torchvision.transforms from return typing

* Remove more torchvision.transforms typing

* Removing return typing

* Add `None` for optional arguments

* Remove unnecessary import

* Add `None` return type

* Add type for torchvision transforms

* enable check

* Adding types to datamodules

* Fixing typing imports

* Removing torchvision.transforms from return typing

* Remove more torchvision.transforms typing

* Removing return typing

* Add `None` for optional arguments

* Remove unnecessary import

* Add `None` return type

* Add type for torchvision transforms

* Adding types to datamodules

* Fixing typing imports

* Removing torchvision.transforms from return typing

* Removing return typing

* Add `None` return type

* enable check

* Adding types to datamodules

* Fixing typing imports

* Removing torchvision.transforms from return typing

* Remove more torchvision.transforms typing

* Removing return typing

* Add `None` for optional arguments

* Remove unnecessary import

* Add `None` return type

* Add type for torchvision transforms

* Adding types to datamodules

* Fixing typing imports

* Removing torchvision.transforms from return typing

* Removing return typing

* Add `None` return type

* enable check

* Adding types to datamodules

* Fixing typing imports

* Removing torchvision.transforms from return typing

* Removing return typing

* Add `None` return type

* Adding types to datamodules

* Fixing typing imports

* Removing torchvision.transforms from return typing

* Removing return typing

* Add `None` return type

* Fix rebasing mistakes

* Fix flake8

* Fix yapf format

* Add types and skip mypy checks on some files

* Fix setup.cfg

* Add missing import

* isort

* yapf

* mypy please...

* Please be quiet mypy and flake8

* yapf...

* Disable all of yapf, flake8, and mypy

* Use Callable

* Use Callable

* Add missing import

Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
  • Loading branch information
4 people authored Jan 20, 2021
1 parent c38c579 commit f0cc60b
Show file tree
Hide file tree
Showing 17 changed files with 168 additions and 130 deletions.
5 changes: 3 additions & 2 deletions pl_bolts/callbacks/byol_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float:
def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None:
# apply MA weight update
for (name, online_p), (_, target_p) in zip(
online_net.named_parameters(), target_net.named_parameters()
): # type: ignore[union-attr]
online_net.named_parameters(), # type: ignore[union-attr]
target_net.named_parameters() # type: ignore[union-attr]
):
if 'weight' in name:
target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data
5 changes: 3 additions & 2 deletions pl_bolts/callbacks/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def __init__(
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0:
images = self.interpolate_latent_space(
pl_module, latent_dim=pl_module.hparams.latent_dim
) # type: ignore[union-attr]
pl_module,
latent_dim=pl_module.hparams.latent_dim # type: ignore[union-attr]
)
images = torch.cat(images, dim=0) # type: ignore[assignment]

num_images = (self.range_end - self.range_start)**2
Expand Down
29 changes: 20 additions & 9 deletions pl_bolts/datamodules/async_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import re
from queue import Queue
from threading import Thread
from typing import Any, Optional, Union

import torch
from torch._six import container_abcs, string_classes
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset


class AsynchronousLoader(object):
Expand All @@ -26,7 +27,14 @@ class AsynchronousLoader(object):
constructing one here
"""

def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs):
def __init__(
self,
data: Union[DataLoader, Dataset],
device: torch.device = torch.device('cuda', 0),
q_size: int = 10,
num_batches: Optional[int] = None,
**kwargs: Any,
) -> None:
if isinstance(data, torch.utils.data.DataLoader):
self.dataloader = data
else:
Expand All @@ -43,20 +51,20 @@ def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=
self.q_size = q_size

self.load_stream = torch.cuda.Stream(device=device)
self.queue = Queue(maxsize=self.q_size)
self.queue: Queue = Queue(maxsize=self.q_size)

self.idx = 0

self.np_str_obj_array_pattern = re.compile(r'[SaUO]')

def load_loop(self): # The loop that will load into the queue in the background
def load_loop(self) -> None: # The loop that will load into the queue in the background
for i, sample in enumerate(self.dataloader):
self.queue.put(self.load_instance(sample))
if i == len(self):
break

# Recursive loading for each instance based on torch.utils.data.default_collate
def load_instance(self, sample):
def load_instance(self, sample: Any) -> Any:
elem_type = type(sample)

if torch.is_tensor(sample):
Expand All @@ -80,16 +88,19 @@ def load_instance(self, sample):
else:
return sample

def __iter__(self):
def __iter__(self) -> "AsynchronousLoader":
# We don't want to run the thread more than once
# Start a new thread if we are at the beginning of a new epoch, and our current worker is dead
if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0:

# yapf: disable
if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: # type: ignore[has-type] # noqa: E501
self.worker = Thread(target=self.load_loop)
# yapf: enable
self.worker.daemon = True
self.worker.start()
return self

def __next__(self):
def __next__(self) -> torch.Tensor:
# If we've reached the number of batches to return
# or the queue is empty and the worker is dead then exit
done = not self.worker.is_alive() and self.queue.empty()
Expand All @@ -105,5 +116,5 @@ def __next__(self):
self.idx += 1
return out

def __len__(self):
def __len__(self) -> int:
return self.num_batches
6 changes: 3 additions & 3 deletions pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets import BinaryMNIST
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
"You want to use transforms loaded from `torchvision` which is not installed yet."
)

super().__init__(
super().__init__( # type: ignore[misc]
data_dir=data_dir,
val_split=val_split,
num_workers=num_workers,
Expand All @@ -98,7 +98,7 @@ def num_classes(self) -> int:
"""
return 10

def default_transforms(self):
def default_transforms(self) -> Callable:
if self.normalize:
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
Expand Down
10 changes: 5 additions & 5 deletions pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets import TrialCIFAR10
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
returning them
drop_last: If true drops the last incomplete batch
"""
super().__init__(
super().__init__( # type: ignore[misc]
data_dir=data_dir,
val_split=val_split,
num_workers=num_workers,
Expand All @@ -112,7 +112,7 @@ def num_classes(self) -> int:
"""
return 10

def default_transforms(self):
def default_transforms(self) -> Callable:
if self.normalize:
cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
else:
Expand Down Expand Up @@ -146,7 +146,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):

def __init__(
self,
data_dir: str,
data_dir: Optional[str] = None,
val_split: int = 50,
num_workers: int = 16,
num_samples: int = 100,
Expand All @@ -164,7 +164,7 @@ def __init__(
"""
super().__init__(data_dir, val_split, num_workers, *args, **kwargs)

self.num_samples = num_samples
self.num_samples = num_samples # type: ignore[misc]
self.labels = sorted(labels) if labels is not None else set(range(10))
self.extra_args = dict(num_samples=self.num_samples, labels=self.labels)

Expand Down
23 changes: 13 additions & 10 deletions pl_bolts/datamodules/cityscapes_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# type: ignore[override]
from typing import Any, Callable

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -56,7 +59,7 @@ class CityscapesDataModule(LightningDataModule):
"""

name = 'Cityscapes'
extra_args = {}
extra_args: dict = {}

def __init__(
self,
Expand All @@ -69,9 +72,9 @@ def __init__(
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args,
**kwargs,
):
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: where to load the data from path, i.e. where directory leftImg8bit and gtFine or gtCoarse
Expand Down Expand Up @@ -109,14 +112,14 @@ def __init__(
self.target_transforms = None

@property
def num_classes(self):
def num_classes(self) -> int:
"""
Return:
30
"""
return 30

def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
"""
Cityscapes train set
"""
Expand All @@ -143,7 +146,7 @@ def train_dataloader(self):
)
return loader

def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
"""
Cityscapes val set
"""
Expand All @@ -170,7 +173,7 @@ def val_dataloader(self):
)
return loader

def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
"""
Cityscapes test set
"""
Expand All @@ -196,7 +199,7 @@ def test_dataloader(self):
)
return loader

def _default_transforms(self):
def _default_transforms(self) -> Callable:
cityscapes_transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(
Expand All @@ -205,7 +208,7 @@ def _default_transforms(self):
])
return cityscapes_transforms

def _default_target_transforms(self):
def _default_target_transforms(self) -> Callable:
cityscapes_target_trasnforms = transform_lib.Compose([
transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze())
])
Expand Down
6 changes: 3 additions & 3 deletions pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ExperienceSourceDataset(IterableDataset):
The logic for the experience source and how the batch is generated is defined the Lightning model itself
"""

def __init__(self, generate_batch: Callable):
def __init__(self, generate_batch: Callable) -> None:
self.generate_batch = generate_batch

def __iter__(self) -> Iterable:
Expand Down Expand Up @@ -240,7 +240,7 @@ def pop_rewards_steps(self):
class DiscountedExperienceSource(ExperienceSource):
"""Outputs experiences with a discounted reward over N steps"""

def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99):
def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99) -> None:
super().__init__(env, agent, (n_steps + 1))
self.gamma = gamma
self.steps = n_steps
Expand Down Expand Up @@ -299,5 +299,5 @@ def discount_rewards(self, experiences: Tuple[Experience]) -> float:
"""
total_reward = 0.0
for exp in reversed(experiences):
total_reward = (self.gamma * total_reward) + exp.reward
total_reward = (self.gamma * total_reward) + exp.reward # type: ignore[attr-defined]
return total_reward
6 changes: 3 additions & 3 deletions pl_bolts/datamodules/fashion_mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.utils import _TORCHVISION_AVAILABLE
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
'You want to use FashionMNIST dataset loaded from `torchvision` which is not installed yet.'
)

super().__init__(
super().__init__( # type: ignore[misc]
data_dir=data_dir,
val_split=val_split,
num_workers=num_workers,
Expand All @@ -98,7 +98,7 @@ def num_classes(self) -> int:
"""
return 10

def default_transforms(self):
def default_transforms(self) -> Callable:
if self.normalize:
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
Expand Down
Loading

0 comments on commit f0cc60b

Please sign in to comment.