From 0c21e36d58f81adeedec1749aeb04ed4e93a7f36 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 29 Dec 2020 04:32:18 +0900 Subject: [PATCH] minor improvement when topk is available --- python/tvm/topi/cuda/nms.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 1f229998b2a0..a6359a9c7867 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -560,6 +560,9 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): "int32", (1,), name="num_valid_boxes_local", scope="local" ) num_valid_boxes_local[0] = 0 + nkeep = if_then_else( + tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i] + ) def nms_inner_loop(ib, j): # The box j is valid, invalidate other boxes that overlap with j above iou_threshold @@ -579,7 +582,7 @@ def nms_inner_loop(ib, j): num_valid_boxes_local[0] += 1 offset_j = j * 4 - num_iter_per_thread = ceil_div(valid_count[i] - (j + 1), nthread_tx) + num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx) with ib.for_range(0, num_iter_per_thread) as _k: k = j + 1 + _k * nthread_tx + tx @@ -587,7 +590,7 @@ def nms_inner_loop(ib, j): with ib.if_scope( tvm.tir.all( - k < valid_count[i], + k < nkeep, out_scores[i, k] > 0, # is the box k still valid? tvm.tir.any( force_suppress > 0, @@ -615,10 +618,6 @@ def nms_inner_loop(ib, j): with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms - nkeep = if_then_else( - tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i] - ) - with ib.for_range(0, nkeep) as j: # Proceed to the inner loop if the box j is still valid with ib.if_scope(out_scores[i, j] > -1.0):