From 7578b6c13b54f7828064c7b5deb7e42a4318f731 Mon Sep 17 00:00:00 2001 From: Yosuke Shinya <42844407+shinya7y@users.noreply.github.com> Date: Sun, 25 Apr 2021 23:34:48 +0000 Subject: [PATCH] [Refactor] Add simple_test to dense heads --- mmdet/models/dense_heads/base_dense_head.py | 19 +++++++ mmdet/models/dense_heads/corner_head.py | 3 +- mmdet/models/dense_heads/dense_test_mixins.py | 32 ++++++++++++ .../models/dense_heads/embedding_rpn_head.py | 4 ++ mmdet/models/dense_heads/sabl_retina_head.py | 3 +- mmdet/models/dense_heads/transformer_head.py | 39 ++++++++++++++- mmdet/models/dense_heads/yolact_head.py | 47 +++++++++++++++++ mmdet/models/detectors/detr.py | 29 ----------- mmdet/models/detectors/single_stage.py | 29 ++++------- mmdet/models/detectors/yolact.py | 50 +++++-------------- .../test_dense_heads/test_dense_heads_attr.py | 43 ++++++++++++++++ 11 files changed, 210 insertions(+), 88 deletions(-) create mode 100644 tests/test_models/test_dense_heads/test_dense_heads_attr.py diff --git a/mmdet/models/dense_heads/base_dense_head.py b/mmdet/models/dense_heads/base_dense_head.py index de11e4a2197..b4fab4c2a9b 100644 --- a/mmdet/models/dense_heads/base_dense_head.py +++ b/mmdet/models/dense_heads/base_dense_head.py @@ -57,3 +57,22 @@ def forward_train(self, else: proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg) return losses, proposal_list + + def simple_test(self, feats, img_metas, rescale=False, postprocess=True): + """Test function without test-time augmentation. + + Args: + feats (tuple[torch.Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + postprocess (bool, optional): Whether to perform post-processing + by bbox2result. Defaults to True. + Returns: + list[list[np.ndarray]]: BBox results of each image and classes. + The outer list corresponds to each image. The inner list + corresponds to each class. + """ + return self.simple_test_bboxes( + feats, img_metas, rescale=rescale, postprocess=postprocess) diff --git a/mmdet/models/dense_heads/corner_head.py b/mmdet/models/dense_heads/corner_head.py index 50cdb49a29f..485349948ab 100644 --- a/mmdet/models/dense_heads/corner_head.py +++ b/mmdet/models/dense_heads/corner_head.py @@ -11,6 +11,7 @@ from ..builder import HEADS, build_loss from ..utils import gaussian_radius, gen_gaussian_target from .base_dense_head import BaseDenseHead +from .dense_test_mixins import BBoxTestMixin class BiCornerPool(nn.Module): @@ -74,7 +75,7 @@ def forward(self, x): @HEADS.register_module() -class CornerHead(BaseDenseHead): +class CornerHead(BaseDenseHead, BBoxTestMixin): """Head of CornerNet: Detecting Objects as Paired Keypoints. Code is modified from the `official github repo diff --git a/mmdet/models/dense_heads/dense_test_mixins.py b/mmdet/models/dense_heads/dense_test_mixins.py index dd81364dec9..527a301e82a 100644 --- a/mmdet/models/dense_heads/dense_test_mixins.py +++ b/mmdet/models/dense_heads/dense_test_mixins.py @@ -8,6 +8,38 @@ class BBoxTestMixin(object): """Mixin class for test time augmentation of bboxes.""" + def simple_test_bboxes(self, + feats, + img_metas, + rescale=False, + postprocess=True): + """Test det bboxes without test-time augmentation. + + Args: + feats (tuple[torch.Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + postprocess (bool, optional): Whether to perform post-processing + by bbox2result. Defaults to True. + + Returns: + list[list[np.ndarray]]: BBox results of each image and classes. + The outer list corresponds to each image. The inner list + corresponds to each class. + """ + outs = self.forward(feats) + bbox_list = self.get_bboxes(*outs, img_metas, rescale=rescale) + if postprocess: + bbox_results = [ + bbox2result(det_bboxes, det_labels, self.num_classes) + for det_bboxes, det_labels in bbox_list + ] + return bbox_results + else: + return bbox_list + def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas): """Merge augmented detection bboxes and scores. diff --git a/mmdet/models/dense_heads/embedding_rpn_head.py b/mmdet/models/dense_heads/embedding_rpn_head.py index 200ce8d20c5..fc8f308a09c 100644 --- a/mmdet/models/dense_heads/embedding_rpn_head.py +++ b/mmdet/models/dense_heads/embedding_rpn_head.py @@ -98,3 +98,7 @@ def forward_train(self, img, img_metas): def simple_test_rpn(self, img, img_metas): """Forward function in testing stage.""" return self._decode_init_proposals(img, img_metas) + + def simple_test(self, img, img_metas): + """Forward function in testing stage.""" + raise NotImplementedError diff --git a/mmdet/models/dense_heads/sabl_retina_head.py b/mmdet/models/dense_heads/sabl_retina_head.py index 4211622cb8b..f101718b1b6 100644 --- a/mmdet/models/dense_heads/sabl_retina_head.py +++ b/mmdet/models/dense_heads/sabl_retina_head.py @@ -9,11 +9,12 @@ multi_apply, multiclass_nms, unmap) from ..builder import HEADS, build_loss from .base_dense_head import BaseDenseHead +from .dense_test_mixins import BBoxTestMixin from .guided_anchor_head import GuidedAnchorHead @HEADS.register_module() -class SABLRetinaHead(BaseDenseHead): +class SABLRetinaHead(BaseDenseHead, BBoxTestMixin): """Side-Aware Boundary Localization (SABL) for RetinaNet. The anchor generation, assigning and sampling in SABLRetinaHead diff --git a/mmdet/models/dense_heads/transformer_head.py b/mmdet/models/dense_heads/transformer_head.py index 820fd069fcc..0e8f040eb2f 100644 --- a/mmdet/models/dense_heads/transformer_head.py +++ b/mmdet/models/dense_heads/transformer_head.py @@ -4,7 +4,7 @@ from mmcv.cnn import Conv2d, Linear, build_activation_layer from mmcv.runner import force_fp32 -from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, +from mmdet.core import (bbox2result, bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, build_assigner, build_sampler, multi_apply, reduce_mean) from mmdet.models.utils import (FFN, build_positional_encoding, @@ -652,3 +652,40 @@ def _get_bboxes_single(self, det_bboxes /= det_bboxes.new_tensor(scale_factor) det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(1)), -1) return det_bboxes, det_labels + + def simple_test_bboxes(self, + feats, + img_metas, + rescale=False, + postprocess=True): + """Test det bboxes without test-time augmentation. + + Args: + feats (tuple[torch.Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + postprocess (bool, optional): Whether to perform post-processing + by bbox2result. Defaults to True. + + Returns: + list[list[np.ndarray]]: BBox results of each image and classes. + The outer list corresponds to each image. The inner list + corresponds to each class. + """ + batch_size = len(img_metas) + assert batch_size == 1, 'Currently only batch_size 1 for inference ' \ + f'mode is supported. Found batch_size {batch_size}.' + + # forward of this head requires img_metas + outs = self.forward(feats, img_metas) + bbox_list = self.get_bboxes(*outs, img_metas, rescale=rescale) + if postprocess: + bbox_results = [ + bbox2result(det_bboxes, det_labels, self.num_classes) + for det_bboxes, det_labels in bbox_list + ] + return bbox_results + else: + return bbox_list diff --git a/mmdet/models/dense_heads/yolact_head.py b/mmdet/models/dense_heads/yolact_head.py index 10d311f94ee..d3a463ace9f 100644 --- a/mmdet/models/dense_heads/yolact_head.py +++ b/mmdet/models/dense_heads/yolact_head.py @@ -572,6 +572,12 @@ def get_targets(self, segm_pred, gt_masks, gt_labels): downsampled_masks[obj_idx]) return segm_targets + def simple_test(self, feats, img_metas, rescale=False, postprocess=True): + """Test function without test-time augmentation.""" + raise NotImplementedError( + 'simple_test of YOLACTSegmHead is not implemented ' + 'because this head is only evaluated during training') + @HEADS.register_module() class YOLACTProtonet(nn.Module): @@ -925,6 +931,47 @@ def sanitize_coordinates(self, x1, x2, img_size, padding=0, cast=True): x2 = torch.clamp(x2 + padding, max=img_size) return x1, x2 + def simple_test(self, + feats, + img_metas, + rescale=False, + det_bboxes=None, + det_labels=None, + det_coeffs=None): + assert det_bboxes + assert det_labels + assert det_coeffs + num_imgs = len(img_metas) + scale_factors = tuple(meta['scale_factor'] for meta in img_metas) + if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes): + segm_results = [[[] for _ in range(self.num_classes)] + for _ in range(num_imgs)] + else: + # if det_bboxes is rescaled to the original image size, we need to + # rescale it back to the testing scale to obtain RoIs. + if rescale and not isinstance(scale_factors[0], float): + scale_factors = [ + torch.from_numpy(scale_factor).to(det_bboxes[0].device) + for scale_factor in scale_factors + ] + _bboxes = [ + det_bboxes[i][:, :4] * + scale_factors[i] if rescale else det_bboxes[i][:, :4] + for i in range(len(det_bboxes)) + ] + mask_preds = self.forward(feats[0], det_coeffs, _bboxes, img_metas) + # apply mask post-processing to each image individually + segm_results = [] + for i in range(num_imgs): + if det_bboxes[i].shape[0] == 0: + segm_results.append([[] for _ in range(self.num_classes)]) + else: + segm_result = self.get_seg_masks(mask_preds[i], + det_labels[i], + img_metas[i], rescale) + segm_results.append(segm_result) + return segm_results + class InterpolateModule(nn.Module): """This is a module version of F.interpolate. diff --git a/mmdet/models/detectors/detr.py b/mmdet/models/detectors/detr.py index 5ff82a280da..3720ee9ba6b 100644 --- a/mmdet/models/detectors/detr.py +++ b/mmdet/models/detectors/detr.py @@ -1,4 +1,3 @@ -from mmdet.core import bbox2result from ..builder import DETECTORS from .single_stage import SingleStageDetector @@ -16,31 +15,3 @@ def __init__(self, pretrained=None): super(DETR, self).__init__(backbone, None, bbox_head, train_cfg, test_cfg, pretrained) - - def simple_test(self, img, img_metas, rescale=False): - """Test function without test time augmentation. - - Args: - imgs (list[torch.Tensor]): List of multiple images - img_metas (list[dict]): List of image information. - rescale (bool, optional): Whether to rescale the results. - Defaults to False. - - Returns: - list[list[np.ndarray]]: BBox results of each image and classes. - The outer list corresponds to each image. The inner list - corresponds to each class. - """ - batch_size = len(img_metas) - assert batch_size == 1, 'Currently only batch_size 1 for inference ' \ - f'mode is supported. Found batch_size {batch_size}.' - x = self.extract_feat(img) - outs = self.bbox_head(x, img_metas) - bbox_list = self.bbox_head.get_bboxes( - *outs, img_metas, rescale=rescale) - - bbox_results = [ - bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) - for det_bboxes, det_labels in bbox_list - ] - return bbox_results diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py index 5172bdbd945..fd012fb408b 100644 --- a/mmdet/models/detectors/single_stage.py +++ b/mmdet/models/detectors/single_stage.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn -from mmdet.core import bbox2result from ..builder import DETECTORS, build_backbone, build_head, build_neck from .base import BaseDetector @@ -95,38 +94,32 @@ def forward_train(self, gt_labels, gt_bboxes_ignore) return losses - def simple_test(self, img, img_metas, rescale=False): - """Test function without test time augmentation. + def simple_test(self, img, img_metas, rescale=False, postprocess=True): + """Test function without test-time augmentation. Args: - imgs (list[torch.Tensor]): List of multiple images + img (torch.Tensor): Images with shape (N, C, H, W). img_metas (list[dict]): List of image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. + postprocess (bool, optional): Whether to perform post-processing + by bbox2result. Defaults to True. Returns: list[list[np.ndarray]]: BBox results of each image and classes. The outer list corresponds to each image. The inner list corresponds to each class. """ - x = self.extract_feat(img) - outs = self.bbox_head(x) - # get origin input shape to support onnx dynamic shape if torch.onnx.is_in_onnx_export(): - # get shape as tensor + # get origin input shape as tensor to support onnx dynamic shape img_shape = torch._shape_as_tensor(img)[2:] img_metas[0]['img_shape_for_onnx'] = img_shape - bbox_list = self.bbox_head.get_bboxes( - *outs, img_metas, rescale=rescale) - # skip post-processing when exporting to ONNX - if torch.onnx.is_in_onnx_export(): - return bbox_list + # skip post-processing when exporting to ONNX + postprocess = False - bbox_results = [ - bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) - for det_bboxes, det_labels in bbox_list - ] - return bbox_results + feat = self.extract_feat(img) + return self.bbox_head.simple_test( + feat, img_metas, rescale=rescale, postprocess=postprocess) def aug_test(self, imgs, img_metas, rescale=False): """Test function with test time augmentation. diff --git a/mmdet/models/detectors/yolact.py b/mmdet/models/detectors/yolact.py index f32fde0d3dc..f92e510a33d 100644 --- a/mmdet/models/detectors/yolact.py +++ b/mmdet/models/detectors/yolact.py @@ -96,49 +96,23 @@ def forward_train(self, return losses def simple_test(self, img, img_metas, rescale=False): - """Test function without test time augmentation.""" - x = self.extract_feat(img) - - cls_score, bbox_pred, coeff_pred = self.bbox_head(x) - - bbox_inputs = (cls_score, bbox_pred, - coeff_pred) + (img_metas, self.test_cfg, rescale) - det_bboxes, det_labels, det_coeffs = self.bbox_head.get_bboxes( - *bbox_inputs) + """Test function without test-time augmentation.""" + feat = self.extract_feat(img) + det_bboxes, det_labels, det_coeffs = self.bbox_head.simple_test( + feat, img_metas, rescale=rescale, postprocess=False) bbox_results = [ bbox2result(det_bbox, det_label, self.bbox_head.num_classes) for det_bbox, det_label in zip(det_bboxes, det_labels) ] - num_imgs = len(img_metas) - scale_factors = tuple(meta['scale_factor'] for meta in img_metas) - if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes): - segm_results = [[[] for _ in range(self.mask_head.num_classes)] - for _ in range(num_imgs)] - else: - # if det_bboxes is rescaled to the original image size, we need to - # rescale it back to the testing scale to obtain RoIs. - if rescale and not isinstance(scale_factors[0], float): - scale_factors = [ - torch.from_numpy(scale_factor).to(det_bboxes[0].device) - for scale_factor in scale_factors - ] - _bboxes = [ - det_bboxes[i][:, :4] * - scale_factors[i] if rescale else det_bboxes[i][:, :4] - for i in range(len(det_bboxes)) - ] - mask_preds = self.mask_head(x[0], det_coeffs, _bboxes, img_metas) - # apply mask post-processing to each image individually - segm_results = [] - for i in range(num_imgs): - if det_bboxes[i].shape[0] == 0: - segm_results.append( - [[] for _ in range(self.mask_head.num_classes)]) - else: - segm_result = self.mask_head.get_seg_masks( - mask_preds[i], det_labels[i], img_metas[i], rescale) - segm_results.append(segm_result) + segm_results = self.mask_head.simple_test( + feat, + img_metas, + rescale=rescale, + det_bboxes=det_bboxes, + det_labels=det_labels, + det_coeffs=det_coeffs) + return list(zip(bbox_results, segm_results)) def aug_test(self, imgs, img_metas, rescale=False): diff --git a/tests/test_models/test_dense_heads/test_dense_heads_attr.py b/tests/test_models/test_dense_heads/test_dense_heads_attr.py new file mode 100644 index 00000000000..f6be7f15272 --- /dev/null +++ b/tests/test_models/test_dense_heads/test_dense_heads_attr.py @@ -0,0 +1,43 @@ +import warnings + +from terminaltables import AsciiTable + +from mmdet.models import dense_heads +from mmdet.models.dense_heads import * # noqa: F401,F403 + + +def test_dense_heads_test_attr(): + """Tests inference methods such as simple_test and aug_test.""" + # make list of dense heads + exceptions = ['FeatureAdaption'] # module used in head + all_dense_heads = [m for m in dense_heads.__all__ if m not in exceptions] + + # search attributes + check_attributes = [ + 'simple_test', 'aug_test', 'simple_test_bboxes', 'simple_test_rpn', + 'aug_test_rpn' + ] + table_header = ['head name'] + check_attributes + table_data = [table_header] + not_found = {k: [] for k in check_attributes} + for target_head_name in all_dense_heads: + target_head = globals()[target_head_name] + target_head_attributes = dir(target_head) + check_results = [target_head_name] + for check_attribute in check_attributes: + found = check_attribute in target_head_attributes + check_results.append(found) + if not found: + not_found[check_attribute].append(target_head_name) + table_data.append(check_results) + table = AsciiTable(table_data) + print() + print(table.table) + + # NOTE: this test just checks attributes. + # simple_test of RPN heads will not work now. + assert len(not_found['simple_test']) == 0, \ + f'simple_test not found in {not_found["simple_test"]}' + if len(not_found['aug_test']) != 0: + warnings.warn(f'aug_test not found in {not_found["aug_test"]}. ' + 'Please implement it or raise NotImplementedError.')