Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding types to some of datamodules #462

Merged
merged 81 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
34f5f07
Adding types to datamodules
briankosw Dec 18, 2020
2b55b32
Fixing typing imports
briankosw Dec 18, 2020
ac3377d
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
a4c39c7
Remove more torchvision.transforms typing
briankosw Dec 19, 2020
ffa0cb9
Removing return typing
briankosw Dec 21, 2020
5e6c5d4
Add `None` for optional arguments
briankosw Dec 21, 2020
3a5a0ab
Remove unnecessary import
briankosw Dec 21, 2020
30579ed
Remove unnecessary import
briankosw Dec 21, 2020
5de590c
Merge branch 'types/datamodules' of https://github.com/briankosw/pyto…
briankosw Dec 21, 2020
c675931
Add `None` return type
briankosw Dec 21, 2020
267649c
Add type for torchvision transforms
briankosw Jan 5, 2021
3938fff
Adding types to datamodules
briankosw Dec 18, 2020
cae3f46
Fixing typing imports
briankosw Dec 18, 2020
7af027e
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
7bec605
Remove more torchvision.transforms typing
briankosw Dec 19, 2020
afbc918
Removing return typing
briankosw Dec 21, 2020
17ce335
Add `None` for optional arguments
briankosw Dec 21, 2020
d09f98d
Remove unnecessary import
briankosw Dec 21, 2020
b61fdc0
Add `None` return type
briankosw Dec 21, 2020
f2f4305
Add type for torchvision transforms
briankosw Jan 5, 2021
cd09554
enable check
Borda Jan 5, 2021
797a9d1
Merge branch 'types/datamodules' of https://github.com/briankosw/pyto…
briankosw Jan 6, 2021
0fcd186
Adding types to datamodules
briankosw Dec 18, 2020
a430696
Fixing typing imports
briankosw Dec 18, 2020
a84551e
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
685162c
Remove more torchvision.transforms typing
briankosw Dec 19, 2020
0140837
Removing return typing
briankosw Dec 21, 2020
a6b8d4a
Add `None` for optional arguments
briankosw Dec 21, 2020
de35a55
Remove unnecessary import
briankosw Dec 21, 2020
f521b79
Add `None` return type
briankosw Dec 21, 2020
fa0d271
Add type for torchvision transforms
briankosw Jan 5, 2021
0bc9f7b
Adding types to datamodules
briankosw Dec 18, 2020
05fcef2
Fixing typing imports
briankosw Dec 18, 2020
d92604b
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
a9c641b
Removing return typing
briankosw Dec 21, 2020
314329c
Add `None` return type
briankosw Dec 21, 2020
14ea6b7
enable check
Borda Jan 5, 2021
e97332c
Merge branch 'types/datamodules' of https://github.com/briankosw/pyto…
briankosw Jan 12, 2021
8a7c6f1
Adding types to datamodules
briankosw Dec 18, 2020
3b0ee3c
Fixing typing imports
briankosw Dec 18, 2020
3d1c9a1
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
6cac909
Remove more torchvision.transforms typing
briankosw Dec 19, 2020
c1ea0fb
Removing return typing
briankosw Dec 21, 2020
c04caab
Add `None` for optional arguments
briankosw Dec 21, 2020
5b6bf64
Remove unnecessary import
briankosw Dec 21, 2020
7ce736d
Add `None` return type
briankosw Dec 21, 2020
7309ade
Add type for torchvision transforms
briankosw Jan 5, 2021
cc154a7
Adding types to datamodules
briankosw Dec 18, 2020
cff7330
Fixing typing imports
briankosw Dec 18, 2020
3bbc189
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
5b2401b
Removing return typing
briankosw Dec 21, 2020
7eb32b8
Add `None` return type
briankosw Dec 21, 2020
47cca32
enable check
Borda Jan 5, 2021
52e4811
Adding types to datamodules
briankosw Dec 18, 2020
64f871e
Fixing typing imports
briankosw Dec 18, 2020
bf6ee1c
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
b5c6dcc
Removing return typing
briankosw Dec 21, 2020
984a962
Add `None` return type
briankosw Dec 21, 2020
7894471
Adding types to datamodules
briankosw Dec 18, 2020
3443883
Fixing typing imports
briankosw Dec 18, 2020
51f8f16
Removing torchvision.transforms from return typing
briankosw Dec 19, 2020
3062dba
Removing return typing
briankosw Dec 21, 2020
53ebe33
Add `None` return type
briankosw Dec 21, 2020
ec068e4
Merge branch 'types/datamodules' of https://github.com/briankosw/pyto…
briankosw Jan 12, 2021
c15efdb
Fix rebasing mistakes
briankosw Jan 12, 2021
7bc0c37
Fix flake8
briankosw Jan 12, 2021
a5f3e4f
Fix yapf format
briankosw Jan 12, 2021
e0c05a2
Merge branch 'master' into types/datamodules
Borda Jan 18, 2021
bd97183
Merge branch 'master' into briankosw-types/datamodules
akihironitta Jan 19, 2021
b9c910d
Add types and skip mypy checks on some files
akihironitta Jan 20, 2021
9e222d0
Fix setup.cfg
akihironitta Jan 20, 2021
0c54fdd
Add missing import
akihironitta Jan 20, 2021
8b2e196
isort
akihironitta Jan 20, 2021
9c5dd5c
yapf
akihironitta Jan 20, 2021
c6c97e1
mypy please...
akihironitta Jan 20, 2021
4ac8a5b
Please be quiet mypy and flake8
akihironitta Jan 20, 2021
e847ead
yapf...
akihironitta Jan 20, 2021
1839438
Disable all of yapf, flake8, and mypy
akihironitta Jan 20, 2021
097df6d
Use Callable
akihironitta Jan 20, 2021
9e00c5d
Use Callable
akihironitta Jan 20, 2021
af563a5
Add missing import
akihironitta Jan 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions pl_bolts/datamodules/async_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import re
from queue import Queue
from threading import Thread
from typing import Any, Optional, Union

import torch
from torch._six import container_abcs, string_classes
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset


class AsynchronousLoader(object):
Expand All @@ -26,7 +27,14 @@ class AsynchronousLoader(object):
constructing one here
"""

def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs):
def __init__(
self,
data: Union[DataLoader, Dataset],
device: torch.device = torch.device('cuda', 0),
q_size: int = 10,
num_batches: Optional[int] = None,
**kwargs: Any
):
if isinstance(data, torch.utils.data.DataLoader):
self.dataloader = data
else:
Expand Down Expand Up @@ -105,5 +113,5 @@ def __next__(self):
self.idx += 1
return out

def __len__(self):
def __len__(self) -> int:
return self.num_batches
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
briankosw marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Args:
data_dir: Where to save/load the data
Expand Down
6 changes: 3 additions & 3 deletions pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
):
"""
Args:
data_dir: Where to save/load the data
Expand Down Expand Up @@ -146,14 +146,14 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):

def __init__(
self,
data_dir: str,
data_dir: Optional[str] = None,
val_split: int = 50,
num_workers: int = 16,
num_samples: int = 100,
labels: Optional[Sequence] = (1, 5, 8),
*args: Any,
**kwargs: Any,
) -> None:
):
"""
Args:
data_dir: where to save/load the data
Expand Down
14 changes: 8 additions & 6 deletions pl_bolts/datamodules/cityscapes_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -69,8 +71,8 @@ def __init__(
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args,
**kwargs,
*args: Any,
**kwargs: Any,
):
"""
Args:
Expand Down Expand Up @@ -109,14 +111,14 @@ def __init__(
self.target_transforms = None

@property
def num_classes(self):
def num_classes(self) -> int:
"""
Return:
30
"""
return 30

def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
"""
Cityscapes train set
"""
Expand All @@ -141,7 +143,7 @@ def train_dataloader(self):
)
return loader

def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
"""
Cityscapes val set
"""
Expand All @@ -166,7 +168,7 @@ def val_dataloader(self):
)
return loader

def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
"""
Cityscapes test set
"""
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/fashion_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
):
"""
Args:
data_dir: Where to save/load the data
Expand Down
16 changes: 8 additions & 8 deletions pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Any, Optional

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -58,8 +58,8 @@ def __init__(
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args,
**kwargs,
*args: Any,
**kwargs: Any,
):
"""
Args:
Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(
self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes

@property
def num_classes(self):
def num_classes(self) -> int:
"""
Return:

Expand All @@ -103,7 +103,7 @@ def num_classes(self):
"""
return 1000

def _verify_splits(self, data_dir, split):
def _verify_splits(self, data_dir: str, split: str):
dirs = os.listdir(data_dir)

if split not in dirs:
Expand Down Expand Up @@ -138,7 +138,7 @@ def prepare_data(self):
UnlabeledImagenet.generate_meta_bins(path)
""")

def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
"""
Uses the train split of imagenet2012 and puts away a portion of it for the validation split
"""
Expand All @@ -160,7 +160,7 @@ def train_dataloader(self):
)
return loader

def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
"""
Uses the part of the train split of imagenet2012 that was not used for training via `num_imgs_per_val_class`

Expand All @@ -185,7 +185,7 @@ def val_dataloader(self):
)
return loader

def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
"""
Uses the validation split of imagenet2012 for testing
"""
Expand Down
13 changes: 7 additions & 6 deletions pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Any, Optional

import torch
from pytorch_lightning import LightningDataModule
Expand All @@ -21,7 +22,7 @@ class KittiDataModule(LightningDataModule):

def __init__(
self,
data_dir: str,
data_dir: Optional[str] = None,
val_split: float = 0.2,
test_split: float = 0.1,
num_workers: int = 16,
Expand All @@ -30,8 +31,8 @@ def __init__(
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args,
**kwargs,
*args: Any,
**kwargs: Any,
):
"""
Kitti train, validation and test dataloaders.
Expand Down Expand Up @@ -100,7 +101,7 @@ def __init__(
lengths=[train_len, val_len, test_len],
generator=torch.Generator().manual_seed(self.seed))

def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
loader = DataLoader(
self.trainset,
batch_size=self.batch_size,
Expand All @@ -111,7 +112,7 @@ def train_dataloader(self):
)
return loader

def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
loader = DataLoader(
self.valset,
batch_size=self.batch_size,
Expand All @@ -122,7 +123,7 @@ def val_dataloader(self):
)
return loader

def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
loader = DataLoader(
self.testset,
batch_size=self.batch_size,
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
):
"""
Args:
data_dir: Where to save/load the data
Expand Down
40 changes: 24 additions & 16 deletions pl_bolts/datamodules/sklearn_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Any
from typing import Any, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -42,10 +42,10 @@ def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_tran
self.X_transform = X_transform
self.y_transform = y_transform

def __len__(self):
def __len__(self) -> int:
return len(self.X)

def __getitem__(self, idx):
def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
x = self.X[idx].astype(np.float32)
y = self.Y[idx]

Expand Down Expand Up @@ -89,10 +89,10 @@ def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_
self.X_transform = X_transform
self.y_transform = y_transform

def __len__(self):
def __len__(self) -> int:
return len(self.X)

def __getitem__(self, idx):
def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.X[idx].float()
y = self.Y[idx]

Expand Down Expand Up @@ -145,14 +145,14 @@ def __init__(
x_val=None, y_val=None,
x_test=None, y_test=None,
val_split=0.2, test_split=0.1,
num_workers=2,
random_state=1234,
shuffle=True,
num_workers: int = 2,
random_state: int = 1234,
shuffle: bool = True,
batch_size: int = 16,
pin_memory=False,
drop_last=False,
*args,
**kwargs,
pin_memory: bool = False,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
):

super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -193,12 +193,20 @@ def __init__(

self._init_datasets(X, y, x_val, y_val, x_test, y_test)

def _init_datasets(self, X, y, x_val, y_val, x_test, y_test):
def _init_datasets(
self,
X: np.ndarray,
y: np.ndarray,
x_val: np.ndarray,
y_val: np.ndarray,
x_test: np.ndarray,
y_test: np.ndarray
):
self.train_dataset = SklearnDataset(X, y)
self.val_dataset = SklearnDataset(x_val, y_val)
self.test_dataset = SklearnDataset(x_test, y_test)

def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
Expand All @@ -209,7 +217,7 @@ def train_dataloader(self):
)
return loader

def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
loader = DataLoader(
self.val_dataset,
batch_size=self.batch_size,
Expand All @@ -220,7 +228,7 @@ def val_dataloader(self):
)
return loader

def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
loader = DataLoader(
self.test_dataset,
batch_size=self.batch_size,
Expand Down
Loading