Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Simplify non batch nms #99

Merged
merged 1 commit into from
Jan 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 115 additions & 22 deletions mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import mmdeploy
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.mmcv.ops import ONNXNMSop, TRTBatchedNMSop
from mmdeploy.utils import is_dynamic_batch


def select_nms_index(scores: torch.Tensor,
Expand Down Expand Up @@ -82,28 +83,10 @@ def _multiclass_nms(boxes: Tensor,
keep_top_k: int = -1):
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.

This function helps exporting to onnx with batch and multiclass NMS op.
It only supports class-agnostic detection results. That is, the scores
is of shape (N, num_bboxes, num_classes) and the boxes is of shape
(N, num_boxes, 4).

Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
score_threshold (float): score threshold of nms.
Defaults to 0.05.
pre_top_k (int): Number of top K boxes to keep before nms.
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.

Returns:
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
and `labels` of shape [N, num_det].
This function helps exporting to onnx with batch and multiclass NMS op. It
only supports class-agnostic detection results. That is, the scores is of
shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
4).
"""
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
Expand All @@ -129,6 +112,116 @@ def _multiclass_nms(boxes: Tensor,
return dets, labels


def _multiclass_nms_single(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = -1):
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.

Single batch nms could be optimized.
"""
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)

# pre topk
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.squeeze(0).topk(pre_top_k)
boxes = boxes[:, topk_inds, :]
scores = scores[:, topk_inds, :]

scores = scores.permute(0, 2, 1)
selected_indices = ONNXNMSop.apply(boxes, scores,
max_output_boxes_per_class,
iou_threshold, score_threshold)

cls_inds = selected_indices[:, 1]
box_inds = selected_indices[:, 2]

scores = scores[:, cls_inds, box_inds].unsqueeze(2)
boxes = boxes[:, box_inds, ...]
dets = torch.cat([boxes, scores], dim=2)
labels = cls_inds.unsqueeze(0)

# pad
dets = torch.cat((dets, dets.new_zeros((1, 1, 5))), 1)
labels = torch.cat((labels, labels.new_zeros((1, 1))), 1)

# topk or sort
is_use_topk = keep_top_k > 0 and \
(torch.onnx.is_in_onnx_export() or keep_top_k < dets.shape[1])
if is_use_topk:
_, topk_inds = dets[:, :, -1].topk(keep_top_k, dim=1)
else:
_, topk_inds = dets[:, :, -1].sort(dim=1, descending=True)
topk_inds = topk_inds.squeeze(0)
dets = dets[:, topk_inds, ...]
labels = labels[:, topk_inds, ...]

return dets, labels


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms')
def multiclass_nms__default(ctx,
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = -1):
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.

This function helps exporting to onnx with batch and multiclass NMS op.
It only supports class-agnostic detection results. That is, the scores
is of shape (N, num_bboxes, num_classes) and the boxes is of shape
(N, num_boxes, 4).

Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
score_threshold (float): score threshold of nms.
Defaults to 0.05.
pre_top_k (int): Number of top K boxes to keep before nms.
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.

Returns:
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
and `labels` of shape [N, num_det].
"""
deploy_cfg = ctx.cfg
batch_size = boxes.size(0)
if not is_dynamic_batch(deploy_cfg) and batch_size != 1:
return _multiclass_nms_single(
boxes,
scores,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
else:
return _multiclass_nms(
boxes,
scores,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms',
backend='tensorrt')
Expand Down