Skip to content

Commit

Permalink
all class nms tf mode first cut
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 5f349f7 commit cde4a1f
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 21 deletions.
129 changes: 113 additions & 16 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust
from tvm.ir import register_intrin_lowering
from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust
from .sort import argsort, argsort_thrust, topk
from .scan import exclusive_scan
from ..utils import ceil_div
from ..math import cast
from ..transform import reshape
from .. import reduction
from ..broadcast import minimum
from ..transform import reshape, strided_slice, gather_nd
from ..vision.nms_util import (
calculate_overlap,
binary_search,
Expand Down Expand Up @@ -988,8 +990,83 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro
return ib.get()


def _collect_selected_indices_tf_ir(
num_class,
selected_indices,
selected_scores,
num_detections,
row_offsets,
collected_indices,
collected_scores,
):
batch_size, num_class = row_offsets.shape
num_boxes = selected_indices.shape[1]

ib = tvm.tir.ir_builder.create()

selected_indices = ib.buffer_ptr(selected_indices)
num_detections = ib.buffer_ptr(num_detections)
row_offsets = ib.buffer_ptr(row_offsets)
collected_indices = ib.buffer_ptr(collected_indices)
collected_scores = ib.buffer_ptr(collected_scores)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = ceil_div(num_boxes, nthread_tx)
nthread_by = batch_size * num_class
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)

with ib.new_scope():
idx = bx * nthread_tx + tx
idy = cast(by, "int64")
batch_id = idy // num_class
class_id = idy % num_class
with ib.if_scope(idx < num_detections[batch_id, class_id]):
offset = row_offsets[batch_id, class_id]
collected_indices[batch_id, offset + idx, 0] = class_id
collected_indices[batch_id, offset + idx, 1] = cast(selected_indices[idy, idx], "int64")
collected_scores[batch_id, offset + idx] = selected_scores[idy, idx]

return ib.get()


def collect_selected_indices_tf(selected_indices, selected_scores, num_detections, row_offsets):
batch_size, num_class = row_offsets.shape
num_boxes = selected_indices.shape[1]

selected_indices_buf = tvm.tir.decl_buffer(
selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8
)
selected_scores_buf = tvm.tir.decl_buffer(
selected_scores.shape, selected_indices.dtype, "selected_scores_buf", data_alignment=8
)
num_detections_buf = tvm.tir.decl_buffer(
num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8
)
row_offsets_buf = tvm.tir.decl_buffer(
row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8
)

return te.extern(
[(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)],
[selected_indices, selected_scores, num_detections, row_offsets],
lambda ins, outs: _collect_selected_indices_tf_ir(
num_class, ins[0], ins[1], ins[2], ins[3], outs[0], outs[1]
),
dtype=["int64"],
in_buffers=[selected_indices_buf, selected_scores_buf, num_detections_buf, row_offsets_buf],
name="collect_indices",
tag="collect_indices",
)


def all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, max_detection_per_batch=-1, output_format="onnx"
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format="onnx"
):
"""Non-maximum suppression operator for object detection, corresponding to ONNX
NonMaxSuppression and TensorFlow combined_non_max_suppression.
Expand Down Expand Up @@ -1027,24 +1104,21 @@ def all_class_non_max_suppression(
"""
batch, num_class, num_boxes = scores.shape

if max_detection_per_batch == -1:
max_detection_per_batch = num_class * num_boxes

scores = reshape(scores, (batch * num_class, num_boxes))
sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both")
valid_count = _get_valid_box_count(sorted_scores, score_threshold)

selected_indices, num_detections = run_all_class_nms(
boxes,
sorted_scores,
sorted_indices,
valid_count,
max_output_boxes_per_class,
iou_threshold,
_nms_loop,
)

if output_format == "onnx":
selected_indices, num_detections = run_all_class_nms(
boxes,
sorted_scores,
sorted_indices,
valid_count,
max_output_boxes_per_class,
iou_threshold,
_nms_loop,
)

row_offsets, num_total_detections = exclusive_scan(
num_detections, return_reduction=True, output_dtype="int64"
)
Expand All @@ -1053,8 +1127,31 @@ def all_class_non_max_suppression(
)
return [selected_indices, num_total_detections]

max_detection_per_batch = 100

selected_indices, selected_scores, num_detections = run_all_class_nms(
boxes,
sorted_scores,
sorted_indices,
valid_count,
max_output_boxes_per_class,
iou_threshold,
_nms_loop,
True,
)

# tf mode, return (batch_size, max_total_size, 2)
num_detections_per_batch = reshape(num_detections, (batch, num_class))
row_offsets, num_total_detections = exclusive_scan(
num_detections_per_batch, return_reduction=True, output_dtype="int64", axis=1
)
selected_indices, selected_scores = collect_selected_indices_tf(
selected_indices, selected_scores, num_detections_per_batch, row_offsets
)
selected_scores = strided_slice(
selected_scores, begin=[0, 0], end=[batch, reduction.max(num_total_detections)]
)
topk_indices = topk(selected_scores, k=max_detection_per_batch, axis=1, ret_type="indices")
final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1)
num_detections = minimum(num_total_detections, max_detection_per_batch)
return [final_indices, num_detections]
4 changes: 2 additions & 2 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def gather(data, axis, indices):
return cpp.gather(data, axis, indices)


def gather_nd(a, indices):
def gather_nd(a, indices, batch_dims=0):
"""Gather elements from a n-dimension array..
Parameters
Expand All @@ -498,7 +498,7 @@ def gather_nd(a, indices):
-------
ret : tvm.te.Tensor
"""
return cpp.gather_nd(a, indices)
return cpp.gather_nd(a, indices, batch_dims)


def matmul(a, b, transp_a=False, transp_b=False):
Expand Down
43 changes: 41 additions & 2 deletions python/tvm/topi/vision/nms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def _all_class_nms_ir(
iou_threshold,
max_output_size_per_class,
box_indices,
selected_scores,
num_valid_boxes,
nms_loop,
):
Expand All @@ -203,6 +204,9 @@ def _all_class_nms_ir(
box_indices = ib.buffer_ptr(box_indices)
num_valid_boxes = ib.buffer_ptr(num_valid_boxes)

if selected_scores is not None:
selected_scores = ib.buffer_ptr(selected_scores)

if isinstance(iou_threshold, float):
iou_threshold = tvm.tir.FloatImm("float32", iou_threshold)

Expand All @@ -224,6 +228,9 @@ def on_new_valid_box(ib, tid, num_current_valid_box, i, j):
with ib.if_scope(tid + 0 == 0):
box_indices[i, num_current_valid_box] = sorted_indices[i, j]

if selected_scores is not None:
selected_scores[i, num_current_valid_box] = sorted_scores[i, j]

def on_new_invalidated_box(*_):
pass

Expand Down Expand Up @@ -254,6 +261,7 @@ def run_all_class_nms(
max_output_size_per_class,
iou_threshold,
nms_loop,
return_scores=False
):
"""The core all class NMS routine
Expand Down Expand Up @@ -306,8 +314,38 @@ def run_all_class_nms(
valid_count.shape, "int32", "valid_count_buf", data_alignment=4
)

if return_scores:
return te.extern(
[(batch_class, num_boxes), (1, batch_class)],
[boxes, sorted_scores, sorted_indices, valid_count],
lambda ins, outs: _all_class_nms_ir(
ins[0], # boxes
ins[1], # sorted_scores
ins[2], # sorted_indices
ins[3], # valid_count
batch_class,
num_class,
num_boxes,
iou_threshold,
max_output_size_per_class,
outs[0], # box_indices
None, # scores
outs[1], # num_selected_boxes
nms_loop,
),
dtype=["int32", "int32"],
in_buffers=[
boxes_buf,
sorted_scores_buf,
sorted_indices_buf,
valid_count_buf,
],
name="all_class_nms",
tag="all_class_nms",
)

return te.extern(
[(batch_class, num_boxes), (1, batch_class)],
[(batch_class, num_boxes), (batch_class, num_boxes), (1, batch_class)],
[boxes, sorted_scores, sorted_indices, valid_count],
lambda ins, outs: _all_class_nms_ir(
ins[0], # boxes
Expand All @@ -320,7 +358,8 @@ def run_all_class_nms(
iou_threshold,
max_output_size_per_class,
outs[0], # box_indices
outs[1], # num_selected_boxes
outs[1], # selected scores
outs[2], # num_selected_boxes
nms_loop,
),
dtype=["int32", "int32"],
Expand Down
2 changes: 1 addition & 1 deletion src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ TVM_REGISTER_GLOBAL("topi.gather").set_body([](TVMArgs args, TVMRetValue* rv) {
});

TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = gather_nd(args[0], args[1]);
*rv = gather_nd(args[0], args[1], args[2]);
});

TVM_REGISTER_GLOBAL("topi.unravel_index").set_body([](TVMArgs args, TVMRetValue* rv) {
Expand Down

0 comments on commit cde4a1f

Please sign in to comment.