Skip to content

Commit

Permalink
make NMS inner loop parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 28, 2020
1 parent 2dec2dd commit dd1e230
Showing 1 changed file with 49 additions and 33 deletions.
82 changes: 49 additions & 33 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,51 +512,62 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):

with ib.new_scope():
nthread_by = batch_size
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)

by = te.thread_axis("blockIdx.y")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(tx, "thread_extent", nthread_tx)

i = by
k = bx * nthread_tx + tx
base_idx = i * num_anchors * box_data_length
num_valid_boxes_local = ib.allocate(
"int32", (1,), name="num_valid_boxes_local", scope="local"
)
num_valid_boxes_local[0] = 0

def nms_inner_loop(ib, j):
offset_j = j * box_data_length
# box j is valid, invalidate other boxes that overlap with j above iou_threshold

with ib.for_range(0, j) as k:
offset_k = k * box_data_length

with ib.if_scope(
tvm.tir.all(
out[base_idx + offset_j + score_index] > -1.0, # if already surpressed
out[base_idx + offset_k + score_index] > 0,
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
tvm.tir.any(
force_suppress > 0,
id_index < 0,
out[base_idx + offset_k + id_index]
== out[base_idx + offset_j + id_index],
),
)
):
iou = calculate_overlap(
out,
base_idx + offset_j + coord_start,
base_idx + offset_k + coord_start,
)
with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_j + score_index] = -1.0
with ib.if_scope(id_index >= 0):
out[base_idx + offset_j + id_index] = -1.0

# Has the box j survived IOU tests?
with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0):
# When return_indices is False, no need to populate box_indices
if return_indices:
# When return_indices is False, no need to populate box_indices
if return_indices:
# Only one thread needs to this write
with ib.if_scope(k == 0):
orig_idx = sorted_index[i * num_anchors + j]
box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]
num_valid_boxes_local[0] += 1

num_valid_boxes_local[0] += 1

offset_j = j * box_data_length
offset_k = k * box_data_length

with ib.if_scope(
tvm.tir.all(
j < k,
out[base_idx + offset_k + score_index] > 0,
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
tvm.tir.any(
force_suppress > 0,
id_index < 0,
out[base_idx + offset_k + id_index] == out[base_idx + offset_j + id_index],
),
)
):
iou = calculate_overlap(
out,
base_idx + offset_j + coord_start,
base_idx + offset_k + coord_start,
)
with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_k + score_index] = -1.0
with ib.if_scope(id_index >= 0):
out[base_idx + offset_k + id_index] = -1.0

ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))

if isinstance(max_output_size, int):
max_output_size = tvm.tir.const(max_output_size)
Expand All @@ -565,7 +576,12 @@ def nms_inner_loop(ib, j):
# Apply nms
with ib.for_range(0, valid_count[i]) as j:
with ib.if_scope(
tvm.tir.any(id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0)
tvm.tir.all(
out[base_idx + (j * box_data_length) + score_index] > -1.0,
tvm.tir.any(
id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0
),
)
):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we already reach max_output_size boxes
Expand Down

0 comments on commit dd1e230

Please sign in to comment.