From 9b42008b42004f5f05cdaa51e2f6feeadf99abb1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 28 Dec 2020 19:50:36 +0900 Subject: [PATCH] add some comments, remove check the check on negative class id --- python/tvm/topi/cuda/nms.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 916c2c758d44..b8c76d5a880f 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -528,14 +528,15 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): num_valid_boxes_local[0] = 0 def nms_inner_loop(ib, j): - # the box j is valid, invalidate other boxes that overlap with j above iou_threshold + # The box j is valid, invalidate other boxes that overlap with j above iou_threshold # When return_indices is False, no need to populate box_indices if return_indices: orig_idx = sorted_index[i * num_anchors + j] box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] - # TODO(masahi): Want to do this instead of above, but the following is eliminated during codegen + # TODO(masahi): Want to do this instead of above, but the following is eliminated + # during codegen # # Only one thread needs to this write # with ib.if_scope(tx == 0): # orig_idx = sorted_index[i * num_anchors + j] @@ -554,7 +555,6 @@ def nms_inner_loop(ib, j): tvm.tir.all( k < num_anchors, out[base_idx + offset_k + score_index] > 0, # is the box k still valid? - tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), tvm.tir.any( force_suppress > 0, id_index < 0, @@ -574,6 +574,7 @@ def nms_inner_loop(ib, j): with ib.if_scope(id_index >= 0): out[base_idx + offset_k + id_index] = -1.0 + # Make sure to do the next loop in a lock step ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) if isinstance(max_output_size, int): @@ -582,14 +583,8 @@ def nms_inner_loop(ib, j): with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms with ib.for_range(0, valid_count[i]) as j: - with ib.if_scope( - tvm.tir.all( - out[base_idx + (j * box_data_length) + score_index] > -1.0, - tvm.tir.any( - id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0 - ), - ) - ): + # Proceed to the inner loop if the box j is still valid + with ib.if_scope(out[base_idx + (j * box_data_length) + score_index] > -1.0): with ib.if_scope(max_output_size > 0): # No need to do more iteration if we already reach max_output_size boxes with ib.if_scope(num_valid_boxes_local[0] < max_output_size): @@ -598,7 +593,8 @@ def nms_inner_loop(ib, j): nms_inner_loop(ib, j) num_valid_boxes[i] = num_valid_boxes_local[0] - # TODO(masahi): Want to do this instead of above, but the following is eliminated during codegen + # TODO(masahi): Want to do this instead of above, but the following is eliminated + # during codegen # with ib.if_scope(tx == 0): # num_valid_boxes[i] = num_valid_boxes_local[0]