Skip to content

Commit

Permalink
Complete sanity functions for faster_rcnn and RetinaNet (#46)
Browse files Browse the repository at this point in the history
* Complete sanity fit functions for faster_rcnn

* Add amp and logging with time for rcnn sanity fit

* Test frcnn sanity fit

* Format files

* bug fixes

* fixes bugs, removes formatting

* adds tests for retinanet

* bug fixes

* fixes bugs

* tries fixing bugs

* tries a fix

* fixes bug

* lower CI

Co-authored-by: Aditya Oke <47158509+oke-aditya@users.noreply.github.com>
  • Loading branch information
ramaneswaran and oke-aditya authored Nov 23, 2020
1 parent 75080a7 commit 0d1ee43
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 26 deletions.
2 changes: 1 addition & 1 deletion quickvision/models/detection/faster_rcnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from quickvision.models.detection.faster_rcnn.model_factory import (
create_vision_fastercnn,
create_fastercnn_backbone
create_fastercnn_backbone,
)

from quickvision.models.detection.faster_rcnn.engine import (
Expand Down
81 changes: 78 additions & 3 deletions quickvision/models/detection/faster_rcnn/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,47 @@ def train_sanity_fit(model: nn.Module, train_loader,
log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
fp16: : (optional) If True uses PyTorch native mixed precision Training.
"""
pass

model = model.to(device)
model.train()

cnt = 0
last_idx = len(train_loader) - 1
train_sanity_start = time.time()

if fp16 is True:
scaler = amp.GradScaler()

for batch_idx, (inputs, targets) in enumerate(train_loader):
last_batch = batch_idx == last_idx
images = list(image.to(device) for image in inputs)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

if fp16 is True:
with amp.autocast():
loss_dict = model(images, targets)
else:
loss_dict = model(images, targets)

cnt += 1

if last_batch or batch_idx % log_interval == 0:
print(f"Train sanity check passed for batch till {batch_idx} batches")

if num_batches is not None:
if cnt >= num_batches:
print(f"Done till {num_batches} train batches")
print("All specified batches done")
train_sanity_end = time.time()
print(f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}")
return True

train_sanity_end = time.time()

print("All specified batches done")
print(f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}")

return True


def val_sanity_fit(model: nn.Module, val_loader,
Expand All @@ -263,7 +303,38 @@ def val_sanity_fit(model: nn.Module, val_loader,
Useful is data is too big even for sanity check.
log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
"""
pass
model = model.to(device)
model.eval()

cnt = 0
val_sanity_start = time.time()
last_idx = len(val_loader) - 1

with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(val_loader):
last_batch = batch_idx == last_idx
images = list(image.to(device) for image in inputs)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

out = model(images)
cnt += 1

if last_batch or (batch_idx % log_interval) == 0:
print(f"Val sanity check passed for batch till {batch_idx} batches")

if num_batches is not None:
if cnt >= num_batches:
print(f"Done till {num_batches} validation batches")
print("All specified batches done")
val_sanity_end = time.time()
print(f"Val sanity fit check passed in time {val_sanity_end-val_sanity_start}")
return True

val_sanity_end = time.time()
print("All specified batches done")
print(f"Validation sanity check pased in time {val_sanity_end-val_sanity_start}")

return True


def sanity_fit(model: nn.Module, train_loader, val_loader,
Expand All @@ -285,4 +356,8 @@ def sanity_fit(model: nn.Module, train_loader, val_loader,
log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
"""

pass
sanity_train = train_sanity_fit(model, train_loader, device, num_batches, log_interval, fp16)

sanity_val = val_sanity_fit(model, val_loader, device, num_batches, log_interval)

return True
11 changes: 6 additions & 5 deletions quickvision/models/detection/faster_rcnn/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from quickvision.models.components import create_torchvision_backbone
from quickvision.models.detection.faster_rcnn import create_fastercnn_backbone
from quickvision.models.detection.utils import _evaluate_iou, _evaluate_giou
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn, FasterRCNN, FastRCNNPredictor
from torchvision.models.detection.faster_rcnn import (fasterrcnn_resnet50_fpn, FasterRCNN, FastRCNNPredictor,)

__all__ = ["lit_frcnn"]

Expand All @@ -16,8 +16,8 @@ class lit_frcnn(pl.LightningModule):

def __init__(self, learning_rate: float = 0.0001, num_classes: int = 91,
backbone: str = None, fpn: bool = True,
pretrained_backbone: str = None, trainable_backbone_layers: int = 3,
**kwargs, ):
pretrained_backbone: str = None, trainable_backbone_layers: int = 3, **kwargs,):

"""
Args:
learning_rate: the learning rate
Expand All @@ -39,7 +39,7 @@ def __init__(self, learning_rate: float = 0.0001, num_classes: int = 91,

else:
backbone_model = create_fastercnn_backbone(self.backbone, fpn, pretrained_backbone,
trainable_backbone_layers, **kwargs)
trainable_backbone_layers, **kwargs,)
self.model = FasterRCNN(backbone_model, num_classes=num_classes, **kwargs)

def forward(self, x):
Expand Down Expand Up @@ -70,4 +70,5 @@ def validation_epoch_end(self, outs):
return {"avg_val_iou": avg_iou, "avg_val_giou": avg_giou, "log": logs}

def configure_optimizers(self):
return torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9, weight_decay=0.005,)
return torch.optim.SGD(self.model.parameters(), lr=self.learning_rate,
momentum=0.9, weight_decay=0.005,)
7 changes: 3 additions & 4 deletions quickvision/models/detection/faster_rcnn/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

import torch.nn as nn
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn, FasterRCNN, FastRCNNPredictor
from torchvision.models.detection.faster_rcnn import (fasterrcnn_resnet50_fpn, FasterRCNN, FastRCNNPredictor,)
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from quickvision.models.components import create_torchvision_backbone

Expand Down Expand Up @@ -29,7 +28,7 @@ def create_vision_fastercnn(num_classes: int = 91, backbone: nn.Module = None, *


def create_fastercnn_backbone(backbone: str, fpn: bool = True, pretrained: str = None,
trainable_backbone_layers: int = 3, **kwargs) -> nn.Module:
trainable_backbone_layers: int = 3, **kwargs,) -> nn.Module:

"""
Args:
Expand All @@ -47,7 +46,7 @@ def create_fastercnn_backbone(backbone: str, fpn: bool = True, pretrained: str =
# Creates a torchvision resnet model with fpn added.
print("Resnet FPN Backbones works only for imagenet weights")
backbone = resnet_fpn_backbone(backbone, pretrained=True,
trainable_layers=trainable_backbone_layers, **kwargs)
trainable_layers=trainable_backbone_layers, **kwargs,)
else:
# This does not create fpn backbone, it is supported for all models
print("FPN is not supported for Non Resnet Backbones")
Expand Down
127 changes: 124 additions & 3 deletions quickvision/models/detection/retinanet/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,137 @@ def fit(model: nn.Module, epochs: int, train_loader, val_loader,
def train_sanity_fit(model: nn.Module, train_loader,
device: str, num_batches: int = None, log_interval: int = 100,
fp16: bool = False,):
pass

"""
Performs Sanity fit over train loader.
Use this to dummy check your fit function. It does not calculate metrics, timing, or does checkpointing.
It iterates over both train_loader and val_loader for given batches.
Note: - It does not to loss.backward().
Args:
model : A pytorch Faster RCNN Model.
train_loader : Train loader.
device : "cuda" or "cpu"
num_batches : (optional) Integer To limit sanity fit over certain batches.
Useful is data is too big even for sanity check.
log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
fp16: : (optional) If True uses PyTorch native mixed precision Training.
"""

model = model.to(device)
model.train()

cnt = 0
last_idx = len(train_loader) - 1
train_sanity_start = time.time()

if fp16 is True:
scaler = amp.GradScaler()

for batch_idx, (inputs, targets) in enumerate(train_loader):
last_batch = batch_idx == last_idx
images = list(image.to(device) for image in inputs)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

if fp16 is True:
with amp.autocast():
loss_dict = model(images, targets)
else:
loss_dict = model(images, targets)

cnt += 1

if last_batch or (batch_idx % log_interval) == 0:
print(f"Train sanity check passed for batch till {batch_idx} batches")

if num_batches is not None:
if cnt >= num_batches:
print(f"Done till {num_batches} train batches")
print("All specified batches done")
train_sanity_end = time.time()
print(f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}")
return True

train_sanity_end = time.time()

print("All specified batches done")
print(f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}")

return True


def val_sanity_fit(model: nn.Module, val_loader,
device: str, num_batches: int = None,
log_interval: int = 100,):
pass

"""
Performs Sanity fit over valid loader.
Use this to dummy check your fit function. It does not calculate metrics, timing, or does checkpointing.
It iterates over both train_loader and val_loader for given batches.
Note: - It does not to loss.backward().
Args:
model : A pytorch Faster RCNN Model.
val_loader : Validation loader.
device : "cuda" or "cpu"
num_batches : (optional) Integer To limit sanity fit over certain batches.
Useful is data is too big even for sanity check.
log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
"""
model = model.to(device)
model.eval()

cnt = 0
val_sanity_start = time.time()
last_idx = len(val_loader) - 1

with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(val_loader):
last_batch = batch_idx == last_idx
images = list(image.to(device) for image in inputs)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

out = model(images)

cnt += 1

if last_batch or (batch_idx % log_interval) == 0:
print(f"Val sanity check passed for batch till {batch_idx} batches")

if num_batches is not None:
if cnt >= num_batches:
print(f"Done till {num_batches} validation batches")
print("All specified batches done")
val_sanity_end = time.time()
print(f"Val sanity fit check passed in time {val_sanity_end-val_sanity_start}")
return True

val_sanity_end = time.time()
print("All specified batches done")
print(f"Validation sanity check pased in time {val_sanity_end-val_sanity_start}")

return True


def sanity_fit(model: nn.Module, train_loader, val_loader,
device: str, num_batches: int = None,
log_interval: int = 100, fp16: bool = False,):
pass

"""
Performs Sanity fit over train loader and valid loader.
Use this to dummy check your fit function. It does not calculate metrics, timing, or does checkpointing.
It iterates over both train_loader and val_loader for given batches.
Note: - It does not to loss.backward().
Args:
model : A pytorch Faster RCNN Model.
train_loader : Training loader.
val_loader : Validation loader.
device : "cuda" or "cpu"
num_batches : (optional) Integer To limit sanity fit over certain batches.
Useful is data is too big even for sanity check.
log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
"""

sanity_train = train_sanity_fit(model, train_loader, device, num_batches, log_interval, fp16)

sanity_val = val_sanity_fit(model, val_loader, device, num_batches, log_interval)

return True
40 changes: 35 additions & 5 deletions tests/test_frcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,21 +168,51 @@ def test_fit_cuda(self):
# self.assertTrue(exp_k3 in history["val"].keys())

def test_train_sanity_fit(self):
pass
for bbone in fpn_supported_models:
backbone = faster_rcnn.create_fastercnn_backbone(backbone=bbone, pretrained=None)
self.assertTrue(isinstance(backbone, nn.Module))
frcnn_model = faster_rcnn.create_vision_fastercnn(num_classes=3, backbone=backbone)
self.assertTrue(isinstance(frcnn_model, nn.Module))
result = faster_rcnn.train_sanity_fit(frcnn_model, train_loader, "cpu", num_batches=10)
self.assertTrue(result)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_train_sanity_fit_cuda(self):
pass
for bbone in fpn_supported_models:
backbone = faster_rcnn.create_fastercnn_backbone(backbone=bbone, pretrained=None)
self.assertTrue(isinstance(backbone, nn.Module))
frcnn_model = faster_rcnn.create_vision_fastercnn(num_classes=3, backbone=backbone)
self.assertTrue(isinstance(frcnn_model, nn.Module))
result = faster_rcnn.train_sanity_fit(frcnn_model, train_loader, "cuda", num_batches=10, fp16=True)
self.assertTrue(result)

def test_val_sanity_fit(self):
pass
for bbone in fpn_supported_models:
backbone = faster_rcnn.create_fastercnn_backbone(backbone=bbone, pretrained=None)
self.assertTrue(isinstance(backbone, nn.Module))
frcnn_model = faster_rcnn.create_vision_fastercnn(num_classes=3, backbone=backbone)
self.assertTrue(isinstance(frcnn_model, nn.Module))
result = faster_rcnn.val_sanity_fit(frcnn_model, val_loader, "cpu", num_batches=10)
self.assertTrue(result)

def test_sanity_fit(self):
pass
for bbone in fpn_supported_models:
backbone = faster_rcnn.create_fastercnn_backbone(backbone=bbone, pretrained=None)
self.assertTrue(isinstance(backbone, nn.Module))
frcnn_model = faster_rcnn.create_vision_fastercnn(num_classes=3, backbone=backbone)
self.assertTrue(isinstance(frcnn_model, nn.Module))
result = faster_rcnn.sanity_fit(frcnn_model, train_loader, val_loader, "cpu", num_batches=10)
self.assertTrue(result)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_sanity_fit_cuda(self):
pass
for bbone in fpn_supported_models:
backbone = faster_rcnn.create_fastercnn_backbone(backbone=bbone, pretrained=None)
self.assertTrue(isinstance(backbone, nn.Module))
frcnn_model = faster_rcnn.create_vision_fastercnn(num_classes=3, backbone=backbone)
self.assertTrue(isinstance(frcnn_model, nn.Module))
result = faster_rcnn.sanity_fit(frcnn_model, train_loader, val_loader, "cuda", num_batches=10, fp16=True)
self.assertTrue(result)


class LightningTester(unittest.TestCase):
Expand Down
Loading

0 comments on commit 0d1ee43

Please sign in to comment.