Skip to content

Commit

Permalink
Support TTA of RetinaNet and GFL (#3638)
Browse files Browse the repository at this point in the history
* Move RepPoints TTA to mixin class for reuse

* Support TTA of RetinaNet

* Support TTA of GFL

* Update to use BBoxTestMixin in dense_heads

* Update for v2.4.0 inference
  • Loading branch information
shinya7y authored Sep 24, 2020
1 parent 868f7e5 commit 9c95543
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 99 deletions.
21 changes: 20 additions & 1 deletion mmdet/models/dense_heads/anchor_free_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from mmdet.core import multi_apply
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin


@HEADS.register_module()
class AnchorFreeHead(BaseDenseHead):
class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
"""Anchor-free head (FCOS, Fovea, RepPoints, etc.).
Args:
Expand Down Expand Up @@ -328,3 +329,21 @@ def get_points(self, featmap_sizes, dtype, device, flatten=False):
self._get_points_single(featmap_sizes[i], self.strides[i],
dtype, device, flatten))
return mlvl_points

def aug_test(self, feats, img_metas, rescale=False):
"""Test function with test time augmentation.
Args:
feats (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains features for all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[ndarray]: bbox results of each class
"""
return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
59 changes: 49 additions & 10 deletions mmdet/models/dense_heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
images_to_levels, multi_apply, multiclass_nms, unmap)
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin


@HEADS.register_module()
class AnchorHead(BaseDenseHead):
class AnchorHead(BaseDenseHead, BBoxTestMixin):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
Args:
Expand Down Expand Up @@ -502,7 +503,8 @@ def get_bboxes(self,
bbox_preds,
img_metas,
cfg=None,
rescale=False):
rescale=False,
with_nms=True):
"""Transform network output for a batch into bbox predictions.
Args:
Expand All @@ -516,6 +518,8 @@ def get_bboxes(self,
if None, test_cfg would be used
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
Expand Down Expand Up @@ -569,9 +573,18 @@ def get_bboxes(self,
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
if with_nms:
# some heads don't support with_nms argument
proposals = self._get_bboxes_single(cls_score_list,
bbox_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
else:
proposals = self._get_bboxes_single(cls_score_list,
bbox_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale,
with_nms)
result_list.append(proposals)
return result_list

Expand All @@ -582,7 +595,8 @@ def _get_bboxes_single(self,
img_shape,
scale_factor,
cfg,
rescale=False):
rescale=False,
with_nms=True):
"""Transform outputs for a single batch item into bbox predictions.
Args:
Expand All @@ -599,6 +613,9 @@ def _get_bboxes_single(self,
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
Expand Down Expand Up @@ -647,7 +664,29 @@ def _get_bboxes_single(self,
# BG cat_id: num_class
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels

if with_nms:
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
else:
return mlvl_bboxes, mlvl_scores

def aug_test(self, feats, img_metas, rescale=False):
"""Test function with test time augmentation.
Args:
feats (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains features for all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[ndarray]: bbox results of each class
"""
return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
88 changes: 88 additions & 0 deletions mmdet/models/dense_heads/dense_test_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from inspect import signature

import torch

from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms


class BBoxTestMixin(object):
"""Mixin class for test time augmentation of bboxes."""

def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
"""Merge augmented detection bboxes and scores.
Args:
aug_bboxes (list[Tensor]): shape (n, 4*#class)
aug_scores (list[Tensor] or None): shape (n, #class)
img_shapes (list[Tensor]): shape (3, ).
Returns:
tuple: (bboxes, scores)
"""
recovered_bboxes = []
for bboxes, img_info in zip(aug_bboxes, img_metas):
img_shape = img_info[0]['img_shape']
scale_factor = img_info[0]['scale_factor']
flip = img_info[0]['flip']
flip_direction = img_info[0]['flip_direction']
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
flip_direction)
recovered_bboxes.append(bboxes)
bboxes = torch.cat(recovered_bboxes, dim=0)
if aug_scores is None:
return bboxes
else:
scores = torch.cat(aug_scores, dim=0)
return bboxes, scores

def aug_test_bboxes(self, feats, img_metas, rescale=False):
"""Test det bboxes with test time augmentation.
Args:
feats (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains features for all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[ndarray]: bbox results of each class
"""
# check with_nms argument
gb_sig = signature(self.get_bboxes)
gb_args = [p.name for p in gb_sig.parameters.values()]
gbs_sig = signature(self._get_bboxes_single)
gbs_args = [p.name for p in gbs_sig.parameters.values()]
assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \
f'{self.__class__.__name__}' \
' does not support test-time augmentation'

aug_bboxes = []
aug_scores = []
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
outs = self.forward(x)
bbox_inputs = outs + (img_meta, self.test_cfg, False, False)
det_bboxes, det_scores = self.get_bboxes(*bbox_inputs)[0]
aug_bboxes.append(det_bboxes)
aug_scores.append(det_scores)

# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = self.merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
self.test_cfg.score_thr,
self.test_cfg.nms,
self.test_cfg.max_per_img)

if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= det_bboxes.new_tensor(
img_metas[0][0]['scale_factor'])
bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes)
return bbox_results
16 changes: 11 additions & 5 deletions mmdet/models/dense_heads/gfl_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ def _get_bboxes_single(self,
img_shape,
scale_factor,
cfg,
rescale=False):
rescale=False,
with_nms=True):
"""Transform outputs for a single batch item into labeled boxes.
Args:
Expand All @@ -401,6 +402,8 @@ def _get_bboxes_single(self,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
tuple(Tensor):
Expand Down Expand Up @@ -450,10 +453,13 @@ def _get_bboxes_single(self,
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)

det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
if with_nms:
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
else:
return mlvl_bboxes, mlvl_scores

def get_targets(self,
anchor_list,
Expand Down
8 changes: 4 additions & 4 deletions mmdet/models/dense_heads/reppoints_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def get_bboxes(self,
img_metas,
cfg=None,
rescale=False,
nms=True):
with_nms=True):
assert len(cls_scores) == len(pts_preds_refine)
bbox_preds_refine = [
self.points2bbox(pts_pred_refine)
Expand All @@ -690,7 +690,7 @@ def get_bboxes(self,
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
mlvl_points, img_shape,
scale_factor, cfg, rescale,
nms)
with_nms)
result_list.append(proposals)
return result_list

Expand All @@ -702,7 +702,7 @@ def _get_bboxes_single(self,
scale_factor,
cfg,
rescale=False,
nms=True):
with_nms=True):
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
mlvl_bboxes = []
Expand Down Expand Up @@ -749,7 +749,7 @@ def _get_bboxes_single(self,
# BG cat_id: num_class
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
if nms:
if with_nms:
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
Expand Down
77 changes: 0 additions & 77 deletions mmdet/models/detectors/reppoints_detector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import torch

from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms
from ..builder import DETECTORS
from .single_stage import SingleStageDetector

Expand All @@ -23,77 +20,3 @@ def __init__(self,
super(RepPointsDetector,
self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
pretrained)

def merge_aug_results(self, aug_bboxes, aug_scores, img_metas):
"""Merge augmented detection bboxes and scores.
Args:
aug_bboxes (list[Tensor]): shape (n, 4*#class)
aug_scores (list[Tensor] or None): shape (n, #class)
img_shapes (list[Tensor]): shape (3, ).
Returns:
tuple: (bboxes, scores)
"""
recovered_bboxes = []
for bboxes, img_info in zip(aug_bboxes, img_metas):
img_shape = img_info[0]['img_shape']
scale_factor = img_info[0]['scale_factor']
flip = img_info[0]['flip']
flip_direction = img_info[0]['flip_direction']
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
flip_direction)
recovered_bboxes.append(bboxes)
bboxes = torch.cat(recovered_bboxes, dim=0)
if aug_scores is None:
return bboxes
else:
scores = torch.cat(aug_scores, dim=0)
return bboxes, scores

def aug_test(self, imgs, img_metas, rescale=False):
"""Test function with test time augmentation.
Args:
imgs (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[ndarray]: bbox results of each class
"""
# recompute feats to save memory
feats = self.extract_feats(imgs)

aug_bboxes = []
aug_scores = []
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
outs = self.bbox_head(x)
bbox_inputs = outs + (img_meta, self.test_cfg, False, False)
det_bboxes, det_scores = self.bbox_head.get_bboxes(*bbox_inputs)[0]
aug_bboxes.append(det_bboxes)
aug_scores.append(det_scores)

# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = self.merge_aug_results(
aug_bboxes, aug_scores, img_metas)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
self.test_cfg.score_thr,
self.test_cfg.nms,
self.test_cfg.max_per_img)

if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= det_bboxes.new_tensor(
img_metas[0][0]['scale_factor'])
bbox_results = bbox2result(_det_bboxes, det_labels,
self.bbox_head.num_classes)
return bbox_results
Loading

0 comments on commit 9c95543

Please sign in to comment.