diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2280cff3059b..984eae6b48b9 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -500,7 +500,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): j = bx * max_threads + tx with ib.if_scope(j < nkeep): src_idx = base_src_idx + sorted_index[i * num_anchors + j] * box_data_length - # Fill in out with sorted boxes with ib.for_range(0, 4) as k: out_bboxes[(base_bbox_idx + j * 4 + k)] = data[src_idx + coord_start + k] @@ -532,7 +531,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.for_range(0, 4) as k: out_bboxes[base_bbox_idx + j * 4 + k] = data[src_offset + coord_start + k] - out_scores[i * num_anchors + j] = data[src_offset + score_index] if id_index >= 0: @@ -600,6 +598,7 @@ def nms_inner_loop(ib, j): with ib.if_scope(iou >= iou_threshold): # invalidate the box k out_scores[i, k] = -1.0 + if return_indices is False and id_index >= 0: out_class_ids[i, k] = -1.0 @@ -615,6 +614,7 @@ def nms_inner_loop(ib, j): with ib.if_scope(out_scores[i, j] > -1.0): with ib.if_scope(max_output_size > 0): # No need to do more iteration if we already reach max_output_size boxes + # TODO(masahi): Add TIR while loop to realize early exit from the outer loop with ib.if_scope(num_valid_boxes_local[0] < max_output_size): nms_inner_loop(ib, j) with ib.else_scope(): @@ -688,6 +688,75 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape): return sort_tensor +def _run_nms( + data, + data_buf, + sort_tensor, + valid_count, + indices, + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + id_index, + score_index, + return_indices, +): + sort_tensor_buf = tvm.tir.decl_buffer( + sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8 + ) + + valid_count_dtype = "int32" + valid_count_buf = tvm.tir.decl_buffer( + valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4 + ) + indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) + + batch_size = data.shape[0] + num_anchors = data.shape[1] + + # output shapes + bbox_shape = (batch_size, num_anchors, 4) + score_shape = (batch_size, num_anchors) + class_id_shape = score_shape + box_indices_shape = score_shape + num_valid_boxes_shape = (batch_size, 1) + + return te.extern( + [bbox_shape, score_shape, class_id_shape, box_indices_shape, num_valid_boxes_shape], + [data, sort_tensor, valid_count, indices], + lambda ins, outs: nms_ir( + ins[0], + ins[1], + ins[2], + ins[3], + outs[0], # sorted bbox + outs[1], # sorted scores + outs[2], # sorted class ids + outs[3], # box_indices + outs[4], # num_valid_boxes + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + id_index, + score_index, + return_indices, + ), + dtype=[data.dtype, "float32", "float32", "int32", "int32"], + in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf], + name="nms", + tag="nms", + ) + + +def _concatenate_outputs(out_bboxes, out_scores, out_sorted_ids, score_index, id_index): + # TODO + return None + + def non_max_suppression( data, valid_count, @@ -779,63 +848,27 @@ def non_max_suppression( tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) f(tvm_data, tvm_valid_count, tvm_out) """ - batch_size = data.shape[0] - num_anchors = data.shape[1] - - valid_count_dtype = "int32" - valid_count_buf = tvm.tir.decl_buffer( - valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4 - ) - - score_shape = (batch_size, num_anchors) data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - sort_tensor = _get_sorted_indices(data, data_buf, score_index, score_shape) - sort_tensor_buf = tvm.tir.decl_buffer( - sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8 - ) - - indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) - - bbox_shape = (batch_size, num_anchors, 4) - class_id_shape = score_shape - box_indices_shape = score_shape - num_valid_boxes_shape = (batch_size, 1) - - out_bboxes, out_scores, out_sorted_ids, box_indices, num_valid_boxes = te.extern( - [bbox_shape, score_shape, class_id_shape, box_indices_shape, num_valid_boxes_shape], - [data, sort_tensor, valid_count, indices], - lambda ins, outs: nms_ir( - ins[0], - ins[1], - ins[2], - ins[3], - outs[0], # sorted bbox - outs[1], # sorted scores - outs[2], # sorted class ids - outs[3], # box_indices - outs[4], # num_valid_boxes - max_output_size, - iou_threshold, - force_suppress, - top_k, - coord_start, - id_index, - score_index, - return_indices, - ), - dtype=[data.dtype, "float32", "float32", "int32", "int32"], - in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf], - name="nms", - tag="nms", + sort_tensor = _get_sorted_indices(data, data_buf, score_index, (data.shape[0], data.shape[1])) + + out_bboxes, out_scores, out_sorted_ids, box_indices, num_valid_boxes = _run_nms( + data, + data_buf, + sort_tensor, + valid_count, + indices, + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + id_index, + score_index, + return_indices, ) if return_indices: return [box_indices, num_valid_boxes] - # TODO: do concat - return out_bboxes - # if id_index >= 0: - # return concatenate([out_bboxes - - # return out + return _concatenate_outputs(out_bboxes, out_scores, out_sorted_ids, score_index, id_index)