-
Notifications
You must be signed in to change notification settings - Fork 9.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support TTA of RetinaNet and GFL (#3638)
* Move RepPoints TTA to mixin class for reuse * Support TTA of RetinaNet * Support TTA of GFL * Update to use BBoxTestMixin in dense_heads * Update for v2.4.0 inference
- Loading branch information
Showing
7 changed files
with
195 additions
and
99 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from inspect import signature | ||
|
||
import torch | ||
|
||
from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms | ||
|
||
|
||
class BBoxTestMixin(object): | ||
"""Mixin class for test time augmentation of bboxes.""" | ||
|
||
def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas): | ||
"""Merge augmented detection bboxes and scores. | ||
Args: | ||
aug_bboxes (list[Tensor]): shape (n, 4*#class) | ||
aug_scores (list[Tensor] or None): shape (n, #class) | ||
img_shapes (list[Tensor]): shape (3, ). | ||
Returns: | ||
tuple: (bboxes, scores) | ||
""" | ||
recovered_bboxes = [] | ||
for bboxes, img_info in zip(aug_bboxes, img_metas): | ||
img_shape = img_info[0]['img_shape'] | ||
scale_factor = img_info[0]['scale_factor'] | ||
flip = img_info[0]['flip'] | ||
flip_direction = img_info[0]['flip_direction'] | ||
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip, | ||
flip_direction) | ||
recovered_bboxes.append(bboxes) | ||
bboxes = torch.cat(recovered_bboxes, dim=0) | ||
if aug_scores is None: | ||
return bboxes | ||
else: | ||
scores = torch.cat(aug_scores, dim=0) | ||
return bboxes, scores | ||
|
||
def aug_test_bboxes(self, feats, img_metas, rescale=False): | ||
"""Test det bboxes with test time augmentation. | ||
Args: | ||
feats (list[Tensor]): the outer list indicates test-time | ||
augmentations and inner Tensor should have a shape NxCxHxW, | ||
which contains features for all images in the batch. | ||
img_metas (list[list[dict]]): the outer list indicates test-time | ||
augs (multiscale, flip, etc.) and the inner list indicates | ||
images in a batch. each dict has image information. | ||
rescale (bool, optional): Whether to rescale the results. | ||
Defaults to False. | ||
Returns: | ||
list[ndarray]: bbox results of each class | ||
""" | ||
# check with_nms argument | ||
gb_sig = signature(self.get_bboxes) | ||
gb_args = [p.name for p in gb_sig.parameters.values()] | ||
gbs_sig = signature(self._get_bboxes_single) | ||
gbs_args = [p.name for p in gbs_sig.parameters.values()] | ||
assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \ | ||
f'{self.__class__.__name__}' \ | ||
' does not support test-time augmentation' | ||
|
||
aug_bboxes = [] | ||
aug_scores = [] | ||
for x, img_meta in zip(feats, img_metas): | ||
# only one image in the batch | ||
outs = self.forward(x) | ||
bbox_inputs = outs + (img_meta, self.test_cfg, False, False) | ||
det_bboxes, det_scores = self.get_bboxes(*bbox_inputs)[0] | ||
aug_bboxes.append(det_bboxes) | ||
aug_scores.append(det_scores) | ||
|
||
# after merging, bboxes will be rescaled to the original image size | ||
merged_bboxes, merged_scores = self.merge_aug_bboxes( | ||
aug_bboxes, aug_scores, img_metas) | ||
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, | ||
self.test_cfg.score_thr, | ||
self.test_cfg.nms, | ||
self.test_cfg.max_per_img) | ||
|
||
if rescale: | ||
_det_bboxes = det_bboxes | ||
else: | ||
_det_bboxes = det_bboxes.clone() | ||
_det_bboxes[:, :4] *= det_bboxes.new_tensor( | ||
img_metas[0][0]['scale_factor']) | ||
bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes) | ||
return bbox_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.