Skip to content

Commit

Permalink
add some comments, remove check the check on negative class id
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 28, 2020
1 parent 0aa375d commit 9b42008
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,14 +528,15 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
num_valid_boxes_local[0] = 0

def nms_inner_loop(ib, j):
# the box j is valid, invalidate other boxes that overlap with j above iou_threshold
# The box j is valid, invalidate other boxes that overlap with j above iou_threshold

# When return_indices is False, no need to populate box_indices
if return_indices:
orig_idx = sorted_index[i * num_anchors + j]
box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]

# TODO(masahi): Want to do this instead of above, but the following is eliminated during codegen
# TODO(masahi): Want to do this instead of above, but the following is eliminated
# during codegen
# # Only one thread needs to this write
# with ib.if_scope(tx == 0):
# orig_idx = sorted_index[i * num_anchors + j]
Expand All @@ -554,7 +555,6 @@ def nms_inner_loop(ib, j):
tvm.tir.all(
k < num_anchors,
out[base_idx + offset_k + score_index] > 0, # is the box k still valid?
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
tvm.tir.any(
force_suppress > 0,
id_index < 0,
Expand All @@ -574,6 +574,7 @@ def nms_inner_loop(ib, j):
with ib.if_scope(id_index >= 0):
out[base_idx + offset_k + id_index] = -1.0

# Make sure to do the next loop in a lock step
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))

if isinstance(max_output_size, int):
Expand All @@ -582,14 +583,8 @@ def nms_inner_loop(ib, j):
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Apply nms
with ib.for_range(0, valid_count[i]) as j:
with ib.if_scope(
tvm.tir.all(
out[base_idx + (j * box_data_length) + score_index] > -1.0,
tvm.tir.any(
id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0
),
)
):
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out[base_idx + (j * box_data_length) + score_index] > -1.0):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we already reach max_output_size boxes
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
Expand All @@ -598,7 +593,8 @@ def nms_inner_loop(ib, j):
nms_inner_loop(ib, j)

num_valid_boxes[i] = num_valid_boxes_local[0]
# TODO(masahi): Want to do this instead of above, but the following is eliminated during codegen
# TODO(masahi): Want to do this instead of above, but the following is eliminated
# during codegen
# with ib.if_scope(tx == 0):
# num_valid_boxes[i] = num_valid_boxes_local[0]

Expand Down

0 comments on commit 9b42008

Please sign in to comment.