Skip to content

Commit

Permalink
Adds type annotations (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
oke-aditya authored Nov 22, 2020
1 parent 9f7239b commit 75080a7
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 41 deletions.
23 changes: 13 additions & 10 deletions quickvision/models/classification/cnn/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# https://github.com/oke-aditya/pytorch_cnn_trainer

import torch
from torch import nn
from torch.cuda import amp
from quickvision import utils
from quickvision.metrics import accuracy
Expand All @@ -15,7 +16,8 @@
"val_sanity_fit", "sanity_fit", ]


def train_step(model, train_loader, criterion, device, optimizer,
def train_step(model: nn.Module, train_loader, criterion,
device: str, optimizer,
scheduler=None, num_batches: int = None,
log_interval: int = 100, grad_penalty: bool = False,
scaler=None,):
Expand Down Expand Up @@ -141,7 +143,8 @@ def train_step(model, train_loader, criterion, device, optimizer,
return metrics


def val_step(model, val_loader, criterion, device, num_batches=None,
def val_step(model: nn.Module, val_loader, criterion,
device: str, num_batches=None,
log_interval: int = 100):

"""
Expand Down Expand Up @@ -214,8 +217,8 @@ def val_step(model, val_loader, criterion, device, num_batches=None,
return metrics


def fit(model, epochs, train_loader, val_loader, criterion,
device, optimizer, scheduler=None, early_stopper=None,
def fit(model: nn.Module, epochs: int, train_loader, val_loader, criterion,
device: str, optimizer, scheduler=None, early_stopper=None,
num_batches: int = None, log_interval: int = 100,
grad_penalty: bool = False, fp16: bool = False,
swa_start: int = None, swa_scheduler=None,):
Expand Down Expand Up @@ -310,8 +313,8 @@ def fit(model, epochs, train_loader, val_loader, criterion,
return history


def train_sanity_fit(model, train_loader, criterion,
device, num_batches: int = None, log_interval: int = 100,
def train_sanity_fit(model: nn.Module, train_loader, criterion,
device: str, num_batches: int = None, log_interval: int = 100,
grad_penalty: bool = False, fp16: bool = False,):

"""
Expand Down Expand Up @@ -381,8 +384,8 @@ def train_sanity_fit(model, train_loader, criterion,
return True


def val_sanity_fit(model, val_loader,
criterion, device, num_batches: int = None,
def val_sanity_fit(model: nn.Module, val_loader,
criterion, device: str, num_batches: int = None,
log_interval: int = 100,):

"""
Expand Down Expand Up @@ -433,8 +436,8 @@ def val_sanity_fit(model, val_loader,
return True


def sanity_fit(model, train_loader, val_loader,
criterion, device, num_batches: int = None,
def sanity_fit(model: nn.Module, train_loader, val_loader,
criterion, device: str, num_batches: int = None,
log_interval: int = 100, grad_penalty: bool = False,
fp16: bool = False,):

Expand Down
6 changes: 3 additions & 3 deletions quickvision/models/components/torchvision_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
__all__ = ["create_torchvision_backbone"]


def _create_backbone_generic(model, out_channels: int):
def _create_backbone_generic(model: nn.Module, out_channels: int):
"""
Generic Backbone creater. It removes the last linear layer.
Args:
Expand All @@ -23,7 +23,7 @@ def _create_backbone_generic(model, out_channels: int):

# Use this when you have Adaptive Pooling layer in End.
# When Model.features is not applicable.
def _create_backbone_adaptive(model, out_channels: int = None):
def _create_backbone_adaptive(model: nn.Module, out_channels: int = None):
"""
Creates backbone by removing linear after Adaptive Pooling layer.
Args:
Expand All @@ -36,7 +36,7 @@ def _create_backbone_adaptive(model, out_channels: int = None):
return _create_backbone_generic(model, out_channels=out_channels)


def _create_backbone_features(model, out_channels: int):
def _create_backbone_features(model: nn.Module, out_channels: int):
"""
Creates backbone from feature sequential block.
Args:
Expand Down
13 changes: 8 additions & 5 deletions quickvision/models/detection/detr/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import nn
from torch.cuda import amp
from quickvision import utils
from tqdm import tqdm
Expand All @@ -9,8 +10,9 @@
"val_sanity_fit", "sanity_fit", ]


def train_step(model, train_loader, criterion, device, optimizer, scheduler=None,
num_batches: int = None, log_interval: int = 100, scaler=None,):
def train_step(model: nn.Module, train_loader, criterion, device: str,
optimizer, scheduler=None, num_batches: int = None,
log_interval: int = 100, scaler=None,):
"""
Performs one step of training. Calculates loss, forward pass, computes gradient and returns metrics.
Args:
Expand Down Expand Up @@ -105,7 +107,7 @@ def train_step(model, train_loader, criterion, device, optimizer, scheduler=None
return metrics


def val_step(model, val_loader, criterion, device,
def val_step(model: nn.Module, val_loader, criterion, device,
num_batches: int = None, log_interval: int = 100):
"""
Performs one step of validation. Calculates loss, forward pass and returns metrics.
Expand Down Expand Up @@ -177,8 +179,9 @@ def val_step(model, val_loader, criterion, device,
return metrics


def fit(model, epochs, train_loader, val_loader, criterion,
device, optimizer, scheduler=None,
def fit(model: nn.Module, epochs: int, train_loader,
val_loader, criterion,
device: str, optimizer, scheduler=None,
num_batches: int = None, log_interval: int = 100,
fp16: bool = False, ):

Expand Down
4 changes: 2 additions & 2 deletions quickvision/models/detection/detr/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class vision_detr(nn.Module):
num_queries: Number of queries for transformer in Detr.
backbone: Backbone created from create_detr_backbone.
"""
def __init__(self, num_classes, num_queries, backbone):
def __init__(self, num_classes: int, num_queries: int, backbone: str):
super().__init__()
self.num_classes = num_classes
self.num_queries = num_queries
Expand All @@ -29,7 +29,7 @@ def forward(self, images):
return self.model(images)


def create_vision_detr(num_classes: int, num_queries: int, backbone):
def create_vision_detr(num_classes: int, num_queries: int, backbone: str):
"""
Creates Detr Model for Object Detection
Args:
Expand Down
21 changes: 11 additions & 10 deletions quickvision/models/detection/faster_rcnn/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import nn, Tensor
from torch.cuda import amp
from quickvision import utils
from tqdm import tqdm
Expand All @@ -11,7 +12,7 @@
"val_sanity_fit", "sanity_fit", ]


def train_step(model, train_loader, device, optimizer,
def train_step(model: nn.Module, train_loader, device: str, optimizer,
scheduler=None, num_batches: int = None,
log_interval: int = 100, scaler=None,):

Expand Down Expand Up @@ -111,7 +112,7 @@ def train_step(model, train_loader, device, optimizer,
return metrics


def val_step(model, val_loader, device, num_batches=None,
def val_step(model: nn.Module, val_loader, device: str, num_batches: int = None,
log_interval: int = 100):

"""
Expand Down Expand Up @@ -171,8 +172,8 @@ def val_step(model, val_loader, device, num_batches=None,
return metrics


def fit(model, epochs, train_loader, val_loader,
device, optimizer, scheduler=None,
def fit(model: nn.Module, epochs: int, train_loader, val_loader,
device: str, optimizer, scheduler=None,
num_batches: int = None, log_interval: int = 100,
fp16: bool = False, ):

Expand Down Expand Up @@ -224,8 +225,8 @@ def fit(model, epochs, train_loader, val_loader,
return history


def train_sanity_fit(model, train_loader,
device, num_batches: int = None, log_interval: int = 100,
def train_sanity_fit(model: nn.Module, train_loader,
device: str, num_batches: int = None, log_interval: int = 100,
fp16: bool = False,):

"""
Expand All @@ -245,8 +246,8 @@ def train_sanity_fit(model, train_loader,
pass


def val_sanity_fit(model, val_loader,
device, num_batches: int = None,
def val_sanity_fit(model: nn.Module, val_loader,
device: str, num_batches: int = None,
log_interval: int = 100,):

"""
Expand All @@ -265,8 +266,8 @@ def val_sanity_fit(model, val_loader,
pass


def sanity_fit(model, train_loader, val_loader,
device, num_batches: int = None,
def sanity_fit(model: nn.Module, train_loader, val_loader,
device: str, num_batches: int = None,
log_interval: int = 100, fp16: bool = False,):

"""
Expand Down
23 changes: 12 additions & 11 deletions quickvision/models/detection/retinanet/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import nn
from torch.cuda import amp
from quickvision import utils
from tqdm import tqdm
Expand All @@ -10,7 +11,7 @@
"val_sanity_fit", "sanity_fit", ]


def train_step(model, train_loader, device, optimizer,
def train_step(model: nn.Module, train_loader, device: str, optimizer,
scheduler=None, num_batches: int = None,
log_interval: int = 100, scaler=None,):

Expand Down Expand Up @@ -102,8 +103,8 @@ def train_step(model, train_loader, device, optimizer,
return metrics


def val_step(model, val_loader, device, num_batches=None,
log_interval: int = 100):
def val_step(model: nn.Module, val_loader, device: str,
num_batches=None, log_interval: int = 100):

"""
Performs one step of validation. Calculates loss, forward pass and returns metrics.
Expand Down Expand Up @@ -162,8 +163,8 @@ def val_step(model, val_loader, device, num_batches=None,
return metrics


def fit(model, epochs, train_loader, val_loader,
device, optimizer, scheduler=None,
def fit(model: nn.Module, epochs: int, train_loader, val_loader,
device: str, optimizer, scheduler=None,
num_batches: int = None, log_interval: int = 100,
fp16: bool = False, ):

Expand Down Expand Up @@ -215,19 +216,19 @@ def fit(model, epochs, train_loader, val_loader,
return history


def train_sanity_fit(model, train_loader,
device, num_batches: int = None, log_interval: int = 100,
def train_sanity_fit(model: nn.Module, train_loader,
device: str, num_batches: int = None, log_interval: int = 100,
fp16: bool = False,):
pass


def val_sanity_fit(model, val_loader,
device, num_batches: int = None,
def val_sanity_fit(model: nn.Module, val_loader,
device: str, num_batches: int = None,
log_interval: int = 100,):
pass


def sanity_fit(model, train_loader, val_loader,
device, num_batches: int = None,
def sanity_fit(model: nn.Module, train_loader, val_loader,
device: str, num_batches: int = None,
log_interval: int = 100, fp16: bool = False,):
pass

0 comments on commit 75080a7

Please sign in to comment.