From 9c95543ea967f867a95164d51c21b24b0986f5aa Mon Sep 17 00:00:00 2001 From: Yosuke Shinya <42844407+shinya7y@users.noreply.github.com> Date: Thu, 24 Sep 2020 11:26:16 +0900 Subject: [PATCH] Support TTA of RetinaNet and GFL (#3638) * Move RepPoints TTA to mixin class for reuse * Support TTA of RetinaNet * Support TTA of GFL * Update to use BBoxTestMixin in dense_heads * Update for v2.4.0 inference --- mmdet/models/dense_heads/anchor_free_head.py | 21 ++++- mmdet/models/dense_heads/anchor_head.py | 59 ++++++++++--- mmdet/models/dense_heads/dense_test_mixins.py | 88 +++++++++++++++++++ mmdet/models/dense_heads/gfl_head.py | 16 ++-- mmdet/models/dense_heads/reppoints_head.py | 8 +- mmdet/models/detectors/reppoints_detector.py | 77 ---------------- mmdet/models/detectors/single_stage.py | 25 +++++- 7 files changed, 195 insertions(+), 99 deletions(-) create mode 100644 mmdet/models/dense_heads/dense_test_mixins.py diff --git a/mmdet/models/dense_heads/anchor_free_head.py b/mmdet/models/dense_heads/anchor_free_head.py index 0dc5f3a2ee2..6c0f51e9502 100644 --- a/mmdet/models/dense_heads/anchor_free_head.py +++ b/mmdet/models/dense_heads/anchor_free_head.py @@ -8,10 +8,11 @@ from mmdet.core import multi_apply from ..builder import HEADS, build_loss from .base_dense_head import BaseDenseHead +from .dense_test_mixins import BBoxTestMixin @HEADS.register_module() -class AnchorFreeHead(BaseDenseHead): +class AnchorFreeHead(BaseDenseHead, BBoxTestMixin): """Anchor-free head (FCOS, Fovea, RepPoints, etc.). Args: @@ -328,3 +329,21 @@ def get_points(self, featmap_sizes, dtype, device, flatten=False): self._get_points_single(featmap_sizes[i], self.strides[i], dtype, device, flatten)) return mlvl_points + + def aug_test(self, feats, img_metas, rescale=False): + """Test function with test time augmentation. + + Args: + feats (list[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains features for all images in the batch. + img_metas (list[list[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. each dict has image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[ndarray]: bbox results of each class + """ + return self.aug_test_bboxes(feats, img_metas, rescale=rescale) diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py index 37c5ef902d6..d13e51a1a94 100644 --- a/mmdet/models/dense_heads/anchor_head.py +++ b/mmdet/models/dense_heads/anchor_head.py @@ -8,10 +8,11 @@ images_to_levels, multi_apply, multiclass_nms, unmap) from ..builder import HEADS, build_loss from .base_dense_head import BaseDenseHead +from .dense_test_mixins import BBoxTestMixin @HEADS.register_module() -class AnchorHead(BaseDenseHead): +class AnchorHead(BaseDenseHead, BBoxTestMixin): """Anchor-based head (RPN, RetinaNet, SSD, etc.). Args: @@ -502,7 +503,8 @@ def get_bboxes(self, bbox_preds, img_metas, cfg=None, - rescale=False): + rescale=False, + with_nms=True): """Transform network output for a batch into bbox predictions. Args: @@ -516,6 +518,8 @@ def get_bboxes(self, if None, test_cfg would be used rescale (bool): If True, return boxes in original image space. Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. Returns: list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. @@ -569,9 +573,18 @@ def get_bboxes(self, ] img_shape = img_metas[img_id]['img_shape'] scale_factor = img_metas[img_id]['scale_factor'] - proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list, - mlvl_anchors, img_shape, - scale_factor, cfg, rescale) + if with_nms: + # some heads don't support with_nms argument + proposals = self._get_bboxes_single(cls_score_list, + bbox_pred_list, + mlvl_anchors, img_shape, + scale_factor, cfg, rescale) + else: + proposals = self._get_bboxes_single(cls_score_list, + bbox_pred_list, + mlvl_anchors, img_shape, + scale_factor, cfg, rescale, + with_nms) result_list.append(proposals) return result_list @@ -582,7 +595,8 @@ def _get_bboxes_single(self, img_shape, scale_factor, cfg, - rescale=False): + rescale=False, + with_nms=True): """Transform outputs for a single batch item into bbox predictions. Args: @@ -599,6 +613,9 @@ def _get_bboxes_single(self, cfg (mmcv.Config): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. Returns: Tensor: Labeled boxes in shape (n, 5), where the first 4 columns @@ -647,7 +664,29 @@ def _get_bboxes_single(self, # BG cat_id: num_class padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) - det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, - cfg.score_thr, cfg.nms, - cfg.max_per_img) - return det_bboxes, det_labels + + if with_nms: + det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, + cfg.score_thr, cfg.nms, + cfg.max_per_img) + return det_bboxes, det_labels + else: + return mlvl_bboxes, mlvl_scores + + def aug_test(self, feats, img_metas, rescale=False): + """Test function with test time augmentation. + + Args: + feats (list[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains features for all images in the batch. + img_metas (list[list[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. each dict has image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[ndarray]: bbox results of each class + """ + return self.aug_test_bboxes(feats, img_metas, rescale=rescale) diff --git a/mmdet/models/dense_heads/dense_test_mixins.py b/mmdet/models/dense_heads/dense_test_mixins.py new file mode 100644 index 00000000000..326621aded7 --- /dev/null +++ b/mmdet/models/dense_heads/dense_test_mixins.py @@ -0,0 +1,88 @@ +from inspect import signature + +import torch + +from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms + + +class BBoxTestMixin(object): + """Mixin class for test time augmentation of bboxes.""" + + def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas): + """Merge augmented detection bboxes and scores. + + Args: + aug_bboxes (list[Tensor]): shape (n, 4*#class) + aug_scores (list[Tensor] or None): shape (n, #class) + img_shapes (list[Tensor]): shape (3, ). + + Returns: + tuple: (bboxes, scores) + """ + recovered_bboxes = [] + for bboxes, img_info in zip(aug_bboxes, img_metas): + img_shape = img_info[0]['img_shape'] + scale_factor = img_info[0]['scale_factor'] + flip = img_info[0]['flip'] + flip_direction = img_info[0]['flip_direction'] + bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip, + flip_direction) + recovered_bboxes.append(bboxes) + bboxes = torch.cat(recovered_bboxes, dim=0) + if aug_scores is None: + return bboxes + else: + scores = torch.cat(aug_scores, dim=0) + return bboxes, scores + + def aug_test_bboxes(self, feats, img_metas, rescale=False): + """Test det bboxes with test time augmentation. + + Args: + feats (list[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains features for all images in the batch. + img_metas (list[list[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. each dict has image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[ndarray]: bbox results of each class + """ + # check with_nms argument + gb_sig = signature(self.get_bboxes) + gb_args = [p.name for p in gb_sig.parameters.values()] + gbs_sig = signature(self._get_bboxes_single) + gbs_args = [p.name for p in gbs_sig.parameters.values()] + assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \ + f'{self.__class__.__name__}' \ + ' does not support test-time augmentation' + + aug_bboxes = [] + aug_scores = [] + for x, img_meta in zip(feats, img_metas): + # only one image in the batch + outs = self.forward(x) + bbox_inputs = outs + (img_meta, self.test_cfg, False, False) + det_bboxes, det_scores = self.get_bboxes(*bbox_inputs)[0] + aug_bboxes.append(det_bboxes) + aug_scores.append(det_scores) + + # after merging, bboxes will be rescaled to the original image size + merged_bboxes, merged_scores = self.merge_aug_bboxes( + aug_bboxes, aug_scores, img_metas) + det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, + self.test_cfg.score_thr, + self.test_cfg.nms, + self.test_cfg.max_per_img) + + if rescale: + _det_bboxes = det_bboxes + else: + _det_bboxes = det_bboxes.clone() + _det_bboxes[:, :4] *= det_bboxes.new_tensor( + img_metas[0][0]['scale_factor']) + bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes) + return bbox_results diff --git a/mmdet/models/dense_heads/gfl_head.py b/mmdet/models/dense_heads/gfl_head.py index 55a9d0e1e03..8a9cd888593 100644 --- a/mmdet/models/dense_heads/gfl_head.py +++ b/mmdet/models/dense_heads/gfl_head.py @@ -382,7 +382,8 @@ def _get_bboxes_single(self, img_shape, scale_factor, cfg, - rescale=False): + rescale=False, + with_nms=True): """Transform outputs for a single batch item into labeled boxes. Args: @@ -401,6 +402,8 @@ def _get_bboxes_single(self, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. Returns: tuple(Tensor): @@ -450,10 +453,13 @@ def _get_bboxes_single(self, padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) - det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, - cfg.score_thr, cfg.nms, - cfg.max_per_img) - return det_bboxes, det_labels + if with_nms: + det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, + cfg.score_thr, cfg.nms, + cfg.max_per_img) + return det_bboxes, det_labels + else: + return mlvl_bboxes, mlvl_scores def get_targets(self, anchor_list, diff --git a/mmdet/models/dense_heads/reppoints_head.py b/mmdet/models/dense_heads/reppoints_head.py index 447b648c96c..16a0af907e2 100644 --- a/mmdet/models/dense_heads/reppoints_head.py +++ b/mmdet/models/dense_heads/reppoints_head.py @@ -664,7 +664,7 @@ def get_bboxes(self, img_metas, cfg=None, rescale=False, - nms=True): + with_nms=True): assert len(cls_scores) == len(pts_preds_refine) bbox_preds_refine = [ self.points2bbox(pts_pred_refine) @@ -690,7 +690,7 @@ def get_bboxes(self, proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list, mlvl_points, img_shape, scale_factor, cfg, rescale, - nms) + with_nms) result_list.append(proposals) return result_list @@ -702,7 +702,7 @@ def _get_bboxes_single(self, scale_factor, cfg, rescale=False, - nms=True): + with_nms=True): cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) mlvl_bboxes = [] @@ -749,7 +749,7 @@ def _get_bboxes_single(self, # BG cat_id: num_class padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) - if nms: + if with_nms: det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img) diff --git a/mmdet/models/detectors/reppoints_detector.py b/mmdet/models/detectors/reppoints_detector.py index 35496b3c54e..a5f6be31e14 100644 --- a/mmdet/models/detectors/reppoints_detector.py +++ b/mmdet/models/detectors/reppoints_detector.py @@ -1,6 +1,3 @@ -import torch - -from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms from ..builder import DETECTORS from .single_stage import SingleStageDetector @@ -23,77 +20,3 @@ def __init__(self, super(RepPointsDetector, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained) - - def merge_aug_results(self, aug_bboxes, aug_scores, img_metas): - """Merge augmented detection bboxes and scores. - - Args: - aug_bboxes (list[Tensor]): shape (n, 4*#class) - aug_scores (list[Tensor] or None): shape (n, #class) - img_shapes (list[Tensor]): shape (3, ). - - Returns: - tuple: (bboxes, scores) - """ - recovered_bboxes = [] - for bboxes, img_info in zip(aug_bboxes, img_metas): - img_shape = img_info[0]['img_shape'] - scale_factor = img_info[0]['scale_factor'] - flip = img_info[0]['flip'] - flip_direction = img_info[0]['flip_direction'] - bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip, - flip_direction) - recovered_bboxes.append(bboxes) - bboxes = torch.cat(recovered_bboxes, dim=0) - if aug_scores is None: - return bboxes - else: - scores = torch.cat(aug_scores, dim=0) - return bboxes, scores - - def aug_test(self, imgs, img_metas, rescale=False): - """Test function with test time augmentation. - - Args: - imgs (list[Tensor]): the outer list indicates test-time - augmentations and inner Tensor should have a shape NxCxHxW, - which contains all images in the batch. - img_metas (list[list[dict]]): the outer list indicates test-time - augs (multiscale, flip, etc.) and the inner list indicates - images in a batch. each dict has image information. - rescale (bool, optional): Whether to rescale the results. - Defaults to False. - - Returns: - list[ndarray]: bbox results of each class - """ - # recompute feats to save memory - feats = self.extract_feats(imgs) - - aug_bboxes = [] - aug_scores = [] - for x, img_meta in zip(feats, img_metas): - # only one image in the batch - outs = self.bbox_head(x) - bbox_inputs = outs + (img_meta, self.test_cfg, False, False) - det_bboxes, det_scores = self.bbox_head.get_bboxes(*bbox_inputs)[0] - aug_bboxes.append(det_bboxes) - aug_scores.append(det_scores) - - # after merging, bboxes will be rescaled to the original image size - merged_bboxes, merged_scores = self.merge_aug_results( - aug_bboxes, aug_scores, img_metas) - det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, - self.test_cfg.score_thr, - self.test_cfg.nms, - self.test_cfg.max_per_img) - - if rescale: - _det_bboxes = det_bboxes - else: - _det_bboxes = det_bboxes.clone() - _det_bboxes[:, :4] *= det_bboxes.new_tensor( - img_metas[0][0]['scale_factor']) - bbox_results = bbox2result(_det_bboxes, det_labels, - self.bbox_head.num_classes) - return bbox_results diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py index 3977bdff117..3932c9afcce 100644 --- a/mmdet/models/detectors/single_stage.py +++ b/mmdet/models/detectors/single_stage.py @@ -123,5 +123,26 @@ def simple_test(self, img, img_metas, rescale=False): return bbox_results def aug_test(self, imgs, img_metas, rescale=False): - """Test function with test time augmentation.""" - raise NotImplementedError + """Test function with test time augmentation. + + Args: + imgs (list[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (list[list[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. each dict has image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[list[np.ndarray]]: BBox results of each image and classes. + The outer list corresponds to each image. The inner list + corresponds to each class. + """ + assert hasattr(self.bbox_head, 'aug_test'), \ + f'{self.bbox_head.__class__.__name__}' \ + ' does not support test-time augmentation' + + feats = self.extract_feats(imgs) + return [self.bbox_head.aug_test(feats, img_metas, rescale=rescale)]