Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Add simple_test to dense heads #5061

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,22 @@ def forward_train(self,
else:
proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
return losses, proposal_list

def simple_test(self, feats, img_metas, rescale=False, postprocess=True):
"""Test function without test-time augmentation.

Args:
feats (tuple[torch.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
postprocess (bool, optional): Whether to perform post-processing
by bbox2result. Defaults to True.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
return self.simple_test_bboxes(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems you add this arguments postprocess to get raw results of boxes in yolact, but I believe original design keep bbbox2results is single-stage would a better design and can do the same thing, Is there other reasons?

In addition, I think keeping this operation in the SingleStageDetector instead of simple_test_bboxes would be more consistent with the two-tage models.
just like

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the argument postprocess to simplify code for onnx.

# skip post-processing when exporting to ONNX
postprocess = False

The original reason disappeared after #5205

Copy link
Collaborator

@jshilong jshilong May 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall design looks good to me, and I would help to move bbbox2results to simple_test of single_stage.py.
It would be merged soon, Thanks for your contribution.

feats, img_metas, rescale=rescale, postprocess=postprocess)
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/corner_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..builder import HEADS, build_loss
from ..utils import gaussian_radius, gen_gaussian_target
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin


class BiCornerPool(BaseModule):
Expand Down Expand Up @@ -78,7 +79,7 @@ def forward(self, x):


@HEADS.register_module()
class CornerHead(BaseDenseHead):
class CornerHead(BaseDenseHead, BBoxTestMixin):
"""Head of CornerNet: Detecting Objects as Paired Keypoints.

Code is modified from the `official github repo
Expand Down
90 changes: 88 additions & 2 deletions mmdet/models/dense_heads/dense_test_mixins.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,49 @@
import sys
from inspect import signature

import torch

from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms
from mmdet.core import (bbox2result, bbox_mapping_back, merge_aug_proposals,
multiclass_nms)

if sys.version_info >= (3, 7):
from mmdet.utils.contextmanagers import completed


class BBoxTestMixin(object):
"""Mixin class for test time augmentation of bboxes."""
"""Mixin class for testing det bboxes."""

def simple_test_bboxes(self,
feats,
img_metas,
rescale=False,
postprocess=True):
"""Test det bboxes without test-time augmentation.

Args:
feats (tuple[torch.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
postprocess (bool, optional): Whether to perform post-processing
by bbox2result. Defaults to True.

Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
outs = self.forward(feats)
bbox_list = self.get_bboxes(*outs, img_metas, rescale=rescale)
if postprocess:
bbox_results = [
bbox2result(det_bboxes, det_labels, self.num_classes)
for det_bboxes, det_labels in bbox_list
]
return bbox_results
else:
return bbox_list

def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
"""Merge augmented detection bboxes and scores.
Expand Down Expand Up @@ -98,3 +135,52 @@ def aug_test_bboxes(self, feats, img_metas, rescale=False):
img_metas[0][0]['scale_factor'])
bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes)
jshilong marked this conversation as resolved.
Show resolved Hide resolved
return bbox_results

if sys.version_info >= (3, 7):

async def async_simple_test_rpn(self, x, img_metas):
sleep_interval = self.test_cfg.pop('async_sleep_interval', 0.025)
async with completed(
__name__, 'rpn_head_forward',
sleep_interval=sleep_interval):
rpn_outs = self(x)

proposal_list = self.get_bboxes(*rpn_outs, img_metas)
return proposal_list

def simple_test_rpn(self, x, img_metas):
"""Test without augmentation.

Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
img_metas (list[dict]): Meta info of each image.

Returns:
list[Tensor]: Proposals of each image.
"""
rpn_outs = self(x)
proposal_list = self.get_bboxes(*rpn_outs, img_metas)
return proposal_list

def aug_test_rpn(self, feats, img_metas):
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
samples_per_gpu = len(img_metas[0])
aug_proposals = [[] for _ in range(samples_per_gpu)]
for x, img_meta in zip(feats, img_metas):
proposal_list = self.simple_test_rpn(x, img_meta)
for i, proposals in enumerate(proposal_list):
aug_proposals[i].append(proposals)
# reorganize the order of 'img_metas' to match the dimensions
# of 'aug_proposals'
aug_img_metas = []
for i in range(samples_per_gpu):
aug_img_meta = []
for j in range(len(img_metas)):
aug_img_meta.append(img_metas[j][i])
aug_img_metas.append(aug_img_meta)
# after merging, proposals will be rescaled to the original image size
merged_proposals = [
merge_aug_proposals(proposals, aug_img_meta, self.test_cfg)
for proposals, aug_img_meta in zip(aug_proposals, aug_img_metas)
]
return merged_proposals
39 changes: 38 additions & 1 deletion mmdet/models/dense_heads/detr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mmcv.cnn.bricks.transformer import FFN, build_positional_encoding
from mmcv.runner import force_fp32

from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh,
from mmdet.core import (bbox2result, bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh,
build_assigner, build_sampler, multi_apply,
reduce_mean)
from mmdet.models.utils import build_transformer
Expand Down Expand Up @@ -680,3 +680,40 @@ def _get_bboxes_single(self,
det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(1)), -1)

return det_bboxes, det_labels

def simple_test_bboxes(self,
feats,
img_metas,
rescale=False,
postprocess=True):
"""Test det bboxes without test-time augmentation.

Args:
feats (tuple[torch.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
postprocess (bool, optional): Whether to perform post-processing
by bbox2result. Defaults to True.

Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
batch_size = len(img_metas)
assert batch_size == 1, 'Currently only batch_size 1 for inference ' \
f'mode is supported. Found batch_size {batch_size}.'

# forward of this head requires img_metas
outs = self.forward(feats, img_metas)
bbox_list = self.get_bboxes(*outs, img_metas, rescale=rescale)
if postprocess:
bbox_results = [
bbox2result(det_bboxes, det_labels, self.num_classes)
for det_bboxes, det_labels in bbox_list
]
return bbox_results
else:
return bbox_list
4 changes: 4 additions & 0 deletions mmdet/models/dense_heads/embedding_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,7 @@ def forward_train(self, img, img_metas):
def simple_test_rpn(self, img, img_metas):
"""Forward function in testing stage."""
return self._decode_init_proposals(img, img_metas)

def simple_test(self, img, img_metas):
"""Forward function in testing stage."""
raise NotImplementedError
3 changes: 1 addition & 2 deletions mmdet/models/dense_heads/ga_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@

from ..builder import HEADS
from .guided_anchor_head import GuidedAnchorHead
from .rpn_test_mixin import RPNTestMixin


@HEADS.register_module()
class GARPNHead(RPNTestMixin, GuidedAnchorHead):
class GARPNHead(GuidedAnchorHead):
"""Guided-Anchor-based RPN head."""

def __init__(self,
Expand Down
3 changes: 1 addition & 2 deletions mmdet/models/dense_heads/rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@

from ..builder import HEADS
from .anchor_head import AnchorHead
from .rpn_test_mixin import RPNTestMixin


@HEADS.register_module()
class RPNHead(RPNTestMixin, AnchorHead):
class RPNHead(AnchorHead):
"""RPN head.

Args:
Expand Down
59 changes: 0 additions & 59 deletions mmdet/models/dense_heads/rpn_test_mixin.py

This file was deleted.

3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/sabl_retina_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
multi_apply, multiclass_nms, unmap)
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin
from .guided_anchor_head import GuidedAnchorHead


@HEADS.register_module()
class SABLRetinaHead(BaseDenseHead):
class SABLRetinaHead(BaseDenseHead, BBoxTestMixin):
"""Side-Aware Boundary Localization (SABL) for RetinaNet.

The anchor generation, assigning and sampling in SABLRetinaHead
Expand Down
48 changes: 48 additions & 0 deletions mmdet/models/dense_heads/yolact_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,12 @@ def get_targets(self, segm_pred, gt_masks, gt_labels):
downsampled_masks[obj_idx])
return segm_targets

def simple_test(self, feats, img_metas, rescale=False, postprocess=True):
"""Test function without test-time augmentation."""
raise NotImplementedError(
'simple_test of YOLACTSegmHead is not implemented '
'because this head is only evaluated during training')


@HEADS.register_module()
class YOLACTProtonet(BaseModule):
Expand Down Expand Up @@ -924,6 +930,48 @@ def sanitize_coordinates(self, x1, x2, img_size, padding=0, cast=True):
x2 = torch.clamp(x2 + padding, max=img_size)
return x1, x2

def simple_test(self,
feats,
img_metas,
rescale=False,
det_bboxes=None,
det_labels=None,
det_coeffs=None):
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
"""Test function without test-time augmentation."""
assert det_bboxes is not None
assert det_labels is not None
assert det_coeffs is not None
num_imgs = len(img_metas)
scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
segm_results = [[[] for _ in range(self.num_classes)]
for _ in range(num_imgs)]
else:
# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
if rescale and not isinstance(scale_factors[0], float):
scale_factors = [
torch.from_numpy(scale_factor).to(det_bboxes[0].device)
for scale_factor in scale_factors
]
_bboxes = [
det_bboxes[i][:, :4] *
scale_factors[i] if rescale else det_bboxes[i][:, :4]
for i in range(len(det_bboxes))
]
mask_preds = self.forward(feats[0], det_coeffs, _bboxes, img_metas)
# apply mask post-processing to each image individually
segm_results = []
for i in range(num_imgs):
if det_bboxes[i].shape[0] == 0:
segm_results.append([[] for _ in range(self.num_classes)])
else:
segm_result = self.get_seg_masks(mask_preds[i],
det_labels[i],
img_metas[i], rescale)
segm_results.append(segm_result)
return segm_results


class InterpolateModule(BaseModule):
"""This is a module version of F.interpolate.
Expand Down
Loading