Skip to content

Commit

Permalink
simplify non batch nms (open-mmlab#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
q.yao authored Jan 26, 2022
1 parent a543d41 commit f2d0b15
Showing 1 changed file with 115 additions and 22 deletions.
137 changes: 115 additions & 22 deletions mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import mmdeploy
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.mmcv.ops import ONNXNMSop, TRTBatchedNMSop
from mmdeploy.utils import is_dynamic_batch


def select_nms_index(scores: torch.Tensor,
Expand Down Expand Up @@ -82,28 +83,10 @@ def _multiclass_nms(boxes: Tensor,
keep_top_k: int = -1):
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
This function helps exporting to onnx with batch and multiclass NMS op.
It only supports class-agnostic detection results. That is, the scores
is of shape (N, num_bboxes, num_classes) and the boxes is of shape
(N, num_boxes, 4).
Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
score_threshold (float): score threshold of nms.
Defaults to 0.05.
pre_top_k (int): Number of top K boxes to keep before nms.
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
Returns:
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
and `labels` of shape [N, num_det].
This function helps exporting to onnx with batch and multiclass NMS op. It
only supports class-agnostic detection results. That is, the scores is of
shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
4).
"""
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
Expand All @@ -129,6 +112,116 @@ def _multiclass_nms(boxes: Tensor,
return dets, labels


def _multiclass_nms_single(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = -1):
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
Single batch nms could be optimized.
"""
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)

# pre topk
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.squeeze(0).topk(pre_top_k)
boxes = boxes[:, topk_inds, :]
scores = scores[:, topk_inds, :]

scores = scores.permute(0, 2, 1)
selected_indices = ONNXNMSop.apply(boxes, scores,
max_output_boxes_per_class,
iou_threshold, score_threshold)

cls_inds = selected_indices[:, 1]
box_inds = selected_indices[:, 2]

scores = scores[:, cls_inds, box_inds].unsqueeze(2)
boxes = boxes[:, box_inds, ...]
dets = torch.cat([boxes, scores], dim=2)
labels = cls_inds.unsqueeze(0)

# pad
dets = torch.cat((dets, dets.new_zeros((1, 1, 5))), 1)
labels = torch.cat((labels, labels.new_zeros((1, 1))), 1)

# topk or sort
is_use_topk = keep_top_k > 0 and \
(torch.onnx.is_in_onnx_export() or keep_top_k < dets.shape[1])
if is_use_topk:
_, topk_inds = dets[:, :, -1].topk(keep_top_k, dim=1)
else:
_, topk_inds = dets[:, :, -1].sort(dim=1, descending=True)
topk_inds = topk_inds.squeeze(0)
dets = dets[:, topk_inds, ...]
labels = labels[:, topk_inds, ...]

return dets, labels


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms')
def multiclass_nms__default(ctx,
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = -1):
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
This function helps exporting to onnx with batch and multiclass NMS op.
It only supports class-agnostic detection results. That is, the scores
is of shape (N, num_bboxes, num_classes) and the boxes is of shape
(N, num_boxes, 4).
Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
score_threshold (float): score threshold of nms.
Defaults to 0.05.
pre_top_k (int): Number of top K boxes to keep before nms.
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
Returns:
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
and `labels` of shape [N, num_det].
"""
deploy_cfg = ctx.cfg
batch_size = boxes.size(0)
if not is_dynamic_batch(deploy_cfg) and batch_size != 1:
return _multiclass_nms_single(
boxes,
scores,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
else:
return _multiclass_nms(
boxes,
scores,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms',
backend='tensorrt')
Expand Down

0 comments on commit f2d0b15

Please sign in to comment.