diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 8946446f3cdc3..3887021cc24e5 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -21,7 +21,7 @@ from tvm import te from tvm.tir import if_then_else -from .sort import argsort, argsort_thrust +from .sort import argsort, argsort_thrust, is_thrust_available def cuda_atomic_add_rule(op): @@ -412,7 +412,9 @@ def nms_ir( sorted_index, valid_count, indices, - out, + out_bboxes, + out_scores, + out_class_ids, box_indices, num_valid_boxes, max_output_size, @@ -444,8 +446,14 @@ def nms_ir( dimension are like the output of arange(num_anchors) if get_valid_counts is not used before non_max_suppression. - out : Buffer - Output buffer, to be filled with sorted boxes. + out_bboxes : Buffer + Output buffer, to be filled with sorted box coordinates. + + out_scores : Buffer + Output buffer, to be filled with sorted scores. + + out_class_ids : Buffer + Output buffer, to be filled with sorted class ids. box_indices : Buffer A indices tensor mapping sorted indices to original indices @@ -532,9 +540,13 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): sorted_index = ib.buffer_ptr(sorted_index) valid_count = ib.buffer_ptr(valid_count) indices = ib.buffer_ptr(indices) - num_valid_boxes = ib.buffer_ptr(num_valid_boxes) - out = ib.buffer_ptr(out) + + # outputs + out_bboxes = ib.buffer_ptr(out_bboxes) + out_scores = ib.buffer_ptr(out_scores) + out_class_ids = ib.buffer_ptr(out_class_ids) box_indices = ib.buffer_ptr(box_indices) + num_valid_boxes = ib.buffer_ptr(num_valid_boxes) if isinstance(iou_threshold, float): iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) @@ -557,31 +569,53 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) i = by - base_idx = i * num_anchors * box_data_length + base_src_idx = i * num_anchors * box_data_length + base_bbox_idx = i * num_anchors * 4 + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Reorder output nkeep = if_then_else( tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i] ) j = bx * max_threads + tx - with ib.if_scope(j < num_anchors): - box_indices[i * num_anchors + j] = -1 with ib.if_scope(j < nkeep): - # Fill in out with sorted boxes - with ib.for_range(0, box_data_length) as k: - out[(base_idx + j * box_data_length + k)] = data[ - (base_idx + sorted_index[i * num_anchors + j] * box_data_length + k) - ] + src_idx = base_src_idx + sorted_index[i * num_anchors + j] * box_data_length + with ib.for_range(0, 4, for_type="unroll") as k: + out_bboxes[(base_bbox_idx + j * 4 + k)] = data[src_idx + coord_start + k] + + out_scores[i * num_anchors + j] = data[src_idx + score_index] + + if id_index >= 0: + out_class_ids[i * num_anchors + j] = data[src_idx + id_index] + with ib.else_scope(): # Indices > nkeep are discarded + # Only needed for return_indices = False case + if return_indices is False: + with ib.if_scope(j < num_anchors): + with ib.for_range(0, 4, for_type="unroll") as k: + out_bboxes[(base_bbox_idx + j * 4 + k)] = -1.0 + + out_scores[i, j] = -1.0 + + if id_index >= 0: + out_class_ids[i, j] = -1.0 + + if return_indices: with ib.if_scope(j < num_anchors): - with ib.for_range(0, box_data_length) as k: - out[(base_idx + j * box_data_length + k)] = -1.0 + box_indices[i * num_anchors + j] = -1 + with ib.else_scope(): with ib.if_scope(j < valid_count[i]): - with ib.for_range(0, box_data_length) as k: - offset = base_idx + j * box_data_length + k - out[offset] = data[offset] + src_offset = base_src_idx + j * box_data_length + + with ib.for_range(0, 4, for_type="unroll") as k: + out_bboxes[base_bbox_idx + j * 4 + k] = data[src_offset + coord_start + k] + out_scores[i * num_anchors + j] = data[src_offset + score_index] + + if id_index >= 0: + out_class_ids[i * num_anchors + j] = data[src_offset + id_index] + box_indices[i * num_anchors + j] = j with ib.new_scope(): @@ -595,7 +629,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): i = by - base_idx = i * num_anchors * box_data_length + base_bbox_idx = i * num_anchors * 4 num_valid_boxes_local = ib.allocate( "int32", (1,), name="num_valid_boxes_local", scope="local" ) @@ -613,37 +647,36 @@ def nms_inner_loop(ib, j): num_valid_boxes_local[0] += 1 - offset_j = j * box_data_length + offset_j = j * 4 num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx) with ib.for_range(0, num_iter_per_thread) as _k: k = j + 1 + _k * nthread_tx + tx - offset_k = k * box_data_length + offset_k = k * 4 with ib.if_scope( tvm.tir.all( k < nkeep, - out[base_idx + offset_k + score_index] > 0, # is the box k still valid? + out_scores[i, k] > 0, # is the box k still valid? tvm.tir.any( force_suppress > 0, id_index < 0, - out[base_idx + offset_k + id_index] - == out[base_idx + offset_j + id_index], + out_class_ids[i, k] == out_class_ids[i, j], ), ) ): iou = calculate_overlap( - out, - base_idx + offset_j + coord_start, - base_idx + offset_k + coord_start, + out_bboxes, + base_bbox_idx + offset_j, + base_bbox_idx + offset_k, ) with ib.if_scope(iou >= iou_threshold): # invalidate the box k - out[base_idx + offset_k + score_index] = -1.0 - with ib.if_scope(id_index >= 0): - out[base_idx + offset_k + id_index] = -1.0 + out_scores[i, k] = -1.0 + + if return_indices is False and id_index >= 0: + out_class_ids[i, k] = -1.0 - # Make sure to do the next loop in a lock step ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) if isinstance(max_output_size, int): @@ -653,9 +686,11 @@ def nms_inner_loop(ib, j): # Apply nms with ib.for_range(0, nkeep) as j: # Proceed to the inner loop if the box j is still valid - with ib.if_scope(out[base_idx + (j * box_data_length) + score_index] > -1.0): + with ib.if_scope(out_scores[i, j] > -1.0): with ib.if_scope(max_output_size > 0): - # No need to do more iteration if we already reach max_output_size boxes + # No need to do more iteration if we have already reached max_output_size + # boxes + # TODO(masahi): Add TIR while loop to realize early exit from the outer loop with ib.if_scope(num_valid_boxes_local[0] < max_output_size): nms_inner_loop(ib, j) with ib.else_scope(): @@ -699,6 +734,147 @@ def _fetch_score_ir(data, score, axis): return ib.get() +def _get_sorted_indices(data, data_buf, score_index, score_shape): + """Extract a 1D score tensor from the packed input and do argsort on it.""" + score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8) + score_tensor = te.extern( + [score_shape], + [data], + lambda ins, outs: _fetch_score_ir( + ins[0], + outs[0], + score_index, + ), + dtype=[data.dtype], + in_buffers=[data_buf], + out_buffers=[score_buf], + name="fetch_score", + tag="fetch_score", + ) + + if is_thrust_available(): + sort_tensor = argsort_thrust( + score_tensor, valid_count=None, axis=1, is_ascend=False, dtype="int32" + ) + else: + sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32") + + return sort_tensor + + +def _run_nms( + data, + data_buf, + sort_tensor, + valid_count, + indices, + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + id_index, + score_index, + return_indices, +): + """Run NMS using sorted scores.""" + sort_tensor_buf = tvm.tir.decl_buffer( + sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8 + ) + + valid_count_dtype = "int32" + valid_count_buf = tvm.tir.decl_buffer( + valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4 + ) + indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) + + batch_size = data.shape[0] + num_anchors = data.shape[1] + + # output shapes + bbox_shape = (batch_size, num_anchors, 4) + score_shape = (batch_size, num_anchors) + class_id_shape = score_shape + box_indices_shape = score_shape + num_valid_boxes_shape = (batch_size, 1) + + return te.extern( + [bbox_shape, score_shape, class_id_shape, box_indices_shape, num_valid_boxes_shape], + [data, sort_tensor, valid_count, indices], + lambda ins, outs: nms_ir( + ins[0], + ins[1], + ins[2], + ins[3], + outs[0], # sorted bbox + outs[1], # sorted scores + outs[2], # sorted class ids + outs[3], # box_indices + outs[4], # num_valid_boxes + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + id_index, + score_index, + return_indices, + ), + dtype=[data.dtype, "float32", "float32", "int32", "int32"], + in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf], + name="nms", + tag="nms", + ) + + +def _concatenate_outputs( + out_bboxes, out_scores, out_class_ids, out_shape, coord_start, score_index, id_index +): + """Pack the results from NMS into a single 5D or 6D tensor.""" + batch_size = out_bboxes.shape[0] + num_anchors = out_bboxes.shape[1] + + def ir(out_bboxes, out_scores, out_class_ids, out): + ib = tvm.tir.ir_builder.create() + + out_bboxes = ib.buffer_ptr(out_bboxes) + out_scores = ib.buffer_ptr(out_scores) + out_class_ids = ib.buffer_ptr(out_class_ids) + out = ib.buffer_ptr(out) + + with ib.if_scope(num_anchors > 0): + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(num_anchors, nthread_tx) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", batch_size) + + tid = bx * nthread_tx + tx + i = by + + with ib.if_scope(tid < num_anchors): + with ib.for_range(0, 4, for_type="unroll") as j: + out[i, tid, coord_start + j] = out_bboxes[i, tid, j] + out[i, tid, score_index] = out_scores[i, tid] + if id_index >= 0: + out[i, tid, id_index] = out_class_ids[i, tid] + + return ib.get() + + return te.extern( + [out_shape], + [out_bboxes, out_scores, out_class_ids], + lambda ins, outs: ir(ins[0], ins[1], ins[2], outs[0]), + dtype=["float32"], + name="nms_output_concat", + tag="nms_output_concat", + ) + + def non_max_suppression( data, valid_count, @@ -790,77 +966,29 @@ def non_max_suppression( tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) f(tvm_data, tvm_valid_count, tvm_out) """ - batch_size = data.shape[0] - num_anchors = data.shape[1] - - valid_count_dtype = "int32" - valid_count_buf = tvm.tir.decl_buffer( - valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4 - ) - score_axis = score_index - score_shape = (batch_size, num_anchors) - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8) - score_tensor = te.extern( - [score_shape], - [data], - lambda ins, outs: _fetch_score_ir( - ins[0], - outs[0], - score_axis, - ), - dtype=[data.dtype], - in_buffers=[data_buf], - out_buffers=[score_buf], - name="fetch_score", - tag="fetch_score", - ) - target = tvm.target.Target.current() - if ( - target - and target.kind.name == "cuda" - and tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True) - ): - sort_tensor = argsort_thrust( - score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype - ) - else: - sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype) - - sort_tensor_buf = tvm.tir.decl_buffer( - sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8 - ) - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) - out, box_indices, num_valid_boxes = te.extern( - [data.shape, score_shape, [batch_size, 1]], - [data, sort_tensor, valid_count, indices], - lambda ins, outs: nms_ir( - ins[0], - ins[1], - ins[2], - ins[3], - outs[0], - outs[1], - outs[2], - max_output_size, - iou_threshold, - force_suppress, - top_k, - coord_start, - id_index, - score_index, - return_indices, - ), - dtype=[data.dtype, "int32", "int32"], - in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf], - name="nms", - tag="nms", + sort_tensor = _get_sorted_indices(data, data_buf, score_index, (data.shape[0], data.shape[1])) + + out_bboxes, out_scores, out_class_ids, box_indices, num_valid_boxes = _run_nms( + data, + data_buf, + sort_tensor, + valid_count, + indices, + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + id_index, + score_index, + return_indices, ) if return_indices: return [box_indices, num_valid_boxes] - return out + return _concatenate_outputs( + out_bboxes, out_scores, out_class_ids, data.shape, coord_start, score_index, id_index + )