-
Notifications
You must be signed in to change notification settings - Fork 9.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support TTA of RetinaNet and GFL #3638
Changes from all commits
8436b37
9c91491
e4e0079
5fa2323
7694c60
4a6149f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,10 +8,11 @@ | |
multiclass_nms, unmap) | ||
from ..builder import HEADS, build_loss | ||
from .base_dense_head import BaseDenseHead | ||
from .dense_test_mixins import BBoxTestMixin | ||
|
||
|
||
@HEADS.register_module() | ||
class AnchorHead(BaseDenseHead): | ||
class AnchorHead(BaseDenseHead, BBoxTestMixin): | ||
"""Anchor-based head (RPN, RetinaNet, SSD, etc.). | ||
|
||
Args: | ||
|
@@ -502,7 +503,8 @@ def get_bboxes(self, | |
bbox_preds, | ||
img_metas, | ||
cfg=None, | ||
rescale=False): | ||
rescale=False, | ||
with_nms=True): | ||
"""Transform network output for a batch into bbox predictions. | ||
|
||
Args: | ||
|
@@ -516,6 +518,8 @@ def get_bboxes(self, | |
if None, test_cfg would be used | ||
rescale (bool): If True, return boxes in original image space. | ||
Default: False. | ||
with_nms (bool): If True, do nms before return boxes. | ||
Default: True. | ||
|
||
Returns: | ||
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. | ||
|
@@ -569,9 +573,18 @@ def get_bboxes(self, | |
] | ||
img_shape = img_metas[img_id]['img_shape'] | ||
scale_factor = img_metas[img_id]['scale_factor'] | ||
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list, | ||
mlvl_anchors, img_shape, | ||
scale_factor, cfg, rescale) | ||
if with_nms: | ||
# some heads don't support with_nms argument | ||
proposals = self._get_bboxes_single(cls_score_list, | ||
bbox_pred_list, | ||
mlvl_anchors, img_shape, | ||
scale_factor, cfg, rescale) | ||
else: | ||
proposals = self._get_bboxes_single(cls_score_list, | ||
bbox_pred_list, | ||
mlvl_anchors, img_shape, | ||
scale_factor, cfg, rescale, | ||
with_nms) | ||
Comment on lines
+576
to
+587
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we just pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
result_list.append(proposals) | ||
return result_list | ||
|
||
|
@@ -582,7 +595,8 @@ def _get_bboxes_single(self, | |
img_shape, | ||
scale_factor, | ||
cfg, | ||
rescale=False): | ||
rescale=False, | ||
with_nms=True): | ||
"""Transform outputs for a single batch item into bbox predictions. | ||
|
||
Args: | ||
|
@@ -599,6 +613,9 @@ def _get_bboxes_single(self, | |
cfg (mmcv.Config): Test / postprocessing configuration, | ||
if None, test_cfg would be used. | ||
rescale (bool): If True, return boxes in original image space. | ||
Default: False. | ||
with_nms (bool): If True, do nms before return boxes. | ||
Default: True. | ||
|
||
Returns: | ||
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns | ||
|
@@ -647,7 +664,29 @@ def _get_bboxes_single(self, | |
# BG cat_id: num_class | ||
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) | ||
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) | ||
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, | ||
cfg.score_thr, cfg.nms, | ||
cfg.max_per_img) | ||
return det_bboxes, det_labels | ||
|
||
if with_nms: | ||
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, | ||
cfg.score_thr, cfg.nms, | ||
cfg.max_per_img) | ||
return det_bboxes, det_labels | ||
else: | ||
return mlvl_bboxes, mlvl_scores | ||
|
||
def aug_test(self, feats, img_metas, rescale=False): | ||
"""Test function with test time augmentation. | ||
|
||
Args: | ||
feats (list[Tensor]): the outer list indicates test-time | ||
augmentations and inner Tensor should have a shape NxCxHxW, | ||
which contains features for all images in the batch. | ||
img_metas (list[list[dict]]): the outer list indicates test-time | ||
augs (multiscale, flip, etc.) and the inner list indicates | ||
images in a batch. each dict has image information. | ||
rescale (bool, optional): Whether to rescale the results. | ||
Defaults to False. | ||
|
||
Returns: | ||
list[ndarray]: bbox results of each class | ||
""" | ||
return self.aug_test_bboxes(feats, img_metas, rescale=rescale) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from inspect import signature | ||
|
||
import torch | ||
|
||
from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms | ||
|
||
|
||
class BBoxTestMixin(object): | ||
"""Mixin class for test time augmentation of bboxes.""" | ||
|
||
def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas): | ||
"""Merge augmented detection bboxes and scores. | ||
|
||
Args: | ||
aug_bboxes (list[Tensor]): shape (n, 4*#class) | ||
aug_scores (list[Tensor] or None): shape (n, #class) | ||
img_shapes (list[Tensor]): shape (3, ). | ||
|
||
Returns: | ||
tuple: (bboxes, scores) | ||
""" | ||
recovered_bboxes = [] | ||
for bboxes, img_info in zip(aug_bboxes, img_metas): | ||
img_shape = img_info[0]['img_shape'] | ||
scale_factor = img_info[0]['scale_factor'] | ||
flip = img_info[0]['flip'] | ||
flip_direction = img_info[0]['flip_direction'] | ||
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip, | ||
flip_direction) | ||
recovered_bboxes.append(bboxes) | ||
bboxes = torch.cat(recovered_bboxes, dim=0) | ||
if aug_scores is None: | ||
return bboxes | ||
else: | ||
scores = torch.cat(aug_scores, dim=0) | ||
return bboxes, scores | ||
|
||
def aug_test_bboxes(self, feats, img_metas, rescale=False): | ||
"""Test det bboxes with test time augmentation. | ||
|
||
Args: | ||
feats (list[Tensor]): the outer list indicates test-time | ||
augmentations and inner Tensor should have a shape NxCxHxW, | ||
which contains features for all images in the batch. | ||
img_metas (list[list[dict]]): the outer list indicates test-time | ||
augs (multiscale, flip, etc.) and the inner list indicates | ||
images in a batch. each dict has image information. | ||
rescale (bool, optional): Whether to rescale the results. | ||
Defaults to False. | ||
|
||
Returns: | ||
list[ndarray]: bbox results of each class | ||
""" | ||
# check with_nms argument | ||
gb_sig = signature(self.get_bboxes) | ||
gb_args = [p.name for p in gb_sig.parameters.values()] | ||
gbs_sig = signature(self._get_bboxes_single) | ||
gbs_args = [p.name for p in gbs_sig.parameters.values()] | ||
assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \ | ||
f'{self.__class__.__name__}' \ | ||
' does not support test-time augmentation' | ||
|
||
aug_bboxes = [] | ||
aug_scores = [] | ||
for x, img_meta in zip(feats, img_metas): | ||
# only one image in the batch | ||
outs = self.forward(x) | ||
bbox_inputs = outs + (img_meta, self.test_cfg, False, False) | ||
det_bboxes, det_scores = self.get_bboxes(*bbox_inputs)[0] | ||
aug_bboxes.append(det_bboxes) | ||
aug_scores.append(det_scores) | ||
|
||
# after merging, bboxes will be rescaled to the original image size | ||
merged_bboxes, merged_scores = self.merge_aug_bboxes( | ||
aug_bboxes, aug_scores, img_metas) | ||
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, | ||
self.test_cfg.score_thr, | ||
self.test_cfg.nms, | ||
self.test_cfg.max_per_img) | ||
|
||
if rescale: | ||
_det_bboxes = det_bboxes | ||
else: | ||
_det_bboxes = det_bboxes.clone() | ||
_det_bboxes[:, :4] *= det_bboxes.new_tensor( | ||
img_metas[0][0]['scale_factor']) | ||
bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes) | ||
return bbox_results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RPNHead has a mixin
RPNTestMixin
, can these two mixins somehow be merged to avoid confusion?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simple merge of
dense_test_mixins.py
andrpn_test_mixin.py
causes another confusion, becausedense_test_mixins.py
focuses on TTA, and doesn't havesimple_test
.This issue comes from the following inconsistencies.
roi_heads
havesimple_test
RPNHead
andGARPNHead
havesimple_test_rpn
dense_heads
don't havesimple_test
I think the inconsistencies should be addressed by a PR for refactoring, not by this PR for adding a feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. A later refactoring will be proposed.