diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index c71ea4d728a20..8a372f2cabd09 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -547,6 +547,9 @@ def nms_inner_loop(ib, j): box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] num_valid_boxes_local[0] += 1 + if isinstance(max_output_size, int): + max_output_size = tvm.tir.const(max_output_size) + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms with ib.for_range(0, valid_count[i]) as j: