-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds Backbones to FRCNN Take 2 (#475)
* Adds frcnn * update tests * aaply suggestions from review * refactor later * leave any * Apply suggestions from code review Co-authored-by: Akihiro Nitta <nitta@akihironitta.com> * remove unused imports * black, isort, cleanup * fix init issue * noqa * Fix isort * black format and refacotor * Fix isort again * Fix isort * decouple sutff * yapf formatting * Yapf reformat * reformat * converts warnings * Formatted * Fix bug * Apply suggestions from code review Co-authored-by: Akihiro Nitta <nitta@akihironitta.com> * format yapf * refactor * chlog Co-authored-by: Akihiro Nitta <nitta@akihironitta.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
- Loading branch information
1 parent
056f836
commit 523699a
Showing
10 changed files
with
225 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,5 @@ | ||
__all__ = [] | ||
|
||
try: | ||
from pl_bolts.models.detection import components # noqa: F401 | ||
from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401 | ||
except ModuleNotFoundError: # pragma: no-cover | ||
pass # pragma: no-cover | ||
else: | ||
__all__.append('FasterRCNN') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from pl_bolts.models.detection.components.torchvision_backbones import create_torchvision_backbone | ||
|
||
__all__ = ["create_torchvision_backbone"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from pl_bolts.utils import _TORCHVISION_AVAILABLE | ||
from pl_bolts.utils.warnings import warn_missing_pkg | ||
|
||
if _TORCHVISION_AVAILABLE: | ||
import torchvision | ||
|
||
TORCHVISION_MODEL_ZOO = { | ||
"vgg11": torchvision.models.vgg11, | ||
"vgg13": torchvision.models.vgg13, | ||
"vgg16": torchvision.models.vgg16, | ||
"vgg19": torchvision.models.vgg19, | ||
"resnet18": torchvision.models.resnet18, | ||
"resnet34": torchvision.models.resnet34, | ||
"resnet50": torchvision.models.resnet50, | ||
"resnet101": torchvision.models.resnet101, | ||
"resnet152": torchvision.models.resnet152, | ||
"resnext50_32x4d": torchvision.models.resnext50_32x4d, | ||
"resnext50_32x8d": torchvision.models.resnext101_32x8d, | ||
"mnasnet0_5": torchvision.models.mnasnet0_5, | ||
"mnasnet0_75": torchvision.models.mnasnet0_75, | ||
"mnasnet1_0": torchvision.models.mnasnet1_0, | ||
"mnasnet1_3": torchvision.models.mnasnet1_3, | ||
"mobilenet_v2": torchvision.models.mobilenet_v2, | ||
} | ||
|
||
else: # pragma: no cover | ||
warn_missing_pkg("torchvision") | ||
TORCHVISION_MODEL_ZOO = {} |
98 changes: 98 additions & 0 deletions
98
pl_bolts/models/detection/components/torchvision_backbones.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from typing import Optional, Tuple | ||
|
||
import torch.nn as nn | ||
|
||
from pl_bolts.models.detection.components._supported_models import TORCHVISION_MODEL_ZOO | ||
from pl_bolts.utils import _TORCHVISION_AVAILABLE # noqa: F401 | ||
from pl_bolts.utils.warnings import warn_missing_pkg # noqa: F401 | ||
|
||
|
||
def _create_backbone_generic(model: nn.Module, out_channels: int) -> nn.Module: | ||
""" | ||
Generic Backbone creater. It removes the last linear layer. | ||
Args: | ||
model: torch.nn model | ||
out_channels: Number of out_channels in last layer. | ||
""" | ||
modules_total = list(model.children()) | ||
modules = modules_total[:-1] | ||
ft_backbone = nn.Sequential(*modules) | ||
ft_backbone.out_channels = out_channels | ||
return ft_backbone | ||
|
||
|
||
# Use this when you have Adaptive Pooling layer in End. | ||
# When Model.features is not applicable. | ||
def _create_backbone_adaptive(model: nn.Module, out_channels: Optional[int] = None) -> nn.Module: | ||
""" | ||
Creates backbone by removing linear after Adaptive Pooling layer. | ||
Args: | ||
model: torch.nn model with adaptive pooling layer | ||
out_channels: Number of out_channels in last layer | ||
""" | ||
if out_channels is None: | ||
modules_total = list(model.children()) | ||
out_channels = modules_total[-1].in_features | ||
return _create_backbone_generic(model, out_channels=out_channels) | ||
|
||
|
||
def _create_backbone_features(model: nn.Module, out_channels: int) -> nn.Module: | ||
""" | ||
Creates backbone from feature sequential block. | ||
Args: | ||
model: torch.nn model with features as sequential block. | ||
out_channels: Number of out_channels in last layer. | ||
""" | ||
ft_backbone = model.features | ||
ft_backbone.out_channels = out_channels | ||
return ft_backbone | ||
|
||
|
||
def create_torchvision_backbone(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: | ||
""" | ||
Creates CNN backbone from Torchvision. | ||
Args: | ||
model_name: Name of the model. E.g. resnet18 | ||
pretrained: Pretrained weights dataset "imagenet", etc | ||
""" | ||
|
||
model_selected = TORCHVISION_MODEL_ZOO[model_name] | ||
net = model_selected(pretrained=pretrained) | ||
|
||
if model_name == "mobilenet_v2": | ||
out_channels = 1280 | ||
ft_backbone = _create_backbone_features(net, 1280) | ||
return ft_backbone, out_channels | ||
|
||
elif model_name in ["vgg11", "vgg13", "vgg16", "vgg19"]: | ||
out_channels = 512 | ||
ft_backbone = _create_backbone_features(net, out_channels) | ||
return ft_backbone, out_channels | ||
|
||
elif model_name in ["resnet18", "resnet34"]: | ||
out_channels = 512 | ||
ft_backbone = _create_backbone_adaptive(net, out_channels) | ||
return ft_backbone, out_channels | ||
|
||
elif model_name in [ | ||
"resnet50", | ||
"resnet101", | ||
"resnet152", | ||
"resnext50_32x4d", | ||
"resnext101_32x8d", | ||
]: | ||
out_channels = 2048 | ||
ft_backbone = _create_backbone_adaptive(net, out_channels) | ||
return ft_backbone, out_channels | ||
|
||
elif model_name in ["mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3"]: | ||
out_channels = 1280 | ||
ft_backbone = _create_backbone_adaptive(net, out_channels) | ||
return ft_backbone, out_channels | ||
|
||
else: | ||
raise ValueError(f"Unsupported model: '{model_name}'") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from pl_bolts.models.detection.faster_rcnn.backbones import create_fasterrcnn_backbone | ||
from pl_bolts.models.detection.faster_rcnn.faster_rcnn_module import FasterRCNN | ||
|
||
__all__ = ["create_fasterrcnn_backbone", "FasterRCNN"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from typing import Any, Optional | ||
|
||
import torch.nn as nn | ||
|
||
from pl_bolts.models.detection.components import create_torchvision_backbone | ||
from pl_bolts.utils import _TORCHVISION_AVAILABLE | ||
from pl_bolts.utils.warnings import warn_missing_pkg | ||
|
||
if _TORCHVISION_AVAILABLE: | ||
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone | ||
else: # pragma: no cover | ||
warn_missing_pkg("torchvision") | ||
|
||
|
||
def create_fasterrcnn_backbone( | ||
backbone: str, | ||
fpn: bool = True, | ||
pretrained: Optional[str] = None, | ||
trainable_backbone_layers: int = 3, | ||
**kwargs: Any | ||
) -> nn.Module: | ||
""" | ||
Args: | ||
backbone: | ||
Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152", | ||
"resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2", | ||
as resnets with fpn backbones. | ||
Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101", | ||
"resnet152", "resnext101_32x8d", "mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19", | ||
fpn: If True then constructs fpn as well. | ||
pretrained: If None creates imagenet weights backbone. | ||
trainable_backbone_layers: number of trainable resnet layers starting from final block. | ||
""" | ||
|
||
if fpn: | ||
# Creates a torchvision resnet model with fpn added. | ||
backbone = resnet_fpn_backbone(backbone, pretrained=True, trainable_layers=trainable_backbone_layers, **kwargs) | ||
else: | ||
# This does not create fpn backbone, it is supported for all models | ||
backbone, _ = create_torchvision_backbone(backbone, pretrained) | ||
return backbone |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters