diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2dafea0de5360..77795e8afdd37 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -118,9 +118,89 @@ def rearrange_indices_out_ir(data, output, valid_box_count): return ib.get() -def get_valid_counts_ir( - data, valid_count, out, out_indices, score_threshold, id_index, score_index -): +def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index): + batch_size = data.shape[0] + num_anchors = data.shape[1] + elem_length = data.shape[2] + + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data) + + valid_boxes = ib.buffer_ptr(valid_boxes) + if isinstance(score_threshold, float): + score_threshold = tvm.tir.FloatImm("float32", score_threshold) + id_index = tvm.tir.IntImm("int32", id_index) + score_index = tvm.tir.IntImm("int32", score_index) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + nthread_by = batch_size + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + tid = bx * max_threads + tx + + with ib.if_scope(tid < num_anchors): + i = by + j = tid + score = data[(i * num_anchors + j) * elem_length + score_index] + with ib.if_scope( + tvm.tir.all( + score > score_threshold, + tvm.tir.any( + id_index < 0, data[(i * num_anchors + j) * elem_length + id_index] >= 0 + ), + ) + ): + valid_boxes[i * num_anchors + j] = 1 + with ib.else_scope(): + valid_boxes[i * num_anchors + j] = 0 + return ib.get() + + +def get_valid_indices_ir(valid_boxes, valid_count, valid_indices): + batch_size = valid_boxes.shape[0] + num_anchors = valid_boxes.shape[1] + + ib = tvm.tir.ir_builder.create() + + valid_boxes = ib.buffer_ptr(valid_boxes) + + valid_count = ib.buffer_ptr(valid_count) + valid_indices = ib.buffer_ptr(valid_indices) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = batch_size // max_threads + 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + # TODO(mbrookhart): Parallelize the sum and cumsum here + current_index = ib.allocate("int32", (1,), name="current_index", scope="local") + with ib.if_scope(tid < batch_size): + current_index[0] = 0 + valid_count[tid] = 0 + with ib.for_range(0, num_anchors) as j: + idx = tid * num_anchors + j + valid_count[tid] = valid_count[tid] + valid_boxes[idx] + with ib.if_scope(valid_boxes[idx] == 1): + valid_indices[idx] = current_index[0] + current_index[0] = current_index[0] + 1 + with ib.else_scope(): + valid_indices[idx] = -1 + return ib.get() + + +def get_valid_counts_ir(data, valid_indices, out, out_indices): """Low level IR to get valid count of bounding boxes given a score threshold. Also prepares to move valid boxes to the top of input data. @@ -158,47 +238,51 @@ def get_valid_counts_ir( data = ib.buffer_ptr(data) - valid_count = ib.buffer_ptr(valid_count) + valid_indices = ib.buffer_ptr(valid_indices) out = ib.buffer_ptr(out) out_indices = ib.buffer_ptr(out_indices) one = tvm.tir.const(1, dtype=out.dtype) - if isinstance(score_threshold, float): - score_threshold = tvm.tir.FloatImm("float32", score_threshold) - id_index = tvm.tir.IntImm("int32", id_index) - score_index = tvm.tir.IntImm("int32", score_index) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + nthread_by = batch_size + nthread_bz = elem_length with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = batch_size // max_threads + 1 tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) tid = bx * max_threads + tx - with ib.if_scope(tid < batch_size): - valid_count[tid] = 0 - i = tid - with ib.for_range(0, num_anchors) as j: - score = data[(i * num_anchors + j) * elem_length + score_index] - with ib.if_scope( - tvm.tir.all( - score > score_threshold, - tvm.tir.any( - id_index < 0, data[(i * num_anchors + j) * elem_length + id_index] >= 0 - ), - ) - ): - with ib.for_range(0, elem_length) as k: - out[(i * num_anchors + valid_count[i]) * elem_length + k] = data[ - (i * num_anchors + j) * elem_length + k - ] - out_indices[i * num_anchors + valid_count[i]] = j - valid_count[i] += 1 - with ib.if_scope(j >= valid_count[i]): - with ib.for_range(0, elem_length) as k: - out[(i * num_anchors + j) * elem_length + k] = -one - out_indices[i * num_anchors + j] = -1 + with ib.if_scope(tid < num_anchors): + i = by + j = tid + k = bz + out[(i * num_anchors + j) * elem_length + k] = -one + out_indices[i * num_anchors + j] = -1 + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + tid = bx * max_threads + tx + with ib.if_scope(tid < num_anchors): + i = by + j = tid + k = bz + with ib.if_scope(valid_indices[i, tid] >= 0): + out[(i * num_anchors + valid_indices[i, tid]) * elem_length + k] = data[ + (i * num_anchors + j) * elem_length + k + ] + out_indices[i * num_anchors + valid_indices[i, tid]] = j return ib.get() @@ -231,23 +315,51 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): batch_size = data.shape[0] num_anchors = data.shape[1] data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + valid_boxes_buf = tvm.tir.decl_buffer( + (batch_size, num_anchors), "int32", "valid_boxes_buf", data_alignment=8 + ) + valid_boxes = te.extern( + [(batch_size, num_anchors)], + [data], + lambda ins, outs: get_valid_boxes_ir( + ins[0], outs[0], score_threshold, id_index, score_index + ), + dtype=["int32"], + in_buffers=[data_buf], + out_buffers=[valid_boxes_buf], + name="get_valid_boxes", + tag="get_valid_boxes_gpu", + ) + + valid_indices_buf = tvm.tir.decl_buffer( + (batch_size, num_anchors), "int32", "valid_indices_buf", data_alignment=8 + ) valid_count_buf = tvm.tir.decl_buffer( (batch_size,), "int32", "valid_count_buf", data_alignment=8 ) + valid_count, valid_indices = te.extern( + [(batch_size,), (batch_size, num_anchors)], + [valid_boxes], + lambda ins, outs: get_valid_indices_ir(ins[0], outs[0], outs[1]), + dtype=["int32"], + in_buffers=[valid_boxes_buf], + out_buffers=[valid_count_buf, valid_indices_buf], + name="get_valid_indices", + tag="get_valid_indices_gpu", + ) + out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) out_indices_buf = tvm.tir.decl_buffer( (batch_size, num_anchors), "int32", "out_buf", data_alignment=8 ) - valid_count, out, out_indices = te.extern( - [(batch_size,), data.shape, (batch_size, num_anchors)], - [data], - lambda ins, outs: get_valid_counts_ir( - ins[0], outs[0], outs[1], outs[2], score_threshold, id_index, score_index - ), + out, out_indices = te.extern( + [data.shape, (batch_size, num_anchors)], + [data, valid_indices], + lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], outs[0], outs[1]), dtype=["int32", data.dtype], - in_buffers=[data_buf], - out_buffers=[valid_count_buf, out_buf, out_indices_buf], + in_buffers=[data_buf, valid_indices_buf], + out_buffers=[out_buf, out_indices_buf], name="get_valid_counts", tag="get_valid_counts_gpu", )