diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2dc177a0fae8..8946446f3cdc 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -151,98 +151,104 @@ def get_valid_indices_ir(valid_boxes, valid_count, valid_indices): valid_indices = ib.buffer_ptr(valid_indices) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - - # Copy boxes to valid_indices - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) - nthread_by = batch_size - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - by = te.thread_axis("blockIdx.y") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - ib.scope_attr(by, "thread_extent", nthread_by) - tid = bx * nthread_tx + tx - with ib.if_scope(tid < num_anchors): - valid_indices[by, tid] = valid_boxes[by, tid] - - nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) - nthread_by = batch_size - - ## The following algorithm performs parallel exclusive scan to get - ## a tensor that can later be used to select valid indices - # Up Sweep of exclusive scan - lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64" - ) - with ib.for_range(0, lim, dtype="int64") as l2_width: - width = 2 << l2_width - + with ib.if_scope(num_anchors > 0): + # Copy boxes to valid_indices with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(num_anchors, max_threads) + nthread_by = batch_size tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr( - bx, - "thread_extent", - tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), - ) - tid = bx * nthread_tx + tx - by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") - start[0] = width * tid - with ib.if_scope(start[0] < num_anchors): - middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - end[0] = tvm.te.min(start[0] + width, num_anchors) - with ib.if_scope(middle[0] < num_anchors): - valid_indices[by * num_anchors + end[0] - 1] += valid_indices[ - by * num_anchors + middle[0] - 1 - ] - - # Down Sweep of exclusive scan - with ib.new_scope(): - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", batch_size) - with ib.if_scope(bx < batch_size): - valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1] - valid_indices[(bx + 1) * num_anchors - 1] = 0 + tid = bx * nthread_tx + tx + with ib.if_scope(tid < num_anchors): + valid_indices[by, tid] = valid_boxes[by, tid] - with ib.for_range(0, lim, dtype="int64") as l2_width: - width = 2 << (lim - l2_width - 1) + nthread_tx = max_threads + nthread_bx = ceil_div(num_anchors, max_threads) + nthread_by = batch_size + ## The following algorithm performs parallel exclusive scan to get + ## a tensor that can later be used to select valid indices + # Up Sweep of exclusive scan + lim = tvm.tir.generic.cast( + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64" + ) + with ib.for_range(0, lim, dtype="int64") as l2_width: + width = 2 << l2_width + + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), + ) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + start[0] = width * tid + with ib.if_scope(start[0] < num_anchors): + middle[0] = start[0] + tvm.tir.indexdiv(width, 2) + end[0] = tvm.te.min(start[0] + width, num_anchors) + with ib.if_scope(middle[0] < num_anchors): + valid_indices[by * num_anchors + end[0] - 1] += valid_indices[ + by * num_anchors + middle[0] - 1 + ] + + # Down Sweep of exclusive scan with ib.new_scope(): - tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr( - bx, - "thread_extent", - tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), - ) - tid = bx * nthread_tx + tx - - by = te.thread_axis("blockIdx.y") - ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") - tmp = ib.allocate("int32", (1,), name="end", scope="local") - start[0] = width * tid - with ib.if_scope(tvm.tir.all(start[0] < num_anchors)): - middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - end[0] = tvm.tir.min(start[0] + width, num_anchors) - with ib.if_scope(middle[0] < num_anchors): - tmp[0] = valid_indices[by * num_anchors + middle[0] - 1] - valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[ - by * num_anchors + end[0] - 1 - ] - valid_indices[by * num_anchors + end[0] - 1] += tmp[0] + ib.scope_attr(bx, "thread_extent", batch_size) + with ib.if_scope(bx < batch_size): + valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1] + valid_indices[(bx + 1) * num_anchors - 1] = 0 + + with ib.for_range(0, lim, dtype="int64") as l2_width: + width = 2 << (lim - l2_width - 1) + + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), + ) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + tmp = ib.allocate("int32", (1,), name="end", scope="local") + start[0] = width * tid + with ib.if_scope(tvm.tir.all(start[0] < num_anchors)): + middle[0] = start[0] + tvm.tir.indexdiv(width, 2) + end[0] = tvm.tir.min(start[0] + width, num_anchors) + with ib.if_scope(middle[0] < num_anchors): + tmp[0] = valid_indices[by * num_anchors + middle[0] - 1] + valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[ + by * num_anchors + end[0] - 1 + ] + valid_indices[by * num_anchors + end[0] - 1] += tmp[0] + with ib.else_scope(): + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", batch_size) + with ib.if_scope(bx < batch_size): + valid_count[bx] = 0 return ib.get()