Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TTA of RetinaNet and GFL #3638

Merged
merged 6 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion mmdet/models/dense_heads/anchor_free_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from mmdet.core import force_fp32, multi_apply
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin


@HEADS.register_module()
class AnchorFreeHead(BaseDenseHead):
class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
"""Anchor-free head (FCOS, Fovea, RepPoints, etc.).

Args:
Expand Down Expand Up @@ -327,3 +328,21 @@ def get_points(self, featmap_sizes, dtype, device, flatten=False):
self._get_points_single(featmap_sizes[i], self.strides[i],
dtype, device, flatten))
return mlvl_points

def aug_test(self, feats, img_metas, rescale=False):
"""Test function with test time augmentation.

Args:
feats (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains features for all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.

Returns:
list[ndarray]: bbox results of each class
"""
return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
59 changes: 49 additions & 10 deletions mmdet/models/dense_heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
multiclass_nms, unmap)
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin


@HEADS.register_module()
class AnchorHead(BaseDenseHead):
class AnchorHead(BaseDenseHead, BBoxTestMixin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RPNHead has a mixin RPNTestMixin, can these two mixins somehow be merged to avoid confusion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simple merge of dense_test_mixins.py and rpn_test_mixin.py causes another confusion, because dense_test_mixins.py focuses on TTA, and doesn't have simple_test.

This issue comes from the following inconsistencies.

  • roi_heads have simple_test
  • RPNHead and GARPNHead have simple_test_rpn
  • dense_heads don't have simple_test

I think the inconsistencies should be addressed by a PR for refactoring, not by this PR for adding a feature.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. A later refactoring will be proposed.

"""Anchor-based head (RPN, RetinaNet, SSD, etc.).

Args:
Expand Down Expand Up @@ -502,7 +503,8 @@ def get_bboxes(self,
bbox_preds,
img_metas,
cfg=None,
rescale=False):
rescale=False,
with_nms=True):
"""Transform network output for a batch into bbox predictions.

Args:
Expand All @@ -516,6 +518,8 @@ def get_bboxes(self,
if None, test_cfg would be used
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.

Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
Expand Down Expand Up @@ -569,9 +573,18 @@ def get_bboxes(self,
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
if with_nms:
# some heads don't support with_nms argument
proposals = self._get_bboxes_single(cls_score_list,
bbox_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
else:
proposals = self._get_bboxes_single(cls_score_list,
bbox_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale,
with_nms)
Comment on lines +576 to +587
Copy link
Collaborator

@xvjiarui xvjiarui Sep 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we just pass with_nms into self._get_bboxes_single, will any detector raise an error?
I checked all heads inherited AnchorHead, it looks fine except for ATSSHead.
We may support it in ATSSHead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ATSSHead will not raise an error, since it has own get_bboxes.
But RPNHead will raise an error.
The code above also avoids breaking compatibility with mmdetection forks and other PRs.
Needless to say, the code should be cleaned after the with_nms argument becomes standard.

result_list.append(proposals)
return result_list

Expand All @@ -582,7 +595,8 @@ def _get_bboxes_single(self,
img_shape,
scale_factor,
cfg,
rescale=False):
rescale=False,
with_nms=True):
"""Transform outputs for a single batch item into bbox predictions.

Args:
Expand All @@ -599,6 +613,9 @@ def _get_bboxes_single(self,
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.

Returns:
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
Expand Down Expand Up @@ -647,7 +664,29 @@ def _get_bboxes_single(self,
# BG cat_id: num_class
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels

if with_nms:
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
else:
return mlvl_bboxes, mlvl_scores

def aug_test(self, feats, img_metas, rescale=False):
"""Test function with test time augmentation.

Args:
feats (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains features for all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.

Returns:
list[ndarray]: bbox results of each class
"""
return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
88 changes: 88 additions & 0 deletions mmdet/models/dense_heads/dense_test_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from inspect import signature

import torch

from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms


class BBoxTestMixin(object):
"""Mixin class for test time augmentation of bboxes."""

def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
"""Merge augmented detection bboxes and scores.

Args:
aug_bboxes (list[Tensor]): shape (n, 4*#class)
aug_scores (list[Tensor] or None): shape (n, #class)
img_shapes (list[Tensor]): shape (3, ).

Returns:
tuple: (bboxes, scores)
"""
recovered_bboxes = []
for bboxes, img_info in zip(aug_bboxes, img_metas):
img_shape = img_info[0]['img_shape']
scale_factor = img_info[0]['scale_factor']
flip = img_info[0]['flip']
flip_direction = img_info[0]['flip_direction']
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
flip_direction)
recovered_bboxes.append(bboxes)
bboxes = torch.cat(recovered_bboxes, dim=0)
if aug_scores is None:
return bboxes
else:
scores = torch.cat(aug_scores, dim=0)
return bboxes, scores

def aug_test_bboxes(self, feats, img_metas, rescale=False):
"""Test det bboxes with test time augmentation.

Args:
feats (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains features for all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.

Returns:
list[ndarray]: bbox results of each class
"""
# check with_nms argument
gb_sig = signature(self.get_bboxes)
gb_args = [p.name for p in gb_sig.parameters.values()]
gbs_sig = signature(self._get_bboxes_single)
gbs_args = [p.name for p in gbs_sig.parameters.values()]
assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \
f'{self.__class__.__name__}' \
' does not support test-time augmentation'

aug_bboxes = []
aug_scores = []
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
outs = self.forward(x)
bbox_inputs = outs + (img_meta, self.test_cfg, False, False)
det_bboxes, det_scores = self.get_bboxes(*bbox_inputs)[0]
aug_bboxes.append(det_bboxes)
aug_scores.append(det_scores)

# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = self.merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
self.test_cfg.score_thr,
self.test_cfg.nms,
self.test_cfg.max_per_img)

if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= det_bboxes.new_tensor(
img_metas[0][0]['scale_factor'])
bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes)
return bbox_results
16 changes: 11 additions & 5 deletions mmdet/models/dense_heads/gfl_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ def _get_bboxes_single(self,
img_shape,
scale_factor,
cfg,
rescale=False):
rescale=False,
with_nms=True):
"""Transform outputs for a single batch item into labeled boxes.

Args:
Expand All @@ -401,6 +402,8 @@ def _get_bboxes_single(self,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.

Returns:
tuple(Tensor):
Expand Down Expand Up @@ -450,10 +453,13 @@ def _get_bboxes_single(self,
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)

det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
if with_nms:
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
else:
return mlvl_bboxes, mlvl_scores

def get_targets(self,
anchor_list,
Expand Down
8 changes: 4 additions & 4 deletions mmdet/models/dense_heads/reppoints_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def get_bboxes(self,
img_metas,
cfg=None,
rescale=False,
nms=True):
with_nms=True):
assert len(cls_scores) == len(pts_preds_refine)
bbox_preds_refine = [
self.points2bbox(pts_pred_refine)
Expand All @@ -690,7 +690,7 @@ def get_bboxes(self,
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
mlvl_points, img_shape,
scale_factor, cfg, rescale,
nms)
with_nms)
result_list.append(proposals)
return result_list

Expand All @@ -702,7 +702,7 @@ def _get_bboxes_single(self,
scale_factor,
cfg,
rescale=False,
nms=True):
with_nms=True):
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
mlvl_bboxes = []
Expand Down Expand Up @@ -749,7 +749,7 @@ def _get_bboxes_single(self,
# BG cat_id: num_class
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
if nms:
if with_nms:
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
Expand Down
77 changes: 0 additions & 77 deletions mmdet/models/detectors/reppoints_detector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import torch

from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms
from ..builder import DETECTORS
from .single_stage import SingleStageDetector

Expand All @@ -23,77 +20,3 @@ def __init__(self,
super(RepPointsDetector,
self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
pretrained)

def merge_aug_results(self, aug_bboxes, aug_scores, img_metas):
"""Merge augmented detection bboxes and scores.

Args:
aug_bboxes (list[Tensor]): shape (n, 4*#class)
aug_scores (list[Tensor] or None): shape (n, #class)
img_shapes (list[Tensor]): shape (3, ).

Returns:
tuple: (bboxes, scores)
"""
recovered_bboxes = []
for bboxes, img_info in zip(aug_bboxes, img_metas):
img_shape = img_info[0]['img_shape']
scale_factor = img_info[0]['scale_factor']
flip = img_info[0]['flip']
flip_direction = img_info[0]['flip_direction']
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
flip_direction)
recovered_bboxes.append(bboxes)
bboxes = torch.cat(recovered_bboxes, dim=0)
if aug_scores is None:
return bboxes
else:
scores = torch.cat(aug_scores, dim=0)
return bboxes, scores

def aug_test(self, imgs, img_metas, rescale=False):
"""Test function with test time augmentation.

Args:
imgs (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.

Returns:
list[ndarray]: bbox results of each class
"""
# recompute feats to save memory
feats = self.extract_feats(imgs)

aug_bboxes = []
aug_scores = []
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
outs = self.bbox_head(x)
bbox_inputs = outs + (img_meta, self.test_cfg, False, False)
det_bboxes, det_scores = self.bbox_head.get_bboxes(*bbox_inputs)[0]
aug_bboxes.append(det_bboxes)
aug_scores.append(det_scores)

# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = self.merge_aug_results(
aug_bboxes, aug_scores, img_metas)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
self.test_cfg.score_thr,
self.test_cfg.nms,
self.test_cfg.max_per_img)

if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= det_bboxes.new_tensor(
img_metas[0][0]['scale_factor'])
bbox_results = bbox2result(_det_bboxes, det_labels,
self.bbox_head.num_classes)
return bbox_results
Loading