Skip to content

Commit

Permalink
Merge pull request pytorch#3 from o295/main
Browse files Browse the repository at this point in the history
Fixing python lint, docstrings and add typing annotations
  • Loading branch information
xiaohu2015 authored Nov 19, 2021
2 parents 5d75049 + d4c08d3 commit edb2a1a
Showing 1 changed file with 51 additions and 17 deletions.
68 changes: 51 additions & 17 deletions torchvision/models/detection/fcos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import warnings
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Tuple, Optional

import torch
Expand All @@ -26,6 +27,7 @@
class FCOSHead(nn.Module):
"""
A regression and classification head for use in FCOS.
Args:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
Expand Down Expand Up @@ -117,6 +119,7 @@ def forward(self, x):
class FCOSClassificationHead(nn.Module):
"""
A classification head for use in FCOS.
Args:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
Expand All @@ -131,7 +134,7 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_pro
self.num_anchors = num_anchors

if norm_layer is None:
norm_layer = lambda channels: nn.GroupNorm(32, channels)
norm_layer = partial(nn.GroupNorm, 32)

conv = []
for _ in range(num_convs):
Expand All @@ -149,8 +152,7 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_pro
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))

def forward(self, x):
# type: (List[Tensor]) -> Tensor
def forward(self, x: List[Tensor]) -> Tensor:
all_cls_logits = []

for features in x:
Expand All @@ -171,6 +173,7 @@ def forward(self, x):
class FCOSRegressionHead(nn.Module):
"""
A regression head for use in FCOS.
Args:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
Expand All @@ -181,7 +184,7 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None):
super().__init__()

if norm_layer is None:
norm_layer = lambda channels: nn.GroupNorm(32, channels)
norm_layer = partial(nn.GroupNorm, 32)

conv = []
for _ in range(num_convs):
Expand All @@ -201,8 +204,7 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None):
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.zeros_(layer.bias)

def forward(self, x):
# type: (List[Tensor]) -> Tensor
def forward(self, x: List[Tensor]) -> Tensor:
all_bbox_regression = []
all_bbox_ctrness = []

Expand Down Expand Up @@ -230,23 +232,29 @@ def forward(self, x):
class FCOS(nn.Module):
"""
Implements FCOS.
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
image, and should be in 0-1 range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
- labels (Int64Tensor[N]): the class label for each ground-truth box
The model returns a Dict[Tensor] during training, containing the classification and regression
losses.
During inference, the model requires only the input tensors, and returns the post-processed
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
follows:
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
- labels (Int64Tensor[N]): the predicted labels for each image
- scores (Tensor[N]): the scores for each prediction
Args:
backbone (nn.Module): the network used to compute the features for the model.
It should contain an out_channels attribute, which indicates the number of output
Expand All @@ -272,7 +280,9 @@ class FCOS(nn.Module):
nms_thresh (float): NMS threshold used for postprocessing the detections.
detections_per_img (int): Number of best detections to keep after NMS.
topk_candidates (int): Number of best detections to keep before NMS.
Example:
>>> import torch
>>> import torchvision
>>> from torchvision.models.detection import FCOS
Expand Down Expand Up @@ -364,15 +374,23 @@ def __init__(
self._has_warned = False

@torch.jit.unused
def eager_outputs(self, losses, detections):
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
def eager_outputs(
self,
losses: Dict[str, Tensor],
detections: List[Dict[str, Tensor]]
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
if self.training:
return losses

return detections

def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[int]) -> Dict[str, Tensor]
def compute_loss(
self,
targets: List[Dict[str, Tensor]],
head_outputs: Dict[str, Tensor],
anchors: List[Tensor],
num_anchors_per_level: List[int],
) -> Dict[str, Tensor]:
matched_idxs = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
if targets_per_image["boxes"].numel() == 0:
Expand Down Expand Up @@ -417,8 +435,12 @@ def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level):

return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs, self.box_coder)

def postprocess_detections(self, head_outputs, anchors, image_shapes):
# type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
def postprocess_detections(
self,
head_outputs: Dict[str, List[Tensor]],
anchors: List[List[Tensor]],
image_shapes: List[Tuple[int, int]]
) -> List[Dict[str, Tensor]]:
class_logits = head_outputs["cls_logits"]
box_regression = head_outputs["bbox_regression"]
box_ctrness = head_outputs["bbox_ctrness"]
Expand Down Expand Up @@ -484,12 +506,16 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):

return detections

def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
def forward(
self,
images: List[Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None,
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
"""
Args:
images (list[Tensor]): images to be processed
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
Expand Down Expand Up @@ -570,14 +596,15 @@ def forward(self, images, targets=None):

if torch.jit.is_scripting():
if not self._has_warned:
warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting")
self._has_warned = True
return losses, detections
return self.eager_outputs(losses, detections)


model_urls = {
"fcos_resnet50_fpn_coco": "",
"fcos_resnet50_fpn_coco":
"https://github.com/o295/checkpoints/releases/download/coco/fcos_resnet50_fpn_coco-46080c1a.pth",
}


Expand All @@ -587,16 +614,20 @@ def fcos_resnet50_fpn(
"""
Constructs a FCOS model with a ResNet-50-FPN backbone.
Reference: `"FCOS: Fully Convolutional One-Stage Object Detection" <https://arxiv.org/abs/1904.01355>`_.
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
losses.
During inference, the model requires only the input tensors, and returns the post-processed
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
follows, where ``N`` is the number of detections:
Expand All @@ -605,11 +636,14 @@ def fcos_resnet50_fpn(
- labels (``Int64Tensor[N]``): the predicted labels for each detection
- scores (``Tensor[N]``): the scores of each detection
For more details on the output, you may refer to :ref:`instance_seg_output`.
Example::
Example:
>>> model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
Expand Down

0 comments on commit edb2a1a

Please sign in to comment.