Skip to content

Commit

Permalink
[TOPI] Parallelize GPU NMS inner loop (#7172)
Browse files Browse the repository at this point in the history
* make NMS inner loop parallel

* use one block two avoid global sync issue

* temp disable write by only thread 0

* leave a TODO on write by only one thread

* add some comments, remove check the check on negative class id

* minor improvement when topk is available

* fix write by a single thread
  • Loading branch information
masahi authored Dec 30, 2020
1 parent f2ab977 commit 66e123f
Showing 1 changed file with 32 additions and 18 deletions.
50 changes: 32 additions & 18 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,26 +512,44 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):

with ib.new_scope():
nthread_by = batch_size
nthread_tx = max_threads

by = te.thread_axis("blockIdx.y")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(tx, "thread_extent", nthread_tx)

i = by

base_idx = i * num_anchors * box_data_length
num_valid_boxes_local = ib.allocate(
"int32", (1,), name="num_valid_boxes_local", scope="local"
)
num_valid_boxes_local[0] = 0
nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i])

def nms_inner_loop(ib, j):
# The box j is valid, invalidate other boxes that overlap with j above iou_threshold

# When return_indices is False, no need to populate box_indices
if return_indices:
with ib.if_scope(tx + 0 == 0):
orig_idx = sorted_index[i * num_anchors + j]
box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]

num_valid_boxes_local[0] += 1

offset_j = j * box_data_length
num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx)

with ib.for_range(0, j) as k:
with ib.for_range(0, num_iter_per_thread) as _k:
k = j + 1 + _k * nthread_tx + tx
offset_k = k * box_data_length

with ib.if_scope(
tvm.tir.all(
out[base_idx + offset_j + score_index] > -1.0, # if already surpressed
out[base_idx + offset_k + score_index] > 0,
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
k < nkeep,
out[base_idx + offset_k + score_index] > 0, # is the box k still valid?
tvm.tir.any(
force_suppress > 0,
id_index < 0,
Expand All @@ -546,35 +564,31 @@ def nms_inner_loop(ib, j):
base_idx + offset_k + coord_start,
)
with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_j + score_index] = -1.0
# invalidate the box k
out[base_idx + offset_k + score_index] = -1.0
with ib.if_scope(id_index >= 0):
out[base_idx + offset_j + id_index] = -1.0
out[base_idx + offset_k + id_index] = -1.0

# Has the box j survived IOU tests?
with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0):
# When return_indices is False, no need to populate box_indices
if return_indices:
orig_idx = sorted_index[i * num_anchors + j]
box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]
num_valid_boxes_local[0] += 1
# Make sure to do the next loop in a lock step
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))

if isinstance(max_output_size, int):
max_output_size = tvm.tir.const(max_output_size)

with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Apply nms
with ib.for_range(0, valid_count[i]) as j:
with ib.if_scope(
tvm.tir.any(id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0)
):
with ib.for_range(0, nkeep) as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out[base_idx + (j * box_data_length) + score_index] > -1.0):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we already reach max_output_size boxes
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
nms_inner_loop(ib, j)
with ib.else_scope():
nms_inner_loop(ib, j)

num_valid_boxes[i] = num_valid_boxes_local[0]
with ib.if_scope(tx + 0 == 0):
num_valid_boxes[i] = num_valid_boxes_local[0]

with ib.else_scope():
num_valid_boxes[i] = 0
Expand Down

0 comments on commit 66e123f

Please sign in to comment.