Skip to content

Commit

Permalink
leave a TODO on write by only one thread
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 28, 2020
1 parent d75ee0a commit 0aa375d
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,20 +530,21 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
def nms_inner_loop(ib, j):
# the box j is valid, invalidate other boxes that overlap with j above iou_threshold

# # When return_indices is False, no need to populate box_indices
# if return_indices:
# # Only one thread needs to this write
# with ib.if_scope(tx == 0):
# orig_idx = sorted_index[i * num_anchors + j]
# box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]
# When return_indices is False, no need to populate box_indices
if return_indices:
orig_idx = sorted_index[i * num_anchors + j]
box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]

orig_idx = sorted_index[i * num_anchors + j]
box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]
# TODO(masahi): Want to do this instead of above, but the following is eliminated during codegen
# # Only one thread needs to this write
# with ib.if_scope(tx == 0):
# orig_idx = sorted_index[i * num_anchors + j]
# box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]

num_valid_boxes_local[0] += 1

offset_j = j * box_data_length
num_iter_per_thread = ceil_div(num_anchors - (j + 1), nthread_tx)
num_iter_per_thread = ceil_div(valid_count[i] - (j + 1), nthread_tx)

with ib.for_range(0, num_iter_per_thread) as _k:
k = j + 1 + _k * nthread_tx + tx
Expand All @@ -552,7 +553,7 @@ def nms_inner_loop(ib, j):
with ib.if_scope(
tvm.tir.all(
k < num_anchors,
out[base_idx + offset_k + score_index] > 0, # is the box k still valid?
out[base_idx + offset_k + score_index] > 0, # is the box k still valid?
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
tvm.tir.any(
force_suppress > 0,
Expand Down Expand Up @@ -596,8 +597,10 @@ def nms_inner_loop(ib, j):
with ib.else_scope():
nms_inner_loop(ib, j)

# with ib.if_scope(tx == 0):
num_valid_boxes[i] = num_valid_boxes_local[0]
# TODO(masahi): Want to do this instead of above, but the following is eliminated during codegen
# with ib.if_scope(tx == 0):
# num_valid_boxes[i] = num_valid_boxes_local[0]

with ib.else_scope():
num_valid_boxes[i] = 0
Expand Down

0 comments on commit 0aa375d

Please sign in to comment.