From 3a273975b1456991fd3f70e055cd5f7c2cdd79fe Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 28 Dec 2020 21:22:16 +0900 Subject: [PATCH] slight change to initialization --- python/tvm/topi/cuda/nms.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index b8c76d5a880f..925d108e5abd 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -490,8 +490,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i] ) j = bx * max_threads + tx - with ib.if_scope(j < num_anchors): - box_indices[i * num_anchors + j] = -1 with ib.if_scope(j < nkeep): # Fill in out with sorted boxes with ib.for_range(0, box_data_length) as k: @@ -500,9 +498,16 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): ] with ib.else_scope(): # Indices > nkeep are discarded + # Only needed for return_indices = False case + if return_indices is False: + with ib.if_scope(j < num_anchors): + with ib.for_range(0, box_data_length) as k: + out[(base_idx + j * box_data_length + k)] = -1.0 + + if return_indices: with ib.if_scope(j < num_anchors): - with ib.for_range(0, box_data_length) as k: - out[(base_idx + j * box_data_length + k)] = -1.0 + box_indices[i * num_anchors + j] = -1 + with ib.else_scope(): with ib.if_scope(j < valid_count[i]): with ib.for_range(0, box_data_length) as k: