Skip to content

Commit

Permalink
Improve docstrings, variable names
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Oct 16, 2023
1 parent 8268145 commit 901d7f9
Showing 1 changed file with 75 additions and 49 deletions.
124 changes: 75 additions & 49 deletions src/transformers/models/mask_rcnn/modeling_maskrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class MaskRCNNRPNOutput(ModelOutput):
Region Proposal Network (RPN) outputs.
Args:
losses (`torch.FloatTensor`):
losses (`dict[str, torch.Tensor]`):
Losses of the RPN head.
proposals (Tuple[`torch.FloatTensor`]):
Tuple of proposals, for each example in the batch. Each proposal is a `torch.FloatTensor` of shape
Expand All @@ -83,7 +83,7 @@ class MaskRCNNRPNOutput(ModelOutput):
4.
"""

losses: torch.FloatTensor = None
losses: dict[str, torch.FloatTensor] = None
proposals: Tuple[torch.FloatTensor] = None
logits: Optional[Tuple[torch.FloatTensor]] = None
pred_boxes: Optional[torch.FloatTensor] = None
Expand Down Expand Up @@ -595,7 +595,7 @@ def forward(self, pred, target, weight=None, avg_factor=None, reduction=None):
class AssignResult:
"""Stores assignments between predicted and truth boxes.
Attributes:
Args:
num_gts (`int`):
The number of truth boxes considered when computing this assignment
gt_indices (`torch.LongTensor`):
Expand All @@ -612,12 +612,10 @@ def __init__(self, num_gts, gt_indices, max_overlaps, labels=None):
self.gt_indices = gt_indices
self.max_overlaps = max_overlaps
self.labels = labels
# Interface for possible user-defined properties
self._extra_properties = {}

@property
def num_preds(self):
"""int: the number of predictions in this assignment"""
def num_preds(self) -> int:
"""The number of predictions in this assignment"""
return len(self.gt_indices)

def add_ground_truth(self, gt_labels):
Expand Down Expand Up @@ -1719,9 +1717,18 @@ class MaskRCNNRPN(nn.Module):
Networks](https://arxiv.org/abs/1506.01497).
Source: https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/anchor_head.py
Args:
config (`MaskRCNNConfig`):
Model configuration.
num_classes (`int`, *optional*, defaults to 1):
Number of classes.
reg_decoded_bbox (`bool`, *optional*, defaults to `False`):
Whether to apply the regression loss directly on decoded bounding boxes, converting both the predicted
boxes and regression targets to absolute coordinates format.
"""

def __init__(self, config, num_classes=1, reg_decoded_bbox=False):
def __init__(self, config, num_classes: int = 1, reg_decoded_bbox: bool = False):
super().__init__()

self.config = config
Expand Down Expand Up @@ -1752,7 +1759,7 @@ def __init__(self, config, num_classes=1, reg_decoded_bbox=False):
self.train_cfg = config.rpn_train_cfg
self.test_cfg = config.rpn_test_cfg

# IoU assigner
# IoU assigner (used to assign anchors to ground truth boxes)
self.assigner = MaskRCNNMaxIoUAssigner(
pos_iou_thr=config.rpn_assigner_pos_iou_thr,
neg_iou_thr=config.rpn_assigner_neg_iou_thr,
Expand Down Expand Up @@ -1781,19 +1788,20 @@ def __init__(self, config, num_classes=1, reg_decoded_bbox=False):
self.loss_bbox = WeightedL1Loss(loss_weight=loss_weight)

def forward_single(self, hidden_state):
"""Forward feature map of a single scale level."""
"""Forward a feature map of a single scale level."""
hidden_state = self.rpn_conv(hidden_state)
hidden_state = nn.functional.relu(hidden_state, inplace=True)
rpn_cls_score = self.rpn_cls(hidden_state)
rpn_bbox_pred = self.rpn_reg(hidden_state)
return rpn_cls_score, rpn_bbox_pred

def forward_features(self, hidden_states):
"""Forward features from the upstream network.
"""Update feature maps of the backbone network.
Args:
hidden_states (`Tuple[torch.FloatTensor]`):
Features from the upstream network, each being a 4D-tensor.
Feature maps from the backbone, each being a 4D-tensor.
Returns:
tuple: A tuple of classification scores and bbox prediction.
- cls_scores (`Tuple[torch.Tensor]`):
Expand Down Expand Up @@ -1887,7 +1895,8 @@ def _get_targets_single(
img_meta,
label_channels=1,
):
"""Compute regression and classification targets for anchors in a single image.
"""Compute regression and classification targets for anchors in a single image, by assigning ground truth
boxes and their class labels to anchors.
Args:
flat_anchors (`torch.Tensor`):
Expand Down Expand Up @@ -1922,17 +1931,19 @@ def _get_targets_single(
num_total_neg (`int`):
Number of negative samples in all images.
"""
# first, filter on anchors which are inside the allowed border of the image
inside_flags = anchor_inside_flags(
flat_anchors, valid_flags, img_meta["img_shape"][-2:], self.train_cfg["allowed_border"]
)
if not inside_flags.any():
return (None,) * 7
# assign ground truth and sample anchors
anchors = flat_anchors[inside_flags, :]

# next, assign ground truth boxes to anchors
assign_result = self.assigner.assign(
anchors, gt_bboxes, gt_bboxes_ignore, None if self.sampling else gt_labels
)
# next, sample anchors for training
sampling_result = self.sampler.sample(assign_result, anchors, gt_bboxes)

num_valid_anchors = anchors.shape[0]
Expand Down Expand Up @@ -2147,7 +2158,7 @@ def loss_single_scale_level(
return loss_cls, loss_bbox

def loss(self, cls_scores, bbox_preds, gt_bboxes, img_metas, gt_bboxes_ignore=None):
"""Compute losses of the head.
"""Compute losses of the RPN head.
Args:
cls_scores (`List[torch.Tensor]):
Expand Down Expand Up @@ -2488,7 +2499,21 @@ def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:


class MaskRCNNSingleRoIExtractor(nn.Module):
"""Extract RoI features from a single level feature map."""
"""Extract Region-of-Interest (RoI) features from a single level feature map.
Source:
https://github.com/open-mmlab/mmdetection/blob/f78af7785ada87f1ced75a2313746e4ba3149760/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py#L13.
Args:
roi_layer (`dict`):
Dictionary containing RoI settings, like the layer type and various arguments like `sampling_ratio`.
out_channels (`int`):
Output channels of RoI layers.
featmap_strides (`List[int]`):
Strides of the input feature maps.
finest_scale (`int`, *optional*, defaults to 56):
Scale threshold of mapping to level 0.
"""

def __init__(self, roi_layer, out_channels, featmap_strides, finest_scale=56):
super().__init__()
Expand All @@ -2499,12 +2524,13 @@ def __init__(self, roi_layer, out_channels, featmap_strides, finest_scale=56):
self.finest_scale = finest_scale

@property
def num_inputs(self):
"""int: Number of input feature maps."""
def num_inputs(self) -> int:
"""Number of input feature maps."""
return len(self.featmap_strides)

def build_roi_layers(self, layer_cfg, featmap_strides):
"""Build RoI operator to extract feature from each level feature map.
Args:
layer_cfg (`dict`):
Dictionary to construct and config RoI layer operation. Options are modules under `mmcv/ops` such as
Expand Down Expand Up @@ -2540,42 +2566,42 @@ def map_roi_levels(self, rois, num_levels):
`torch.Tensor`: Level index (0-based) of each RoI, shape (k, )
"""
scale = torch.sqrt((rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2]))
target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6))
target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
return target_lvls
target_levels = torch.floor(torch.log2(scale / self.finest_scale + 1e-6))
target_levels = target_levels.clamp(min=0, max=num_levels - 1).long()
return target_levels

def forward(self, feats, rois, roi_scale_factor=None):
def forward(self, features, rois, roi_scale_factor=None):
out_size = self.roi_layers[0].output_size
num_levels = len(feats)
num_levels = len(features)
expand_dims = (-1, self.out_channels * out_size[0] * out_size[1])
roi_feats = rois[:, :1].clone().detach()
roi_feats = roi_feats.expand(*expand_dims)
roi_feats = roi_feats.reshape(-1, self.out_channels, *out_size)
roi_feats = roi_feats * 0
roi_features = rois[:, :1].clone().detach()
roi_features = roi_features.expand(*expand_dims)
roi_features = roi_features.reshape(-1, self.out_channels, *out_size)
roi_features = roi_features * 0

if num_levels == 1:
if len(rois) == 0:
return roi_feats
return self.roi_layers[0](feats[0], rois)
return roi_features
return self.roi_layers[0](features[0], rois)

target_lvls = self.map_roi_levels(rois, num_levels)
target_levels = self.map_roi_levels(rois, num_levels)

if roi_scale_factor is not None:
rois = self.roi_rescale(rois, roi_scale_factor)

for i in range(num_levels):
mask = target_lvls == i
mask = target_levels == i
# to keep all roi_align nodes exported to onnx and skip nonzero op
mask = mask.float().unsqueeze(-1)
# select target level rois and reset the rest rois to zero.
rois_i = rois.clone().detach()
rois_i *= mask
mask_exp = mask.expand(*expand_dims).reshape(roi_feats.shape)
roi_feats_t = self.roi_layers[i](feats[i], rois_i)
roi_feats_t *= mask_exp
roi_feats += roi_feats_t
mask_exp = mask.expand(*expand_dims).reshape(roi_features.shape)
roi_features_t = self.roi_layers[i](features[i], rois_i)
roi_features_t *= mask_exp
roi_features += roi_features_t

return roi_feats
return roi_features


class MaskRCNNShared2FCBBoxHead(nn.Module):
Expand Down Expand Up @@ -2925,7 +2951,7 @@ def _bbox_forward(self, feature_maps, rois):
bbox_features = self.bbox_roi_extractor(feature_maps[: self.bbox_roi_extractor.num_inputs], rois)
cls_score, bbox_pred = self.bbox_head(bbox_features)

bbox_results = {"cls_score": cls_score, "bbox_pred": bbox_pred, "bbox_feats": bbox_features}
bbox_results = {"cls_score": cls_score, "bbox_pred": bbox_pred, "bbox_features": bbox_features}

return bbox_results

Expand Down Expand Up @@ -2978,37 +3004,37 @@ def forward_test_bboxes(self, feature_maps, proposals, rcnn_test_cfg):

return rois, proposals, logits, pred_boxes

def _mask_forward(self, x, rois=None, pos_indices=None, bbox_feats=None):
def _mask_forward(self, x, rois=None, pos_indices=None, bbox_features=None):
"""Mask head forward function used in both training and testing.
Removed with_shared_head here.
"""
if not ((rois is not None) ^ (pos_indices is not None and bbox_feats is not None)):
raise ValueError("Either rois or (pos_indices and bbox_feats) should be specified")
if not ((rois is not None) ^ (pos_indices is not None and bbox_features is not None)):
raise ValueError("Either rois or (pos_indices and bbox_features) should be specified")
if rois is not None:
mask_feats = self.mask_roi_extractor(x[: self.mask_roi_extractor.num_inputs], rois)
mask_features = self.mask_roi_extractor(x[: self.mask_roi_extractor.num_inputs], rois)
else:
mask_feats = bbox_feats[pos_indices]
mask_features = bbox_features[pos_indices]

mask_pred = self.mask_head(mask_feats)
mask_pred = self.mask_head(mask_features)

mask_results = {"mask_pred": mask_pred, "mask_feats": mask_feats}
mask_results = {"mask_pred": mask_pred, "mask_features": mask_features}
return mask_results

def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks):
def _mask_forward_train(self, x, sampling_results, bbox_features, gt_masks):
"""Run forward function and calculate loss for mask head in training."""
if not self.share_roi_extractor:
pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
mask_results = self._mask_forward(x, pos_rois)
else:
pos_indices = []
device = bbox_feats.device
device = bbox_features.device
for res in sampling_results:
pos_indices.append(torch.ones(res.pos_bboxes.shape[0], device=device, dtype=torch.uint8))
pos_indices.append(torch.zeros(res.neg_bboxes.shape[0], device=device, dtype=torch.uint8))
pos_indices = torch.cat(pos_indices)

mask_results = self._mask_forward(x, pos_indices=pos_indices, bbox_feats=bbox_feats)
mask_results = self._mask_forward(x, pos_indices=pos_indices, bbox_features=bbox_features)

mask_targets = self.mask_head.get_targets(sampling_results, gt_masks, self.train_cfg)
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
Expand Down Expand Up @@ -3082,7 +3108,7 @@ def forward_train(
proposals[i],
gt_bboxes[i],
gt_labels[i],
feats=[lvl_feat[i][None] for lvl_feat in feature_maps],
features=[lvl_feat[i][None] for lvl_feat in feature_maps],
)
sampling_results.append(sampling_result)

Expand All @@ -3095,7 +3121,7 @@ def forward_train(

if self.with_mask:
mask_results = self._mask_forward_train(
feature_maps, sampling_results, bbox_results["bbox_feats"], gt_masks
feature_maps, sampling_results, bbox_results["bbox_features"], gt_masks
)
losses.update(mask_results["loss_mask"])

Expand Down

0 comments on commit 901d7f9

Please sign in to comment.