Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Add simple_test to dense heads #5061

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,22 @@ def forward_train(self,
else:
proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
return losses, proposal_list

def simple_test(self, feats, img_metas, rescale=False, postprocess=True):
"""Test function without test-time augmentation.

Args:
feats (tuple[torch.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
postprocess (bool, optional): Whether to perform post-processing
by bbox2result. Defaults to True.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
return self.simple_test_bboxes(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems you add this arguments postprocess to get raw results of boxes in yolact, but I believe original design keep bbbox2results is single-stage would a better design and can do the same thing, Is there other reasons?

In addition, I think keeping this operation in the SingleStageDetector instead of simple_test_bboxes would be more consistent with the two-tage models.
just like

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the argument postprocess to simplify code for onnx.

# skip post-processing when exporting to ONNX
postprocess = False

The original reason disappeared after #5205

Copy link
Collaborator

@jshilong jshilong May 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall design looks good to me, and I would help to move bbbox2results to simple_test of single_stage.py.
It would be merged soon, Thanks for your contribution.

feats, img_metas, rescale=rescale, postprocess=postprocess)
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/cascade_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,4 +782,5 @@ def simple_test_rpn(self, x, img_metas):

def aug_test_rpn(self, x, img_metas):
"""Augmented forward test function."""
raise NotImplementedError
raise NotImplementedError(
'CascadeRPNHead does not support test-time augmentation')
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/centernet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from ..utils.gaussian_target import (get_local_maximum, get_topk_from_heatmap,
transpose_and_gather_feat)
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin


@HEADS.register_module()
class CenterNetHead(BaseDenseHead):
class CenterNetHead(BaseDenseHead, BBoxTestMixin):
"""Objects as Points Head. CenterHead use center_point to indicate object's
position. Paper link <https://arxiv.org/abs/1904.07850>

Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/corner_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_topk_from_heatmap,
transpose_and_gather_feat)
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin


class BiCornerPool(BaseModule):
Expand Down Expand Up @@ -80,7 +81,7 @@ def forward(self, x):


@HEADS.register_module()
class CornerHead(BaseDenseHead):
class CornerHead(BaseDenseHead, BBoxTestMixin):
"""Head of CornerNet: Detecting Objects as Paired Keypoints.

Code is modified from the `official github repo
Expand Down
102 changes: 99 additions & 3 deletions mmdet/models/dense_heads/dense_test_mixins.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,49 @@
import sys
from inspect import signature

import torch

from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms
from mmdet.core import (bbox2result, bbox_mapping_back, merge_aug_proposals,
multiclass_nms)

if sys.version_info >= (3, 7):
from mmdet.utils.contextmanagers import completed

class BBoxTestMixin:
"""Mixin class for test time augmentation of bboxes."""

class BBoxTestMixin(object):
"""Mixin class for testing det bboxes."""

def simple_test_bboxes(self,
feats,
img_metas,
rescale=False,
postprocess=True):
"""Test det bboxes without test-time augmentation.

Args:
feats (tuple[torch.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
postprocess (bool, optional): Whether to perform post-processing
by bbox2result. Defaults to True.

Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
outs = self.forward(feats)
bbox_list = self.get_bboxes(*outs, img_metas, rescale=rescale)
if postprocess:
bbox_results = [
bbox2result(det_bboxes, det_labels, self.num_classes)
for det_bboxes, det_labels in bbox_list
]
return bbox_results
else:
return bbox_list

def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
"""Merge augmented detection bboxes and scores.
Expand Down Expand Up @@ -98,3 +135,62 @@ def aug_test_bboxes(self, feats, img_metas, rescale=False):
img_metas[0][0]['scale_factor'])
bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes)
jshilong marked this conversation as resolved.
Show resolved Hide resolved
return bbox_results

def simple_test_rpn(self, x, img_metas):
"""Test without augmentation for RPN.

Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
img_metas (list[dict]): Meta info of each image.

Returns:
list[Tensor]: Proposals of each image.
"""
rpn_outs = self(x)
proposal_list = self.get_bboxes(*rpn_outs, img_metas)
return proposal_list

def aug_test_rpn(self, feats, img_metas):
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
"""Test with augmentation for RPN.

Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
img_metas (list[dict]): Meta info of each image.

Returns:
list[Tensor]: Proposals of each image.
"""
samples_per_gpu = len(img_metas[0])
aug_proposals = [[] for _ in range(samples_per_gpu)]
for x, img_meta in zip(feats, img_metas):
proposal_list = self.simple_test_rpn(x, img_meta)
for i, proposals in enumerate(proposal_list):
aug_proposals[i].append(proposals)
# reorganize the order of 'img_metas' to match the dimensions
# of 'aug_proposals'
aug_img_metas = []
for i in range(samples_per_gpu):
aug_img_meta = []
for j in range(len(img_metas)):
aug_img_meta.append(img_metas[j][i])
aug_img_metas.append(aug_img_meta)
# after merging, proposals will be rescaled to the original image size
merged_proposals = [
merge_aug_proposals(proposals, aug_img_meta, self.test_cfg)
for proposals, aug_img_meta in zip(aug_proposals, aug_img_metas)
]
return merged_proposals

if sys.version_info >= (3, 7):

async def async_simple_test_rpn(self, x, img_metas):
sleep_interval = self.test_cfg.pop('async_sleep_interval', 0.025)
async with completed(
__name__, 'rpn_head_forward',
sleep_interval=sleep_interval):
rpn_outs = self(x)

proposal_list = self.get_bboxes(*rpn_outs, img_metas)
return proposal_list
39 changes: 38 additions & 1 deletion mmdet/models/dense_heads/detr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mmcv.cnn.bricks.transformer import FFN, build_positional_encoding
from mmcv.runner import force_fp32

from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh,
from mmdet.core import (bbox2result, bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh,
build_assigner, build_sampler, multi_apply,
reduce_mean)
from mmdet.models.utils import build_transformer
Expand Down Expand Up @@ -680,3 +680,40 @@ def _get_bboxes_single(self,
det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(1)), -1)

return det_bboxes, det_labels

def simple_test_bboxes(self,
feats,
img_metas,
rescale=False,
postprocess=True):
"""Test det bboxes without test-time augmentation.

Args:
feats (tuple[torch.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
postprocess (bool, optional): Whether to perform post-processing
by bbox2result. Defaults to True.

Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
batch_size = len(img_metas)
assert batch_size == 1, 'Currently only batch_size 1 for inference ' \
f'mode is supported. Found batch_size {batch_size}.'

# forward of this head requires img_metas
outs = self.forward(feats, img_metas)
bbox_list = self.get_bboxes(*outs, img_metas, rescale=rescale)
if postprocess:
bbox_results = [
bbox2result(det_bboxes, det_labels, self.num_classes)
for det_bboxes, det_labels in bbox_list
]
return bbox_results
else:
return bbox_list
8 changes: 8 additions & 0 deletions mmdet/models/dense_heads/embedding_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,11 @@ def forward_train(self, img, img_metas):
def simple_test_rpn(self, img, img_metas):
"""Forward function in testing stage."""
return self._decode_init_proposals(img, img_metas)

def simple_test(self, img, img_metas):
"""Forward function in testing stage."""
raise NotImplementedError

def aug_test_rpn(self, feats, img_metas):
raise NotImplementedError(
'EmbeddingRPNHead does not support test-time augmentation')
3 changes: 1 addition & 2 deletions mmdet/models/dense_heads/ga_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@

from ..builder import HEADS
from .guided_anchor_head import GuidedAnchorHead
from .rpn_test_mixin import RPNTestMixin


@HEADS.register_module()
class GARPNHead(RPNTestMixin, GuidedAnchorHead):
class GARPNHead(GuidedAnchorHead):
"""Guided-Anchor-based RPN head."""

def __init__(self,
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/paa_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,8 @@ def _get_bboxes(self,
cls_scores. Besides, score voting is used when `` score_voting``
is set to True.
"""
assert with_nms, 'PAA only supports "with_nms=True" now'
assert with_nms, 'PAA only supports "with_nms=True" now and it is ' \
jshilong marked this conversation as resolved.
Show resolved Hide resolved
'mean PAAHead does not support test-time augmentation'
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
batch_size = cls_scores[0].shape[0]

Expand Down
3 changes: 1 addition & 2 deletions mmdet/models/dense_heads/rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@

from ..builder import HEADS
from .anchor_head import AnchorHead
from .rpn_test_mixin import RPNTestMixin


@HEADS.register_module()
class RPNHead(RPNTestMixin, AnchorHead):
class RPNHead(AnchorHead):
"""RPN head.

Args:
Expand Down
59 changes: 0 additions & 59 deletions mmdet/models/dense_heads/rpn_test_mixin.py

This file was deleted.

3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/sabl_retina_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
multi_apply, multiclass_nms, unmap)
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin
from .guided_anchor_head import GuidedAnchorHead


@HEADS.register_module()
class SABLRetinaHead(BaseDenseHead):
class SABLRetinaHead(BaseDenseHead, BBoxTestMixin):
"""Side-Aware Boundary Localization (SABL) for RetinaNet.

The anchor generation, assigning and sampling in SABLRetinaHead
Expand Down
Loading