Skip to content

Commit

Permalink
Adds Backbones to FRCNN Take 2 (#475)
Browse files Browse the repository at this point in the history
* Adds frcnn

* update tests

* aaply suggestions from review

* refactor later

* leave any

* Apply suggestions from code review

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>

* remove unused imports

* black, isort, cleanup

* fix init issue

* noqa

* Fix isort

* black format and refacotor

* Fix isort again

* Fix isort

* decouple sutff

* yapf formatting

* Yapf reformat

* reformat

* converts warnings

* Formatted

* Fix bug

* Apply suggestions from code review

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>

* format yapf

* refactor

* chlog

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
3 people authored Jan 18, 2021
1 parent 056f836 commit 523699a
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Intersection over Union Metric/Loss ([#469](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/469))
- Added SimSiam model ([#407](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/407))
- Added gradient verification callback ([#465](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/465))
- Added Backbones to FRCNN ([#475](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/475))

### Changed

Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/datasets/dummy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def _random_bbox(self):

def __getitem__(self, idx: int):
img = torch.rand(self.img_shape)
boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)])
labels = torch.randint(self.num_classes, (self.num_boxes, ))
boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)], dtype=torch.float32)
labels = torch.randint(self.num_classes, (self.num_boxes, ), dtype=torch.long)
return img, {"boxes": boxes, "labels": labels}


Expand Down
5 changes: 1 addition & 4 deletions pl_bolts/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
__all__ = []

try:
from pl_bolts.models.detection import components # noqa: F401
from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401
except ModuleNotFoundError: # pragma: no-cover
pass # pragma: no-cover
else:
__all__.append('FasterRCNN')
3 changes: 3 additions & 0 deletions pl_bolts/models/detection/components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pl_bolts.models.detection.components.torchvision_backbones import create_torchvision_backbone

__all__ = ["create_torchvision_backbone"]
28 changes: 28 additions & 0 deletions pl_bolts/models/detection/components/_supported_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
import torchvision

TORCHVISION_MODEL_ZOO = {
"vgg11": torchvision.models.vgg11,
"vgg13": torchvision.models.vgg13,
"vgg16": torchvision.models.vgg16,
"vgg19": torchvision.models.vgg19,
"resnet18": torchvision.models.resnet18,
"resnet34": torchvision.models.resnet34,
"resnet50": torchvision.models.resnet50,
"resnet101": torchvision.models.resnet101,
"resnet152": torchvision.models.resnet152,
"resnext50_32x4d": torchvision.models.resnext50_32x4d,
"resnext50_32x8d": torchvision.models.resnext101_32x8d,
"mnasnet0_5": torchvision.models.mnasnet0_5,
"mnasnet0_75": torchvision.models.mnasnet0_75,
"mnasnet1_0": torchvision.models.mnasnet1_0,
"mnasnet1_3": torchvision.models.mnasnet1_3,
"mobilenet_v2": torchvision.models.mobilenet_v2,
}

else: # pragma: no cover
warn_missing_pkg("torchvision")
TORCHVISION_MODEL_ZOO = {}
98 changes: 98 additions & 0 deletions pl_bolts/models/detection/components/torchvision_backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Optional, Tuple

import torch.nn as nn

from pl_bolts.models.detection.components._supported_models import TORCHVISION_MODEL_ZOO
from pl_bolts.utils import _TORCHVISION_AVAILABLE # noqa: F401
from pl_bolts.utils.warnings import warn_missing_pkg # noqa: F401


def _create_backbone_generic(model: nn.Module, out_channels: int) -> nn.Module:
"""
Generic Backbone creater. It removes the last linear layer.
Args:
model: torch.nn model
out_channels: Number of out_channels in last layer.
"""
modules_total = list(model.children())
modules = modules_total[:-1]
ft_backbone = nn.Sequential(*modules)
ft_backbone.out_channels = out_channels
return ft_backbone


# Use this when you have Adaptive Pooling layer in End.
# When Model.features is not applicable.
def _create_backbone_adaptive(model: nn.Module, out_channels: Optional[int] = None) -> nn.Module:
"""
Creates backbone by removing linear after Adaptive Pooling layer.
Args:
model: torch.nn model with adaptive pooling layer
out_channels: Number of out_channels in last layer
"""
if out_channels is None:
modules_total = list(model.children())
out_channels = modules_total[-1].in_features
return _create_backbone_generic(model, out_channels=out_channels)


def _create_backbone_features(model: nn.Module, out_channels: int) -> nn.Module:
"""
Creates backbone from feature sequential block.
Args:
model: torch.nn model with features as sequential block.
out_channels: Number of out_channels in last layer.
"""
ft_backbone = model.features
ft_backbone.out_channels = out_channels
return ft_backbone


def create_torchvision_backbone(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
"""
Creates CNN backbone from Torchvision.
Args:
model_name: Name of the model. E.g. resnet18
pretrained: Pretrained weights dataset "imagenet", etc
"""

model_selected = TORCHVISION_MODEL_ZOO[model_name]
net = model_selected(pretrained=pretrained)

if model_name == "mobilenet_v2":
out_channels = 1280
ft_backbone = _create_backbone_features(net, 1280)
return ft_backbone, out_channels

elif model_name in ["vgg11", "vgg13", "vgg16", "vgg19"]:
out_channels = 512
ft_backbone = _create_backbone_features(net, out_channels)
return ft_backbone, out_channels

elif model_name in ["resnet18", "resnet34"]:
out_channels = 512
ft_backbone = _create_backbone_adaptive(net, out_channels)
return ft_backbone, out_channels

elif model_name in [
"resnet50",
"resnet101",
"resnet152",
"resnext50_32x4d",
"resnext101_32x8d",
]:
out_channels = 2048
ft_backbone = _create_backbone_adaptive(net, out_channels)
return ft_backbone, out_channels

elif model_name in ["mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3"]:
out_channels = 1280
ft_backbone = _create_backbone_adaptive(net, out_channels)
return ft_backbone, out_channels

else:
raise ValueError(f"Unsupported model: '{model_name}'")
4 changes: 4 additions & 0 deletions pl_bolts/models/detection/faster_rcnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from pl_bolts.models.detection.faster_rcnn.backbones import create_fasterrcnn_backbone
from pl_bolts.models.detection.faster_rcnn.faster_rcnn_module import FasterRCNN

__all__ = ["create_fasterrcnn_backbone", "FasterRCNN"]
41 changes: 41 additions & 0 deletions pl_bolts/models/detection/faster_rcnn/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Any, Optional

import torch.nn as nn

from pl_bolts.models.detection.components import create_torchvision_backbone
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
else: # pragma: no cover
warn_missing_pkg("torchvision")


def create_fasterrcnn_backbone(
backbone: str,
fpn: bool = True,
pretrained: Optional[str] = None,
trainable_backbone_layers: int = 3,
**kwargs: Any
) -> nn.Module:
"""
Args:
backbone:
Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152",
"resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2",
as resnets with fpn backbones.
Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101",
"resnet152", "resnext101_32x8d", "mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19",
fpn: If True then constructs fpn as well.
pretrained: If None creates imagenet weights backbone.
trainable_backbone_layers: number of trainable resnet layers starting from final block.
"""

if fpn:
# Creates a torchvision resnet model with fpn added.
backbone = resnet_fpn_backbone(backbone, pretrained=True, trainable_layers=trainable_backbone_layers, **kwargs)
else:
# This does not create fpn backbone, it is supported for all models
backbone, _ = create_torchvision_backbone(backbone, pretrained)
return backbone
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from argparse import ArgumentParser
from typing import Any, Optional

import pytorch_lightning as pl
import torch

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
from torchvision.models.detection import faster_rcnn, fasterrcnn_resnet50_fpn
if _TORCHVISION_AVAILABLE:
from torchvision.models.detection.faster_rcnn import FasterRCNN as torchvision_FasterRCNN
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn, FastRCNNPredictor
from torchvision.ops import box_iou
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover

from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone
else: # pragma: no cover
warn_missing_pkg("torchvision")


def _evaluate_iou(target, pred):
Expand Down Expand Up @@ -42,43 +47,51 @@ class FasterRCNN(pl.LightningModule):
# PascalVOC
python faster_rcnn.py --gpus 1 --pretrained True
"""

def __init__(
self,
learning_rate: float = 0.0001,
num_classes: int = 91,
backbone: Optional[str] = None,
fpn: bool = True,
pretrained: bool = False,
pretrained_backbone: bool = True,
trainable_backbone_layers: int = 3,
replace_head: bool = True,
**kwargs,
**kwargs: Any,
):
"""
Args:
learning_rate: the learning rate
num_classes: number of detection classes (including background)
backbone: Pretained backbone CNN architecture.
fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs.
pretrained: if true, returns a model pre-trained on COCO train2017
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block
"""
super().__init__()

model = fasterrcnn_resnet50_fpn(
# num_classes=num_classes,
pretrained=pretrained,
pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers,
)
self.learning_rate = learning_rate
self.num_classes = num_classes
self.backbone = backbone
if backbone is None:
self.model = fasterrcnn_resnet50_fpn(
pretrained=pretrained,
pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers,
)

in_features = self.model.roi_heads.box_predictor.cls_score.in_features
self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, self.num_classes)

if replace_head:
in_features = model.roi_heads.box_predictor.cls_score.in_features
head = faster_rcnn.FastRCNNPredictor(in_features, num_classes)
model.roi_heads.box_predictor = head
else:
assert num_classes == 91, "replace_head must be true to change num_classes"

self.model = model
self.learning_rate = learning_rate
backbone_model = create_fasterrcnn_backbone(
self.backbone,
fpn,
pretrained_backbone,
trainable_backbone_layers,
**kwargs,
)
self.model = torchvision_FasterRCNN(backbone_model, num_classes=num_classes, **kwargs)

def forward(self, x):
self.model.eval()
Expand Down Expand Up @@ -119,10 +132,11 @@ def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--learning_rate", type=float, default=0.0001)
parser.add_argument("--num_classes", type=int, default=91)
parser.add_argument("--backbone", type=str, default=None)
parser.add_argument("--fpn", type=bool, default=True)
parser.add_argument("--pretrained", type=bool, default=False)
parser.add_argument("--pretrained_backbone", type=bool, default=True)
parser.add_argument("--trainable_backbone_layers", type=int, default=3)
parser.add_argument("--replace_head", type=bool, default=True)
return parser


Expand All @@ -132,6 +146,8 @@ def run_cli():
pl.seed_everything(42)
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--data_dir", type=str, default=".")
parser.add_argument("--batch_size", type=int, default=1)
parser = FasterRCNN.add_model_specific_args(parser)

args = parser.parse_args()
Expand Down
9 changes: 9 additions & 0 deletions tests/models/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,12 @@ def test_fasterrcnn_train(tmpdir):

trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model, train_dataloader=train_dl, val_dataloaders=valid_dl)


def test_fasterrcnn_bbone_train(tmpdir):
model = FasterRCNN(backbone="resnet18", fpn=True, pretrained_backbone=True)
train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)
valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)

trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model, train_dl, valid_dl)

0 comments on commit 523699a

Please sign in to comment.