Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding types to some of datamodules #462

Merged
merged 81 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
34f5f07
Adding types to datamodules
briankosw Dec 18, 2020
2b55b32
Fixing typing imports
briankosw Dec 18, 2020
ac3377d
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
a4c39c7
Remove more torchvision.transforms typing
briankosw Dec 19, 2020
ffa0cb9
Removing return typing
briankosw Dec 21, 2020
5e6c5d4
Add `None` for optional arguments
briankosw Dec 21, 2020
3a5a0ab
Remove unnecessary import
briankosw Dec 21, 2020
30579ed
Remove unnecessary import
briankosw Dec 21, 2020
5de590c
Merge branch 'types/datamodules' of https://github.com/briankosw/pyto…
briankosw Dec 21, 2020
c675931
Add `None` return type
briankosw Dec 21, 2020
267649c
Add type for torchvision transforms
briankosw Jan 5, 2021
3938fff
Adding types to datamodules
briankosw Dec 18, 2020
cae3f46
Fixing typing imports
briankosw Dec 18, 2020
7af027e
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
7bec605
Remove more torchvision.transforms typing
briankosw Dec 19, 2020
afbc918
Removing return typing
briankosw Dec 21, 2020
17ce335
Add `None` for optional arguments
briankosw Dec 21, 2020
d09f98d
Remove unnecessary import
briankosw Dec 21, 2020
b61fdc0
Add `None` return type
briankosw Dec 21, 2020
f2f4305
Add type for torchvision transforms
briankosw Jan 5, 2021
cd09554
enable check
Borda Jan 5, 2021
797a9d1
Merge branch 'types/datamodules' of https://github.com/briankosw/pyto…
briankosw Jan 6, 2021
0fcd186
Adding types to datamodules
briankosw Dec 18, 2020
a430696
Fixing typing imports
briankosw Dec 18, 2020
a84551e
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
685162c
Remove more torchvision.transforms typing
briankosw Dec 19, 2020
0140837
Removing return typing
briankosw Dec 21, 2020
a6b8d4a
Add `None` for optional arguments
briankosw Dec 21, 2020
de35a55
Remove unnecessary import
briankosw Dec 21, 2020
f521b79
Add `None` return type
briankosw Dec 21, 2020
fa0d271
Add type for torchvision transforms
briankosw Jan 5, 2021
0bc9f7b
Adding types to datamodules
briankosw Dec 18, 2020
05fcef2
Fixing typing imports
briankosw Dec 18, 2020
d92604b
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
a9c641b
Removing return typing
briankosw Dec 21, 2020
314329c
Add `None` return type
briankosw Dec 21, 2020
14ea6b7
enable check
Borda Jan 5, 2021
e97332c
Merge branch 'types/datamodules' of https://github.com/briankosw/pyto…
briankosw Jan 12, 2021
8a7c6f1
Adding types to datamodules
briankosw Dec 18, 2020
3b0ee3c
Fixing typing imports
briankosw Dec 18, 2020
3d1c9a1
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
6cac909
Remove more torchvision.transforms typing
briankosw Dec 19, 2020
c1ea0fb
Removing return typing
briankosw Dec 21, 2020
c04caab
Add `None` for optional arguments
briankosw Dec 21, 2020
5b6bf64
Remove unnecessary import
briankosw Dec 21, 2020
7ce736d
Add `None` return type
briankosw Dec 21, 2020
7309ade
Add type for torchvision transforms
briankosw Jan 5, 2021
cc154a7
Adding types to datamodules
briankosw Dec 18, 2020
cff7330
Fixing typing imports
briankosw Dec 18, 2020
3bbc189
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
5b2401b
Removing return typing
briankosw Dec 21, 2020
7eb32b8
Add `None` return type
briankosw Dec 21, 2020
47cca32
enable check
Borda Jan 5, 2021
52e4811
Adding types to datamodules
briankosw Dec 18, 2020
64f871e
Fixing typing imports
briankosw Dec 18, 2020
bf6ee1c
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
b5c6dcc
Removing return typing
briankosw Dec 21, 2020
984a962
Add `None` return type
briankosw Dec 21, 2020
7894471
Adding types to datamodules
briankosw Dec 18, 2020
3443883
Fixing typing imports
briankosw Dec 18, 2020
51f8f16
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
3062dba
Removing return typing
briankosw Dec 21, 2020
53ebe33
Add `None` return type
briankosw Dec 21, 2020
ec068e4
Merge branch 'types/datamodules' of https://github.com/briankosw/pyto…
briankosw Jan 12, 2021
c15efdb
Fix rebasing mistakes
briankosw Jan 12, 2021
7bc0c37
Fix flake8
briankosw Jan 12, 2021
a5f3e4f
Fix yapf format
briankosw Jan 12, 2021
e0c05a2
Merge branch 'master' into types/datamodules
Borda Jan 18, 2021
bd97183
Merge branch 'master' into briankosw-types/datamodules
akihironitta Jan 19, 2021
b9c910d
Add types and skip mypy checks on some files
akihironitta Jan 20, 2021
9e222d0
Fix setup.cfg
akihironitta Jan 20, 2021
0c54fdd
Add missing import
akihironitta Jan 20, 2021
8b2e196
isort
akihironitta Jan 20, 2021
9c5dd5c
yapf
akihironitta Jan 20, 2021
c6c97e1
mypy please...
akihironitta Jan 20, 2021
4ac8a5b
Please be quiet mypy and flake8
akihironitta Jan 20, 2021
e847ead
yapf...
akihironitta Jan 20, 2021
1839438
Disable all of yapf, flake8, and mypy
akihironitta Jan 20, 2021
097df6d
Use Callable
akihironitta Jan 20, 2021
9e00c5d
Use Callable
akihironitta Jan 20, 2021
af563a5
Add missing import
akihironitta Jan 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pl_bolts/callbacks/byol_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float:
def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None:
# apply MA weight update
for (name, online_p), (_, target_p) in zip(
online_net.named_parameters(), target_net.named_parameters()
): # type: ignore[union-attr]
online_net.named_parameters(), # type: ignore[union-attr]
target_net.named_parameters() # type: ignore[union-attr]
):
if 'weight' in name:
target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data
5 changes: 3 additions & 2 deletions pl_bolts/callbacks/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def __init__(
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0:
images = self.interpolate_latent_space(
pl_module, latent_dim=pl_module.hparams.latent_dim
) # type: ignore[union-attr]
pl_module,
latent_dim=pl_module.hparams.latent_dim # type: ignore[union-attr]
)
images = torch.cat(images, dim=0) # type: ignore[assignment]

num_images = (self.range_end - self.range_start)**2
Expand Down
29 changes: 20 additions & 9 deletions pl_bolts/datamodules/async_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import re
from queue import Queue
from threading import Thread
from typing import Any, Optional, Union

import torch
from torch._six import container_abcs, string_classes
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset


class AsynchronousLoader(object):
Expand All @@ -26,7 +27,14 @@ class AsynchronousLoader(object):
constructing one here
"""

def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs):
def __init__(
self,
data: Union[DataLoader, Dataset],
device: torch.device = torch.device('cuda', 0),
q_size: int = 10,
num_batches: Optional[int] = None,
**kwargs: Any,
) -> None:
if isinstance(data, torch.utils.data.DataLoader):
self.dataloader = data
else:
Expand All @@ -43,20 +51,20 @@ def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=
self.q_size = q_size

self.load_stream = torch.cuda.Stream(device=device)
self.queue = Queue(maxsize=self.q_size)
self.queue: Queue = Queue(maxsize=self.q_size)

self.idx = 0

self.np_str_obj_array_pattern = re.compile(r'[SaUO]')

def load_loop(self): # The loop that will load into the queue in the background
def load_loop(self) -> None: # The loop that will load into the queue in the background
for i, sample in enumerate(self.dataloader):
self.queue.put(self.load_instance(sample))
if i == len(self):
break

# Recursive loading for each instance based on torch.utils.data.default_collate
def load_instance(self, sample):
def load_instance(self, sample: Any) -> Any:
elem_type = type(sample)

if torch.is_tensor(sample):
Expand All @@ -80,16 +88,19 @@ def load_instance(self, sample):
else:
return sample

def __iter__(self):
def __iter__(self) -> "AsynchronousLoader":
# We don't want to run the thread more than once
# Start a new thread if we are at the beginning of a new epoch, and our current worker is dead
if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0:

# yapf: disable
if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: # type: ignore[has-type] # noqa: E501
self.worker = Thread(target=self.load_loop)
# yapf: enable
Comment on lines +95 to +98
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for the record, I had a collision here between yapf and flake8. Lightning-AI/pytorch-lightning#5591

Copy link
Contributor

@akihironitta akihironitta Jan 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applying yapf will cause flake8's error [W503] line break before binary operator if without these ignores.

self.worker.daemon = True
self.worker.start()
return self

def __next__(self):
def __next__(self) -> torch.Tensor:
# If we've reached the number of batches to return
# or the queue is empty and the worker is dead then exit
done = not self.worker.is_alive() and self.queue.empty()
Expand All @@ -105,5 +116,5 @@ def __next__(self):
self.idx += 1
return out

def __len__(self):
def __len__(self) -> int:
return self.num_batches
6 changes: 3 additions & 3 deletions pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets.mnist_dataset import BinaryMNIST
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
"You want to use transforms loaded from `torchvision` which is not installed yet."
)

super().__init__(
super().__init__( # type: ignore[misc]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy will raise errors without these # type: ignore[misc] due to the bug reported in python/mypy#6799.

data_dir=data_dir,
val_split=val_split,
num_workers=num_workers,
Expand All @@ -98,7 +98,7 @@ def num_classes(self) -> int:
"""
return 10

def default_transforms(self):
def default_transforms(self) -> Callable:
if self.normalize:
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
Expand Down
10 changes: 5 additions & 5 deletions pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
returning them
drop_last: If true drops the last incomplete batch
"""
super().__init__(
super().__init__( # type: ignore[misc]
data_dir=data_dir,
val_split=val_split,
num_workers=num_workers,
Expand All @@ -112,7 +112,7 @@ def num_classes(self) -> int:
"""
return 10

def default_transforms(self):
def default_transforms(self) -> Callable:
if self.normalize:
cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
else:
Expand Down Expand Up @@ -146,7 +146,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):

def __init__(
self,
data_dir: str,
data_dir: Optional[str] = None,
val_split: int = 50,
num_workers: int = 16,
num_samples: int = 100,
Expand All @@ -164,7 +164,7 @@ def __init__(
"""
super().__init__(data_dir, val_split, num_workers, *args, **kwargs)

self.num_samples = num_samples
self.num_samples = num_samples # type: ignore[misc]
self.labels = sorted(labels) if labels is not None else set(range(10))
self.extra_args = dict(num_samples=self.num_samples, labels=self.labels)

Expand Down
23 changes: 13 additions & 10 deletions pl_bolts/datamodules/cityscapes_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# type: ignore[override]
from typing import Any, Callable

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -56,7 +59,7 @@ class CityscapesDataModule(LightningDataModule):
"""

name = 'Cityscapes'
extra_args = {}
extra_args: dict = {}

def __init__(
self,
Expand All @@ -69,9 +72,9 @@ def __init__(
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args,
**kwargs,
):
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: where to load the data from path, i.e. where directory leftImg8bit and gtFine or gtCoarse
Expand Down Expand Up @@ -109,14 +112,14 @@ def __init__(
self.target_transforms = None

@property
def num_classes(self):
def num_classes(self) -> int:
"""
Return:
30
"""
return 30

def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
"""
Cityscapes train set
"""
Expand All @@ -143,7 +146,7 @@ def train_dataloader(self):
)
return loader

def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
"""
Cityscapes val set
"""
Expand All @@ -170,7 +173,7 @@ def val_dataloader(self):
)
return loader

def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
"""
Cityscapes test set
"""
Expand All @@ -196,7 +199,7 @@ def test_dataloader(self):
)
return loader

def _default_transforms(self):
def _default_transforms(self) -> Callable:
cityscapes_transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(
Expand All @@ -205,7 +208,7 @@ def _default_transforms(self):
])
return cityscapes_transforms

def _default_target_transforms(self):
def _default_target_transforms(self) -> Callable:
cityscapes_target_trasnforms = transform_lib.Compose([
transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze())
])
Expand Down
6 changes: 3 additions & 3 deletions pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ExperienceSourceDataset(IterableDataset):
The logic for the experience source and how the batch is generated is defined the Lightning model itself
"""

def __init__(self, generate_batch: Callable):
def __init__(self, generate_batch: Callable) -> None:
self.generate_batch = generate_batch

def __iter__(self) -> Iterable:
Expand Down Expand Up @@ -240,7 +240,7 @@ def pop_rewards_steps(self):
class DiscountedExperienceSource(ExperienceSource):
"""Outputs experiences with a discounted reward over N steps"""

def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99):
def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99) -> None:
super().__init__(env, agent, (n_steps + 1))
self.gamma = gamma
self.steps = n_steps
Expand Down Expand Up @@ -299,5 +299,5 @@ def discount_rewards(self, experiences: Tuple[Experience]) -> float:
"""
total_reward = 0.0
for exp in reversed(experiences):
total_reward = (self.gamma * total_reward) + exp.reward
total_reward = (self.gamma * total_reward) + exp.reward # type: ignore[attr-defined]
return total_reward
6 changes: 3 additions & 3 deletions pl_bolts/datamodules/fashion_mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.utils import _TORCHVISION_AVAILABLE
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
'You want to use FashionMNIST dataset loaded from `torchvision` which is not installed yet.'
)

super().__init__(
super().__init__( # type: ignore[misc]
data_dir=data_dir,
val_split=val_split,
num_workers=num_workers,
Expand All @@ -98,7 +98,7 @@ def num_classes(self) -> int:
"""
return 10

def default_transforms(self):
def default_transforms(self) -> Callable:
if self.normalize:
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
Expand Down
Loading