Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor onnx export of two stage #5205

Merged
merged 25 commits into from
May 27, 2021
2 changes: 1 addition & 1 deletion mmdet/core/export/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def preprocess_example_input(input_config):
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'filename': '<demo>.png',
'scale_factor': 1.0,
'scale_factor': np.ones(4),
'flip': False,
'show_img': show_img,
}
Expand Down
177 changes: 105 additions & 72 deletions mmdet/models/dense_heads/rpn_head.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import copy
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv import ConfigDict
from mmcv.ops import batched_nms

from ..builder import HEADS
Expand Down Expand Up @@ -119,8 +117,6 @@ def _get_bboxes(self,
mlvl_bbox_preds = []
mlvl_valid_anchors = []
batch_size = cls_scores[0].shape[0]
nms_pre_tensor = torch.tensor(
cfg.nms_pre, device=cls_scores[0].device, dtype=torch.long)
for idx in range(len(cls_scores)):
rpn_cls_score = cls_scores[idx]
rpn_bbox_pred = bbox_preds[idx]
Expand All @@ -140,33 +136,16 @@ def _get_bboxes(self,
batch_size, -1, 4)
anchors = mlvl_anchors[idx]
anchors = anchors.expand_as(rpn_bbox_pred)
# Get top-k prediction
from mmdet.core.export import get_k_for_topk
nms_pre = get_k_for_topk(nms_pre_tensor, rpn_bbox_pred.shape[1])
if nms_pre > 0:
_, topk_inds = scores.topk(nms_pre)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and rpn_bbox_pred.size(1) > nms_pre:
# sort is faster than topk
ranked_scores, rank_inds = scores.sort(descending=True)
topk_inds = rank_inds[:, :cfg.nms_pre]
scores = ranked_scores[:, :cfg.nms_pre]
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds)
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
if torch.onnx.is_in_onnx_export():
# Mind k<=3480 in TensorRT for TopK
transformed_inds = scores.shape[1] * batch_inds + topk_inds
scores = scores.reshape(-1, 1)[transformed_inds].reshape(
batch_size, -1)
rpn_bbox_pred = rpn_bbox_pred.reshape(
-1, 4)[transformed_inds, :].reshape(batch_size, -1, 4)
anchors = anchors.reshape(-1,
4)[transformed_inds, :].reshape(
batch_size, -1, 4)
else:
# sort is faster than topk
ranked_scores, rank_inds = scores.sort(descending=True)
topk_inds = rank_inds[:, :cfg.nms_pre]
scores = ranked_scores[:, :cfg.nms_pre]
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds)
rpn_bbox_pred = rpn_bbox_pred[batch_inds, topk_inds, :]
anchors = anchors[batch_inds, topk_inds, :]
rpn_bbox_pred = rpn_bbox_pred[batch_inds, topk_inds, :]
anchors = anchors[batch_inds, topk_inds, :]

mlvl_scores.append(scores)
mlvl_bbox_preds.append(rpn_bbox_pred)
Expand All @@ -186,53 +165,11 @@ def _get_bboxes(self,
batch_mlvl_anchors, batch_mlvl_rpn_bbox_pred, max_shape=img_shapes)
batch_mlvl_ids = torch.cat(level_ids, dim=1)

# deprecate arguments warning
if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
warnings.warn(
'In rpn_proposal or test_cfg, '
'nms_thr has been moved to a dict named nms as '
'iou_threshold, max_num has been renamed as max_per_img, '
'name of original arguments and the way to specify '
'iou_threshold of NMS will be deprecated.')
if 'nms' not in cfg:
cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
if 'max_num' in cfg:
if 'max_per_img' in cfg:
assert cfg.max_num == cfg.max_per_img, f'You ' \
f'set max_num and ' \
f'max_per_img at the same time, but get {cfg.max_num} ' \
f'and {cfg.max_per_img} respectively' \
'Please delete max_num which will be deprecated.'
else:
cfg.max_per_img = cfg.max_num
if 'nms_thr' in cfg:
assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set' \
f' iou_threshold in nms and ' \
f'nms_thr at the same time, but get' \
f' {cfg.nms.iou_threshold} and {cfg.nms_thr}' \
f' respectively. Please delete the nms_thr ' \
f'which will be deprecated.'

# Replace batched_nms with ONNX::NonMaxSuppression in deployment
if torch.onnx.is_in_onnx_export():
from mmdet.core.export import add_dummy_nms_for_onnx
batch_mlvl_scores = batch_mlvl_scores.unsqueeze(2)
score_threshold = cfg.nms.get('score_thr', 0.0)
nms_pre = cfg.get('deploy_nms_pre', -1)
dets, _ = add_dummy_nms_for_onnx(batch_mlvl_proposals,
batch_mlvl_scores,
cfg.max_per_img,
cfg.nms.iou_threshold,
score_threshold, nms_pre,
cfg.max_per_img)
return dets

result_list = []
for (mlvl_proposals, mlvl_scores,
mlvl_ids) in zip(batch_mlvl_proposals, batch_mlvl_scores,
batch_mlvl_ids):
# Skip nonzero op while exporting to ONNX
if cfg.min_bbox_size >= 0 and (not torch.onnx.is_in_onnx_export()):
if cfg.min_bbox_size >= 0:
w = mlvl_proposals[:, 2] - mlvl_proposals[:, 0]
h = mlvl_proposals[:, 3] - mlvl_proposals[:, 1]
valid_ind = torch.nonzero(
Expand All @@ -248,3 +185,99 @@ def _get_bboxes(self,
cfg.nms)
result_list.append(dets[:cfg.max_per_img])
return result_list

# TODO: waiting for refactor the anchor_head and anchor_free head
def onnx_export(self, x, img_metas):
"""Test without augmentation.

Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
img_metas (list[dict]): Meta info of each image.

Returns:
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
and class labels of shape [N, num_det].
"""
cls_scores, bbox_preds = self(x)

assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)

device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)

cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]

assert len(
img_metas
) == 1, 'Only support one input image while in exporting to ONNX'
img_shapes = img_metas[0]['img_shape_for_onnx']

cfg = copy.deepcopy(self.test_cfg)

mlvl_scores = []
mlvl_bbox_preds = []
mlvl_valid_anchors = []
batch_size = cls_scores[0].shape[0]
nms_pre_tensor = torch.tensor(
cfg.nms_pre, device=cls_scores[0].device, dtype=torch.long)
for idx in range(len(cls_scores)):
rpn_cls_score = cls_scores[idx]
rpn_bbox_pred = bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
rpn_cls_score = rpn_cls_score.permute(0, 2, 3, 1)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.reshape(batch_size, -1)
scores = rpn_cls_score.sigmoid()
else:
rpn_cls_score = rpn_cls_score.reshape(batch_size, -1, 2)
# We set FG labels to [0, num_class-1] and BG label to
# num_class in RPN head since mmdet v2.5, which is unified to
# be consistent with other head since mmdet v2.0. In mmdet v2.0
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
scores = rpn_cls_score.softmax(-1)[..., 0]
rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).reshape(
batch_size, -1, 4)
anchors = mlvl_anchors[idx]
anchors = anchors.expand_as(rpn_bbox_pred)
# Get top-k prediction
from mmdet.core.export import get_k_for_topk
nms_pre = get_k_for_topk(nms_pre_tensor, rpn_bbox_pred.shape[1])
if nms_pre > 0:
_, topk_inds = scores.topk(nms_pre)
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds)
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
# Mind k<=3480 in TensorRT for TopK
transformed_inds = scores.shape[1] * batch_inds + topk_inds
scores = scores.reshape(-1, 1)[transformed_inds].reshape(
batch_size, -1)
rpn_bbox_pred = rpn_bbox_pred.reshape(
-1, 4)[transformed_inds, :].reshape(batch_size, -1, 4)
anchors = anchors.reshape(-1, 4)[transformed_inds, :].reshape(
batch_size, -1, 4)
mlvl_scores.append(scores)
mlvl_bbox_preds.append(rpn_bbox_pred)
mlvl_valid_anchors.append(anchors)

batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
batch_mlvl_rpn_bbox_pred = torch.cat(mlvl_bbox_preds, dim=1)
batch_mlvl_proposals = self.bbox_coder.decode(
batch_mlvl_anchors, batch_mlvl_rpn_bbox_pred, max_shape=img_shapes)

# Use ONNX::NonMaxSuppression in deployment
from mmdet.core.export import add_dummy_nms_for_onnx
batch_mlvl_scores = batch_mlvl_scores.unsqueeze(2)
score_threshold = cfg.nms.get('score_thr', 0.0)
nms_pre = cfg.get('deploy_nms_pre', -1)
dets, _ = add_dummy_nms_for_onnx(batch_mlvl_proposals,
batch_mlvl_scores, cfg.max_per_img,
cfg.nms.iou_threshold,
score_threshold, nms_pre,
cfg.max_per_img)
return dets
8 changes: 8 additions & 0 deletions mmdet/models/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ def forward(self, img, img_metas, return_loss=True, **kwargs):
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
"""
if torch.onnx.is_in_onnx_export():
assert len(img_metas) == 1
return self.onnx_export(img[0], img_metas[0])

if return_loss:
return self.forward_train(img, img_metas, **kwargs)
else:
Expand Down Expand Up @@ -339,3 +343,7 @@ def show_result(self,

if not (show or out_file):
return img

def onnx_export(self, img, img_metas):
raise NotImplementedError(f'{self.__class__.__name__} does '
f'not support ONNX EXPORT')
32 changes: 24 additions & 8 deletions mmdet/models/detectors/single_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,9 @@ def simple_test(self, img, img_metas, rescale=False):
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
# get origin input shape to support onnx dynamic shape
if torch.onnx.is_in_onnx_export():
# get shape as tensor
img_shape = torch._shape_as_tensor(img)[2:]
img_metas[0]['img_shape_for_onnx'] = img_shape

bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
# skip post-processing when exporting to ONNX
if torch.onnx.is_in_onnx_export():
return bbox_list

bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
Expand Down Expand Up @@ -135,3 +128,26 @@ def aug_test(self, imgs, img_metas, rescale=False):

feats = self.extract_feats(imgs)
return [self.bbox_head.aug_test(feats, img_metas, rescale=rescale)]

def onnx_export(self, img, img_metas):
"""Test function without test time augmentation.

Args:
img (torch.Tensor): input images.
img_metas (list[dict]): List of image information.

Returns:
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
and class labels of shape [N, num_det].
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
# get origin input shape to support onnx dynamic shape

# get shape as tensor
img_shape = torch._shape_as_tensor(img)[2:]
img_metas[0]['img_shape_for_onnx'] = img_shape
# TODO:move all onnx related code in bbox_head to onnx_export function
jshilong marked this conversation as resolved.
Show resolved Hide resolved
det_bboxes, det_labels = self.bbox_head.get_bboxes(*outs, img_metas)

return det_bboxes, det_labels
16 changes: 9 additions & 7 deletions mmdet/models/detectors/two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,9 @@ async def async_simple_test(self,

def simple_test(self, img, img_metas, proposals=None, rescale=False):
"""Test without augmentation."""
assert self.with_bbox, 'Bbox head must be implemented.'

assert self.with_bbox, 'Bbox head must be implemented.'
x = self.extract_feat(img)

# get origin input shape to onnx dynamic input shape
if torch.onnx.is_in_onnx_export():
img_shape = torch._shape_as_tensor(img)[2:]
img_metas[0]['img_shape_for_onnx'] = img_shape

if proposals is None:
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
else:
Expand All @@ -192,3 +186,11 @@ def aug_test(self, imgs, img_metas, rescale=False):
proposal_list = self.rpn_head.aug_test_rpn(x, img_metas)
return self.roi_head.aug_test(
x, proposal_list, img_metas, rescale=rescale)

def onnx_export(self, img, img_metas):

img_shape = torch._shape_as_tensor(img)[2:]
img_metas[0]['img_shape_for_onnx'] = img_shape
x = self.extract_feat(img)
proposals = self.rpn_head.onnx_export(x, img_metas)
return self.roi_head.onnx_export(x, proposals, img_metas)
Loading