Skip to content

Commit

Permalink
[Refactor] Add simple_test to dense heads
Browse files Browse the repository at this point in the history
  • Loading branch information
shinya7y committed Apr 25, 2021
1 parent 8a432d2 commit 7578b6c
Show file tree
Hide file tree
Showing 11 changed files with 210 additions and 88 deletions.
19 changes: 19 additions & 0 deletions mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,22 @@ def forward_train(self,
else:
proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
return losses, proposal_list

def simple_test(self, feats, img_metas, rescale=False, postprocess=True):
"""Test function without test-time augmentation.
Args:
feats (tuple[torch.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
postprocess (bool, optional): Whether to perform post-processing
by bbox2result. Defaults to True.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
return self.simple_test_bboxes(
feats, img_metas, rescale=rescale, postprocess=postprocess)
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/corner_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..builder import HEADS, build_loss
from ..utils import gaussian_radius, gen_gaussian_target
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin


class BiCornerPool(nn.Module):
Expand Down Expand Up @@ -74,7 +75,7 @@ def forward(self, x):


@HEADS.register_module()
class CornerHead(BaseDenseHead):
class CornerHead(BaseDenseHead, BBoxTestMixin):
"""Head of CornerNet: Detecting Objects as Paired Keypoints.
Code is modified from the `official github repo
Expand Down
32 changes: 32 additions & 0 deletions mmdet/models/dense_heads/dense_test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,38 @@
class BBoxTestMixin(object):
"""Mixin class for test time augmentation of bboxes."""

def simple_test_bboxes(self,
feats,
img_metas,
rescale=False,
postprocess=True):
"""Test det bboxes without test-time augmentation.
Args:
feats (tuple[torch.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
postprocess (bool, optional): Whether to perform post-processing
by bbox2result. Defaults to True.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
outs = self.forward(feats)
bbox_list = self.get_bboxes(*outs, img_metas, rescale=rescale)
if postprocess:
bbox_results = [
bbox2result(det_bboxes, det_labels, self.num_classes)
for det_bboxes, det_labels in bbox_list
]
return bbox_results
else:
return bbox_list

def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
"""Merge augmented detection bboxes and scores.
Expand Down
4 changes: 4 additions & 0 deletions mmdet/models/dense_heads/embedding_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,7 @@ def forward_train(self, img, img_metas):
def simple_test_rpn(self, img, img_metas):
"""Forward function in testing stage."""
return self._decode_init_proposals(img, img_metas)

def simple_test(self, img, img_metas):
"""Forward function in testing stage."""
raise NotImplementedError
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/sabl_retina_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
multi_apply, multiclass_nms, unmap)
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin
from .guided_anchor_head import GuidedAnchorHead


@HEADS.register_module()
class SABLRetinaHead(BaseDenseHead):
class SABLRetinaHead(BaseDenseHead, BBoxTestMixin):
"""Side-Aware Boundary Localization (SABL) for RetinaNet.
The anchor generation, assigning and sampling in SABLRetinaHead
Expand Down
39 changes: 38 additions & 1 deletion mmdet/models/dense_heads/transformer_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mmcv.cnn import Conv2d, Linear, build_activation_layer
from mmcv.runner import force_fp32

from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh,
from mmdet.core import (bbox2result, bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh,
build_assigner, build_sampler, multi_apply,
reduce_mean)
from mmdet.models.utils import (FFN, build_positional_encoding,
Expand Down Expand Up @@ -652,3 +652,40 @@ def _get_bboxes_single(self,
det_bboxes /= det_bboxes.new_tensor(scale_factor)
det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(1)), -1)
return det_bboxes, det_labels

def simple_test_bboxes(self,
feats,
img_metas,
rescale=False,
postprocess=True):
"""Test det bboxes without test-time augmentation.
Args:
feats (tuple[torch.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
postprocess (bool, optional): Whether to perform post-processing
by bbox2result. Defaults to True.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
batch_size = len(img_metas)
assert batch_size == 1, 'Currently only batch_size 1 for inference ' \
f'mode is supported. Found batch_size {batch_size}.'

# forward of this head requires img_metas
outs = self.forward(feats, img_metas)
bbox_list = self.get_bboxes(*outs, img_metas, rescale=rescale)
if postprocess:
bbox_results = [
bbox2result(det_bboxes, det_labels, self.num_classes)
for det_bboxes, det_labels in bbox_list
]
return bbox_results
else:
return bbox_list
47 changes: 47 additions & 0 deletions mmdet/models/dense_heads/yolact_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,12 @@ def get_targets(self, segm_pred, gt_masks, gt_labels):
downsampled_masks[obj_idx])
return segm_targets

def simple_test(self, feats, img_metas, rescale=False, postprocess=True):
"""Test function without test-time augmentation."""
raise NotImplementedError(
'simple_test of YOLACTSegmHead is not implemented '
'because this head is only evaluated during training')


@HEADS.register_module()
class YOLACTProtonet(nn.Module):
Expand Down Expand Up @@ -925,6 +931,47 @@ def sanitize_coordinates(self, x1, x2, img_size, padding=0, cast=True):
x2 = torch.clamp(x2 + padding, max=img_size)
return x1, x2

def simple_test(self,
feats,
img_metas,
rescale=False,
det_bboxes=None,
det_labels=None,
det_coeffs=None):
assert det_bboxes
assert det_labels
assert det_coeffs
num_imgs = len(img_metas)
scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
segm_results = [[[] for _ in range(self.num_classes)]
for _ in range(num_imgs)]
else:
# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
if rescale and not isinstance(scale_factors[0], float):
scale_factors = [
torch.from_numpy(scale_factor).to(det_bboxes[0].device)
for scale_factor in scale_factors
]
_bboxes = [
det_bboxes[i][:, :4] *
scale_factors[i] if rescale else det_bboxes[i][:, :4]
for i in range(len(det_bboxes))
]
mask_preds = self.forward(feats[0], det_coeffs, _bboxes, img_metas)
# apply mask post-processing to each image individually
segm_results = []
for i in range(num_imgs):
if det_bboxes[i].shape[0] == 0:
segm_results.append([[] for _ in range(self.num_classes)])
else:
segm_result = self.get_seg_masks(mask_preds[i],
det_labels[i],
img_metas[i], rescale)
segm_results.append(segm_result)
return segm_results


class InterpolateModule(nn.Module):
"""This is a module version of F.interpolate.
Expand Down
29 changes: 0 additions & 29 deletions mmdet/models/detectors/detr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from mmdet.core import bbox2result
from ..builder import DETECTORS
from .single_stage import SingleStageDetector

Expand All @@ -16,31 +15,3 @@ def __init__(self,
pretrained=None):
super(DETR, self).__init__(backbone, None, bbox_head, train_cfg,
test_cfg, pretrained)

def simple_test(self, img, img_metas, rescale=False):
"""Test function without test time augmentation.
Args:
imgs (list[torch.Tensor]): List of multiple images
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
batch_size = len(img_metas)
assert batch_size == 1, 'Currently only batch_size 1 for inference ' \
f'mode is supported. Found batch_size {batch_size}.'
x = self.extract_feat(img)
outs = self.bbox_head(x, img_metas)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)

bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list
]
return bbox_results
29 changes: 11 additions & 18 deletions mmdet/models/detectors/single_stage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn

from mmdet.core import bbox2result
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import BaseDetector

Expand Down Expand Up @@ -95,38 +94,32 @@ def forward_train(self,
gt_labels, gt_bboxes_ignore)
return losses

def simple_test(self, img, img_metas, rescale=False):
"""Test function without test time augmentation.
def simple_test(self, img, img_metas, rescale=False, postprocess=True):
"""Test function without test-time augmentation.
Args:
imgs (list[torch.Tensor]): List of multiple images
img (torch.Tensor): Images with shape (N, C, H, W).
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
postprocess (bool, optional): Whether to perform post-processing
by bbox2result. Defaults to True.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
# get origin input shape to support onnx dynamic shape
if torch.onnx.is_in_onnx_export():
# get shape as tensor
# get origin input shape as tensor to support onnx dynamic shape
img_shape = torch._shape_as_tensor(img)[2:]
img_metas[0]['img_shape_for_onnx'] = img_shape
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
# skip post-processing when exporting to ONNX
if torch.onnx.is_in_onnx_export():
return bbox_list
# skip post-processing when exporting to ONNX
postprocess = False

bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list
]
return bbox_results
feat = self.extract_feat(img)
return self.bbox_head.simple_test(
feat, img_metas, rescale=rescale, postprocess=postprocess)

def aug_test(self, imgs, img_metas, rescale=False):
"""Test function with test time augmentation.
Expand Down
50 changes: 12 additions & 38 deletions mmdet/models/detectors/yolact.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,49 +96,23 @@ def forward_train(self,
return losses

def simple_test(self, img, img_metas, rescale=False):
"""Test function without test time augmentation."""
x = self.extract_feat(img)

cls_score, bbox_pred, coeff_pred = self.bbox_head(x)

bbox_inputs = (cls_score, bbox_pred,
coeff_pred) + (img_metas, self.test_cfg, rescale)
det_bboxes, det_labels, det_coeffs = self.bbox_head.get_bboxes(
*bbox_inputs)
"""Test function without test-time augmentation."""
feat = self.extract_feat(img)
det_bboxes, det_labels, det_coeffs = self.bbox_head.simple_test(
feat, img_metas, rescale=rescale, postprocess=False)
bbox_results = [
bbox2result(det_bbox, det_label, self.bbox_head.num_classes)
for det_bbox, det_label in zip(det_bboxes, det_labels)
]

num_imgs = len(img_metas)
scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
segm_results = [[[] for _ in range(self.mask_head.num_classes)]
for _ in range(num_imgs)]
else:
# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
if rescale and not isinstance(scale_factors[0], float):
scale_factors = [
torch.from_numpy(scale_factor).to(det_bboxes[0].device)
for scale_factor in scale_factors
]
_bboxes = [
det_bboxes[i][:, :4] *
scale_factors[i] if rescale else det_bboxes[i][:, :4]
for i in range(len(det_bboxes))
]
mask_preds = self.mask_head(x[0], det_coeffs, _bboxes, img_metas)
# apply mask post-processing to each image individually
segm_results = []
for i in range(num_imgs):
if det_bboxes[i].shape[0] == 0:
segm_results.append(
[[] for _ in range(self.mask_head.num_classes)])
else:
segm_result = self.mask_head.get_seg_masks(
mask_preds[i], det_labels[i], img_metas[i], rescale)
segm_results.append(segm_result)
segm_results = self.mask_head.simple_test(
feat,
img_metas,
rescale=rescale,
det_bboxes=det_bboxes,
det_labels=det_labels,
det_coeffs=det_coeffs)

return list(zip(bbox_results, segm_results))

def aug_test(self, imgs, img_metas, rescale=False):
Expand Down
Loading

0 comments on commit 7578b6c

Please sign in to comment.