From 20b563031adf56f93a7bcfe5b853c477175f4f80 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 26 Dec 2020 10:06:43 +0900 Subject: [PATCH] use one block two avoid global sync issue --- python/tvm/topi/cuda/nms.py | 64 ++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 56f63ce021c3..65f7e3950e1c 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -513,17 +513,14 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.new_scope(): nthread_by = batch_size nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) by = te.thread_axis("blockIdx.y") tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(by, "thread_extent", nthread_by) ib.scope_attr(by, "thread_extent", nthread_by) ib.scope_attr(tx, "thread_extent", nthread_tx) i = by - k = bx * nthread_tx + tx + base_idx = i * num_anchors * box_data_length num_valid_boxes_local = ib.allocate( "int32", (1,), name="num_valid_boxes_local", scope="local" @@ -531,43 +528,49 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): num_valid_boxes_local[0] = 0 def nms_inner_loop(ib, j): - # box j is valid, invalidate other boxes that overlap with j above iou_threshold + # the box j is valid, invalidate other boxes that overlap with j above iou_threshold # When return_indices is False, no need to populate box_indices if return_indices: # Only one thread needs to this write - with ib.if_scope(k == 0): + with ib.if_scope(tx == 0): orig_idx = sorted_index[i * num_anchors + j] box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] num_valid_boxes_local[0] += 1 offset_j = j * box_data_length - offset_k = k * box_data_length + num_iter_per_thread = ceil_div(num_anchors - (j + 1), nthread_tx) - with ib.if_scope( - tvm.tir.all( - j < k, - out[base_idx + offset_k + score_index] > 0, - tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), - tvm.tir.any( - force_suppress > 0, - id_index < 0, - out[base_idx + offset_k + id_index] == out[base_idx + offset_j + id_index], - ), - ) - ): - iou = calculate_overlap( - out, - base_idx + offset_j + coord_start, - base_idx + offset_k + coord_start, - ) - with ib.if_scope(iou >= iou_threshold): - out[base_idx + offset_k + score_index] = -1.0 - with ib.if_scope(id_index >= 0): - out[base_idx + offset_k + id_index] = -1.0 + with ib.for_range(0, num_iter_per_thread) as _k: + k = j + 1 + _k * nthread_tx + tx + offset_k = k * box_data_length + + with ib.if_scope( + tvm.tir.all( + k < num_anchors, + out[base_idx + offset_k + score_index] > 0, # is the box k still valid? + tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), + tvm.tir.any( + force_suppress > 0, + id_index < 0, + out[base_idx + offset_k + id_index] + == out[base_idx + offset_j + id_index], + ), + ) + ): + iou = calculate_overlap( + out, + base_idx + offset_j + coord_start, + base_idx + offset_k + coord_start, + ) + with ib.if_scope(iou >= iou_threshold): + # invalidate the box k + out[base_idx + offset_k + score_index] = -1.0 + with ib.if_scope(id_index >= 0): + out[base_idx + offset_k + id_index] = -1.0 - ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) if isinstance(max_output_size, int): max_output_size = tvm.tir.const(max_output_size) @@ -590,7 +593,8 @@ def nms_inner_loop(ib, j): with ib.else_scope(): nms_inner_loop(ib, j) - num_valid_boxes[i] = num_valid_boxes_local[0] + with ib.if_scope(tx == 0): + num_valid_boxes[i] = num_valid_boxes_local[0] with ib.else_scope(): num_valid_boxes[i] = 0