Skip to content

Commit

Permalink
Post-paper Detection Optimizations (#5444)
Browse files Browse the repository at this point in the history
* Use frozen BN only if pre-trained.

* Add LSJ and ability to from scratch training.

* Fixing formatter

* Adding `--opt` and `--norm-weight-decay` support in Detection.

* Fix error message

* Make ScaleJitter proportional.

* Adding more norm layers in split_normalization_params.

* Add FixedSizeCrop

* Temporary fix for fill values on PIL

* Fix the bug on fill.

* Add RandomShortestSize.

* Skip resize when an augmentation method is used.

* multiscale in [480, 800]

* Add missing star

* Add new RetinaNet variant.

* Add tests.

* Update expected file for old retina

* Fixing tests

* Add FrozenBN to retinav2

* Fix network initialization issues

* Adding BN support in MaskRCNNHeads and FPN

* Adding support of FasterRCNNHeads

* Introduce norm_layers in backbone utils.

* Bigger RPN head + 2x rcnn v2 models.

* Adding gIoU support to retinanet

* Fix assert

* Add back nesterov momentum

* Rename and extend `FastRCNNConvFCHead` to support arbitrary FCs

* Fix linter
  • Loading branch information
datumbox authored Apr 5, 2022
1 parent 63576c9 commit 08cc9a7
Show file tree
Hide file tree
Showing 11 changed files with 563 additions and 54 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
35 changes: 35 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,14 @@ def _check_input_backprop(model, inputs):
"googlenet": lambda x: x.logits,
"inception_v3": lambda x: x.logits,
"fasterrcnn_resnet50_fpn": lambda x: x[1],
"fasterrcnn_resnet50_fpn_v2": lambda x: x[1],
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
"fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
"maskrcnn_resnet50_fpn": lambda x: x[1],
"maskrcnn_resnet50_fpn_v2": lambda x: x[1],
"keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn_v2": lambda x: x[1],
"ssd300_vgg16": lambda x: x[1],
"ssdlite320_mobilenet_v3_large": lambda x: x[1],
"fcos_resnet50_fpn": lambda x: x[1],
Expand Down Expand Up @@ -227,6 +230,7 @@ def _check_input_backprop(model, inputs):
"fcn_resnet101",
"lraspp_mobilenet_v3_large",
"maskrcnn_resnet50_fpn",
"maskrcnn_resnet50_fpn_v2",
)

# The tests for the following quantized models are flaky possibly due to inconsistent
Expand All @@ -246,6 +250,13 @@ def _check_input_backprop(model, inputs):
"max_size": 224,
"input_shape": (3, 224, 224),
},
"retinanet_resnet50_fpn_v2": {
"num_classes": 20,
"score_thresh": 0.01,
"min_size": 224,
"max_size": 224,
"input_shape": (3, 224, 224),
},
"keypointrcnn_resnet50_fpn": {
"num_classes": 2,
"min_size": 224,
Expand All @@ -259,6 +270,12 @@ def _check_input_backprop(model, inputs):
"max_size": 224,
"input_shape": (3, 224, 224),
},
"fasterrcnn_resnet50_fpn_v2": {
"num_classes": 20,
"min_size": 224,
"max_size": 224,
"input_shape": (3, 224, 224),
},
"fcos_resnet50_fpn": {
"num_classes": 2,
"score_thresh": 0.05,
Expand All @@ -272,6 +289,12 @@ def _check_input_backprop(model, inputs):
"max_size": 224,
"input_shape": (3, 224, 224),
},
"maskrcnn_resnet50_fpn_v2": {
"num_classes": 10,
"min_size": 224,
"max_size": 224,
"input_shape": (3, 224, 224),
},
"fasterrcnn_mobilenet_v3_large_fpn": {
"box_score_thresh": 0.02076,
},
Expand Down Expand Up @@ -311,6 +334,10 @@ def _check_input_backprop(model, inputs):
"max_trainable": 5,
"n_trn_params_per_layer": [36, 46, 65, 78, 88, 89],
},
"retinanet_resnet50_fpn_v2": {
"max_trainable": 5,
"n_trn_params_per_layer": [44, 74, 131, 170, 200, 203],
},
"keypointrcnn_resnet50_fpn": {
"max_trainable": 5,
"n_trn_params_per_layer": [48, 58, 77, 90, 100, 101],
Expand All @@ -319,10 +346,18 @@ def _check_input_backprop(model, inputs):
"max_trainable": 5,
"n_trn_params_per_layer": [30, 40, 59, 72, 82, 83],
},
"fasterrcnn_resnet50_fpn_v2": {
"max_trainable": 5,
"n_trn_params_per_layer": [50, 80, 137, 176, 206, 209],
},
"maskrcnn_resnet50_fpn": {
"max_trainable": 5,
"n_trn_params_per_layer": [42, 52, 71, 84, 94, 95],
},
"maskrcnn_resnet50_fpn_v2": {
"max_trainable": 5,
"n_trn_params_per_layer": [66, 96, 153, 192, 222, 225],
},
"fasterrcnn_mobilenet_v3_large_fpn": {
"max_trainable": 6,
"n_trn_params_per_layer": [22, 23, 44, 70, 91, 97, 100],
Expand Down
28 changes: 26 additions & 2 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import math
from collections import OrderedDict
from typing import List, Tuple
from typing import Dict, List, Optional, Tuple

import torch
from torch import Tensor, nn
from torchvision.ops.misc import FrozenBatchNorm2d
from torch.nn import functional as F
from torchvision.ops import FrozenBatchNorm2d, generalized_box_iou_loss


class BalancedPositiveNegativeSampler:
Expand Down Expand Up @@ -507,3 +508,26 @@ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
return _fake_cast_onnx(min_kval)


def _box_loss(
type: str,
box_coder: BoxCoder,
anchors_per_image: Tensor,
matched_gt_boxes_per_image: Tensor,
bbox_regression_per_image: Tensor,
cnf: Optional[Dict[str, float]] = None,
) -> Tensor:
torch._assert(type in ["l1", "smooth_l1", "giou"], f"Unsupported loss: {type}")

if type == "l1":
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
elif type == "smooth_l1":
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
else: # giou
bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
13 changes: 11 additions & 2 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class BackboneWithFPN(nn.Module):
in_channels_list (List[int]): number of channels for each feature map
that is returned, in the order they are present in the OrderedDict
out_channels (int): number of channels in the FPN.
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
Attributes:
out_channels (int): the number of channels in the FPN
"""
Expand All @@ -36,6 +37,7 @@ def __init__(
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()

Expand All @@ -47,6 +49,7 @@ def __init__(
in_channels_list=in_channels_list,
out_channels=out_channels,
extra_blocks=extra_blocks,
norm_layer=norm_layer,
)
self.out_channels = out_channels

Expand Down Expand Up @@ -115,6 +118,7 @@ def _resnet_fpn_extractor(
trainable_layers: int,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> BackboneWithFPN:

# select layers that wont be frozen
Expand All @@ -139,7 +143,9 @@ def _resnet_fpn_extractor(
in_channels_stage2 = backbone.inplanes // 8
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
out_channels = 256
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
return BackboneWithFPN(
backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
)


def _validate_trainable_layers(
Expand Down Expand Up @@ -194,6 +200,7 @@ def _mobilenet_extractor(
trainable_layers: int,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> nn.Module:
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
Expand Down Expand Up @@ -222,7 +229,9 @@ def _mobilenet_extractor(
return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}

in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
return BackboneWithFPN(
backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
)
else:
m = nn.Sequential(
backbone,
Expand Down
115 changes: 111 additions & 4 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand All @@ -24,14 +24,22 @@
__all__ = [
"FasterRCNN",
"FasterRCNN_ResNet50_FPN_Weights",
"FasterRCNN_ResNet50_FPN_V2_Weights",
"FasterRCNN_MobileNet_V3_Large_FPN_Weights",
"FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
"fasterrcnn_resnet50_fpn",
"fasterrcnn_resnet50_fpn_v2",
"fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
]


def _default_anchorgen():
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
return AnchorGenerator(anchor_sizes, aspect_ratios)


class FasterRCNN(GeneralizedRCNN):
"""
Implements Faster R-CNN.
Expand Down Expand Up @@ -216,9 +224,7 @@ def __init__(
out_channels = backbone.out_channels

if rpn_anchor_generator is None:
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
rpn_anchor_generator = _default_anchorgen()
if rpn_head is None:
rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])

Expand Down Expand Up @@ -298,6 +304,43 @@ def forward(self, x):
return x


class FastRCNNConvFCHead(nn.Sequential):
def __init__(
self,
input_size: Tuple[int, int, int],
conv_layers: List[int],
fc_layers: List[int],
norm_layer: Optional[Callable[..., nn.Module]] = None,
):
"""
Args:
input_size (Tuple[int, int, int]): the input size in CHW format.
conv_layers (list): feature dimensions of each Convolution layer
fc_layers (list): feature dimensions of each FCN layer
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
"""
in_channels, in_height, in_width = input_size

blocks = []
previous_channels = in_channels
for current_channels in conv_layers:
blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
previous_channels = current_channels
blocks.append(nn.Flatten())
previous_channels = previous_channels * in_height * in_width
for current_channels in fc_layers:
blocks.append(nn.Linear(previous_channels, current_channels))
blocks.append(nn.ReLU(inplace=True))
previous_channels = current_channels

super().__init__(*blocks)
for layer in self.modules():
if isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
if layer.bias is not None:
nn.init.zeros_(layer.bias)


class FastRCNNPredictor(nn.Module):
"""
Standard classification + bounding box regression layers
Expand Down Expand Up @@ -349,6 +392,10 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
DEFAULT = COCO_V1


class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
pass


class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
Expand Down Expand Up @@ -481,6 +528,66 @@ def fasterrcnn_resnet50_fpn(
return model


def fasterrcnn_resnet50_fpn_v2(
*,
weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
"""
Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone.
Reference: `"Benchmarking Detection Transfer Learning with Vision Transformers"
<https://arxiv.org/abs/2111.11429>`_.
:func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more details.
Args:
weights (FasterRCNN_ResNet50_FPN_V2_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int, optional): number of output classes of the model (including the background)
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)

if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91

is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)

backbone = resnet50(weights=weights_backbone, progress=progress)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
rpn_anchor_generator = _default_anchorgen()
rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
box_head = FastRCNNConvFCHead(
(backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
)
model = FasterRCNN(
backbone,
num_classes=num_classes,
rpn_anchor_generator=rpn_anchor_generator,
rpn_head=rpn_head,
box_head=box_head,
**kwargs,
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))

return model


def _fasterrcnn_mobilenet_v3_large_fpn(
*,
weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
Expand Down
Loading

0 comments on commit 08cc9a7

Please sign in to comment.