diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 005b900d5d44..976304e79c34 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -114,11 +114,19 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { + std::string output_format; + TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs, - "relay.attrs.AllClassNonMaximumSuppressionAttrs") {} + "relay.attrs.AllClassNonMaximumSuppressionAttrs") { + TVM_ATTR_FIELD(output_format) + .set_default("onnx") + .describe( + "Output format, onnx or tensorflow. Returns outputs in a way that can be easily " + "consumed by each frontend."); + } }; /*! \brief Attributes used in roi_align operators */ diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 040f8384dbe0..fdd58bb53ba5 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -793,6 +793,88 @@ def _impl(inputs, attr, params, mod): return _impl +def convert_combined_nms_with_all_class_nms( + batch_size, + max_output_boxes_per_batch, + num_class, + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + max_total_size, + clip_boxes, +): + """Converts TF combined_nms using Relay all_class_max_suppression op""" + (selected_indices, selected_scores, num_detections,) = _op.vision.all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + output_format="tensorflow", + ) + box_range = _op.arange( + _op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64" + ) + assert isinstance(batch_size, int), "dynamic batch size not supported yet." + tile_batch_reps = _op.const([batch_size, 1]) + box_range_2d = _op.tile(box_range, tile_batch_reps) + valid_mask = _op.cast( + _op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32" + ) + + def select_topk(do_zero_pad): + def true_branch(): + arange = _op.arange( + _op.const(0, dtype="int64"), + _op.const(max_output_boxes_per_batch, dtype="int64"), + dtype="int64", + ) + pad = _op.full( + _op.const(0, dtype="int64"), (max_total_size - max_output_boxes_per_batch,) + ) + topk_indices = _op.tile(_op.concatenate([arange, pad], 0), tile_batch_reps) + nmsed_scores = _op.gather(selected_scores, 1, topk_indices) + nmsed_scores = nmsed_scores * valid_mask + return nmsed_scores, topk_indices + + def false_branch(): + if isinstance(max_output_boxes_per_class, int): + # Do topk on smaller input if possible + slice_mx = _op.const([max_output_boxes_per_class * num_class], dtype="int64") + selected_scores_slice = _op.strided_slice( + selected_scores, begin=_op.const([0], dtype="int64"), end=slice_mx, axes=[1] + ) + else: + selected_scores_slice = selected_scores + return _op.topk(selected_scores_slice, k=max_total_size, axis=1, ret_type="both") + + # TODO(masahi): support dynamic num_boxes + # return _expr.If(do_zero_pad, true_branch(), false_branch()) + return true_branch() if do_zero_pad else false_branch() + + assert isinstance(max_output_boxes_per_batch, int), "dynamic number of boxes not supported yet." + nmsed_scores, topk_indices = select_topk(max_output_boxes_per_batch < max_total_size) + + indices = _op.take(selected_indices, topk_indices, axis=1, batch_dims=1) + nmsed_box_indices = _op.take(indices, _op.const(1), axis=2) + nmsed_classes = _op.take(indices, _op.const(0), axis=2) + nmsed_classes = _op.cast(nmsed_classes, "float32") + nmsed_boxes = _op.take(boxes, nmsed_box_indices, axis=1, batch_dims=1) + num_detections = _op.minimum(num_detections, _op.const(max_total_size, dtype="int64")) + + if clip_boxes: + nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32")) + nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32")) + + nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2) + + return _expr.TupleWrapper( + _expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, num_detections]), 4 + ) + + def _combined_nms(): def _impl(inputs, attr, params, mod): # Get parameter values @@ -821,9 +903,27 @@ def _impl(inputs, attr, params, mod): q = boxes_shape[2] num_classes = scores_shape[2] - if q != num_classes: - # When q is 1, it means same box coords are used for all classes. - boxes = _op.broadcast_to(boxes, (batch_size, num_anchors, num_classes, 4)) + assert isinstance(batch_size, int) and isinstance( + num_anchors, int + ), "Dynamic inputs not supported yet" + + if q == 1: + boxes = _op.squeeze(boxes, axis=[2]) + scores_trans = _op.transpose(scores, [0, 2, 1]) + max_output_boxes_per_batch = num_anchors * num_classes + return convert_combined_nms_with_all_class_nms( + batch_size, + max_output_boxes_per_batch, + num_classes, + boxes, + scores_trans, + max_output_size, + iou_threshold, + score_threshold, + max_total_size.data.numpy().item(), + attr["clip_boxes"], + ) + boxes = _op.reshape(boxes, newshape=[batch_size, num_anchors * num_classes, 4]) scores = _op.reshape(scores, newshape=[batch_size, num_anchors * num_classes, 1]) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 0d6c3ef58cdf..d56820e409aa 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1095,7 +1095,15 @@ def _compute_nms(attrs, inputs, out_type): max_output_size = inputs[2] iou_threshold = inputs[3] score_threshold = inputs[4] - return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold) + output_format = attrs.output_format + return topi_compute( + inputs[0], + inputs[1], + max_output_size, + iou_threshold, + score_threshold, + output_format, + ) return _compute_nms diff --git a/python/tvm/relay/op/vision/_vision.py b/python/tvm/relay/op/vision/_vision.py index 8d6abf1a8c20..cab9f703e88a 100644 --- a/python/tvm/relay/op/vision/_vision.py +++ b/python/tvm/relay/op/vision/_vision.py @@ -89,7 +89,7 @@ def nms_shape_func(attrs, inputs, _): @script -def _all_class_nms_shape_func(boxes_shape, scores_shape): +def _all_class_nms_shape_func_onnx(boxes_shape, scores_shape): out_shape = output_tensor((2,), "int64") count_shape = output_tensor((1,), "int64") @@ -99,9 +99,27 @@ def _all_class_nms_shape_func(boxes_shape, scores_shape): return out_shape, count_shape +@script +def _all_class_nms_shape_func_tf(boxes_shape, scores_shape): + out_indices_shape = output_tensor((3,), "int64") + out_scores_shape = output_tensor((2,), "int64") + count_shape = output_tensor((1,), "int64") + + out_indices_shape[0] = boxes_shape[0] + out_indices_shape[1] = scores_shape[1] * boxes_shape[1] + out_indices_shape[2] = int64(2) + out_scores_shape[0] = boxes_shape[0] + out_scores_shape[1] = scores_shape[1] * boxes_shape[1] + count_shape[0] = boxes_shape[0] + + return out_indices_shape, out_scores_shape, count_shape + + @reg.register_shape_func("vision.all_class_non_max_suppression", False) def all_class_nms_shape_func(attrs, inputs, _): - return _all_class_nms_shape_func(inputs[0], inputs[1]) + if attrs.output_format == "onnx": + return _all_class_nms_shape_func_onnx(inputs[0], inputs[1]) + return _all_class_nms_shape_func_tf(inputs[0], inputs[1]) @script diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 3f829e0b1cc7..8c54075d952c 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -152,7 +152,12 @@ def non_max_suppression( def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0 + boxes, + scores, + max_output_boxes_per_class=-1, + iou_threshold=-1.0, + score_threshold=-1.0, + output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -175,16 +180,31 @@ def all_class_non_max_suppression( score_threshold : float or relay.Expr, optional Score threshold to filter out low score boxes early + output_format : string, optional + "onnx" or "tensorflow". Specify by which frontends the outputs are + intented to be consumed. + Returns ------- out : relay.Tuple - The output is a relay.Tuple of two tensors, the first is `indices` of size - `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor - `num_total_detection` of shape `(1,)` representing the total number of selected boxes. + If `output_format` is "onnx", the output is a relay.Tuple of two tensors, the first is + `indices` of size `(batch_size * num_class* num_boxes , 3)` and the second is a scalar + tensor `num_total_detection` of shape `(1,)` representing the total number of selected + boxes. The three values in `indices` encode batch, class, and box indices. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` rows are valid. + + If `output_format` is "tensorflow", the output is a relay.Tuple of three tensors, the first + is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of + size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size + `(batch_size,)` representing the total number of selected boxes per batch. The two values + in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at + batch b, only the first `num_total_detection[b]` entries are valid. The second axis of + `indices` and `scores` are sorted within each class by box scores, but not across classes. + So the box indices and scores for the class 0 come first in a sorted order, followed by + the class 1 etc. """ if not isinstance(max_output_boxes_per_class, expr.Expr): max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32") @@ -194,6 +214,15 @@ def all_class_non_max_suppression( score_threshold = expr.const(score_threshold, "float32") out = _make.all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + output_format, ) - return expr.TupleWrapper(out, 2) + + if output_format == "onnx": + return expr.TupleWrapper(out, 2) + + return expr.TupleWrapper(out, 3) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 9a3b86d72b18..e402c5888978 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -32,6 +32,7 @@ calculate_overlap, binary_search, collect_selected_indices, + collect_selected_indices_and_scores, run_all_class_nms, ) @@ -988,8 +989,74 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro return ib.get() +def _collect_selected_indices_and_scores_ir( + selected_indices, + selected_scores, + num_detections, + row_offsets, + num_total_detections, + collected_indices, + collected_scores, +): + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + selected_scores = ib.buffer_ptr(selected_scores) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + num_total_detections = ib.buffer_ptr(num_total_detections) + collected_indices = ib.buffer_ptr(collected_indices) + collected_scores = ib.buffer_ptr(collected_scores) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(num_boxes, nthread_tx) + nthread_by = batch_size * num_class + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + zero = cast(0, "int64") + + with ib.new_scope(): + idx = bx * nthread_tx + tx + idy = cast(by, "int64") + batch_id = idy // num_class + class_id = idy % num_class + + with ib.if_scope(idx < num_detections[batch_id, class_id]): + offset = row_offsets[batch_id, class_id] + idx + collected_indices[batch_id, offset, 0] = class_id + collected_indices[batch_id, offset, 1] = cast(selected_indices[idy, idx], "int64") + collected_scores[batch_id, offset] = selected_scores[idy, idx] + with ib.else_scope(): + with ib.if_scope(idx < num_boxes): + offset = ( + num_total_detections[batch_id] + + class_id * num_boxes + - row_offsets[batch_id, class_id] + + idx + - num_detections[batch_id, class_id] + ) + collected_indices[batch_id, offset, 0] = zero + collected_indices[batch_id, offset, 1] = zero + collected_scores[batch_id, offset] = 0.0 + + return ib.get() + + def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -1012,16 +1079,30 @@ def all_class_non_max_suppression( score_threshold : float or tvm.te.Tensor, optional Score threshold to filter out low score boxes early + output_format : str, optional + "onnx" or "tensorflow", see below + Returns ------- - out : [tvm.te.Tensor, tvm.te.Tensor] - The output is two tensors, the first is `indices` of size + out : list of tvm.te.Tensor + If `output_format` is "onnx", the output is two tensors. The first is `indices` of size `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor `num_total_detection` of shape `(1,)` representing the total number of selected - boxes. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + boxes. The three values in `indices` encode batch, class, and box indices. + Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` rows are valid. + + If `output_format` is "tensorflow", the output is three tensors, the first + is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of + size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size + `(batch_size,)` representing the total number of selected boxes per batch. The two values + in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at + batch b, only the first `num_total_detection[b]` entries are valid. The second axis of + `indices` and `scores` are sorted within each class by box scores, but not across classes. + So the box indices and scores for the class 0 come first in a sorted order, followed by + the class 1 etc. """ batch, num_class, num_boxes = scores.shape @@ -1029,7 +1110,7 @@ def all_class_non_max_suppression( sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both") valid_count = _get_valid_box_count(sorted_scores, score_threshold) - selected_indices, num_detections = run_all_class_nms( + selected_indices, selected_scores, num_detections = run_all_class_nms( boxes, sorted_scores, sorted_indices, @@ -1037,14 +1118,30 @@ def all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, _nms_loop, + return_scores=(output_format == "tensorflow"), ) + if output_format == "onnx": + row_offsets, num_total_detections = exclusive_scan( + num_detections, return_reduction=True, output_dtype="int64" + ) + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) + return [selected_indices, num_total_detections] + + num_detections_per_batch = reshape(num_detections, (batch, num_class)) row_offsets, num_total_detections = exclusive_scan( - num_detections, return_reduction=True, output_dtype="int64" + num_detections_per_batch, return_reduction=True, output_dtype="int64", axis=1 ) - selected_indices = collect_selected_indices( - num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + selected_indices, selected_scores = collect_selected_indices_and_scores( + selected_indices, + selected_scores, + num_detections_per_batch, + row_offsets, + num_total_detections, + _collect_selected_indices_and_scores_ir, ) - return [selected_indices, num_total_detections] + return [selected_indices, selected_scores, num_total_detections] diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 6dbaf02191c8..0d19a92f2058 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -231,7 +231,7 @@ def ir(data, data_ex_scan, reduction): data[tid * scan_axis_size + scan_axis_size - 1], ) with ib.else_scope(): - reduction[tid] = 0 + reduction[tid] = cast(0, reduction.dtype) return ib.get() diff --git a/python/tvm/topi/cuda/vision.py b/python/tvm/topi/cuda/vision.py index 88983ab89f76..5208aeccd413 100644 --- a/python/tvm/topi/cuda/vision.py +++ b/python/tvm/topi/cuda/vision.py @@ -39,7 +39,9 @@ def traverse(op): traverse(tensor.op) scheduled_ops.append(op) - traverse(outs[0].op) + for o in outs: + traverse(o.op) + return s diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 744c5ef7feda..7a51946a279a 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -22,14 +22,15 @@ from tvm.te import hybrid from tvm.tir import if_then_else -from ..sort import sort, argsort +from ..sort import argsort from ..math import cast -from ..transform import reshape +from ..transform import reshape, gather from .. import reduction from ..scan import cumsum from .nms_util import ( binary_search, collect_selected_indices, + collect_selected_indices_and_scores, run_all_class_nms, ) @@ -727,8 +728,62 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro return ib.get() +def _collect_selected_indices_and_scores_ir( + selected_indices, + selected_scores, + num_detections, + row_offsets, + num_total_detections, + collected_indices, + collected_scores, +): + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + selected_scores = ib.buffer_ptr(selected_scores) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + num_total_detections = ib.buffer_ptr(num_total_detections) + collected_indices = ib.buffer_ptr(collected_indices) + collected_scores = ib.buffer_ptr(collected_scores) + zero = cast(0, "int64") + + with ib.for_range(0, batch_size * num_class, name="i", kind="parallel") as i: + i = cast(i, "int64") + batch_id = i // num_class + class_id = i % num_class + + with ib.for_range(0, num_boxes, name="j") as j: + with ib.if_scope(j < num_detections[batch_id, class_id]): + offset = row_offsets[batch_id, class_id] + j + collected_indices[batch_id, offset, 0] = class_id + collected_indices[batch_id, offset, 1] = cast(selected_indices[i, j], "int64") + collected_scores[batch_id, offset] = selected_scores[i, j] + with ib.else_scope(): + offset = ( + num_total_detections[batch_id] + + class_id * num_boxes + - row_offsets[batch_id, class_id] + + j + - num_detections[batch_id, class_id] + ) + collected_indices[batch_id, offset, 0] = zero + collected_indices[batch_id, offset, 1] = zero + collected_scores[batch_id, offset] = 0.0 + + return ib.get() + + def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -751,25 +806,40 @@ def all_class_non_max_suppression( score_threshold : float or tvm.te.Tensor, optional Score threshold to filter out low score boxes early + output_format : str, optional + "onnx" or "tensorflow", see below. + Returns ------- - out : [tvm.te.Tensor, tvm.te.Tensor] - The output is two tensors, the first is `indices` of size + out : list of tvm.te.Tensor + If `output_format` is "onnx", the output is two tensors. The first is `indices` of size `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor `num_total_detection` of shape `(1,)` representing the total number of selected - boxes. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + boxes. The three values in `indices` encode batch, class, and box indices. + Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` rows are valid. + + If `output_format` is "tensorflow", the output is three tensors, the first + is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of + size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size + `(batch_size,)` representing the total number of selected boxes per batch. The two values + in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at + batch b, only the first `num_total_detection[b]` entries are valid. The second axis of + `indices` and `scores` are sorted within each class by box scores, but not across classes. + So the box indices and scores for the class 0 come first in a sorted order, followed by + the class 1 etc. """ batch, num_class, num_boxes = scores.shape scores = reshape(scores, (batch * num_class, num_boxes)) - sorted_scores = sort(scores, axis=1, is_ascend=False) sorted_indices = argsort(scores, axis=1, is_ascend=False, dtype="int32") + sorted_scores = gather(scores, 1, sorted_indices) + valid_count = _get_valid_box_count(sorted_scores, score_threshold) - selected_indices, num_detections = run_all_class_nms( + selected_indices, selected_scores, num_detections = run_all_class_nms( boxes, sorted_scores, sorted_indices, @@ -777,14 +847,29 @@ def all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, _nms_loop, + return_scores=(output_format == "tensorflow"), ) - row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") + if output_format == "onnx": + row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") + num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1) - num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1) - - selected_indices = collect_selected_indices( - num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) + return [selected_indices, num_total_detections] + + num_detections_per_batch = reshape(num_detections, (batch, num_class)) + row_offsets = cumsum(num_detections_per_batch, exclusive=True, dtype="int64", axis=1) + num_total_detections = reduction.sum(cast(num_detections_per_batch, "int64"), axis=1) + + selected_indices, selected_scores = collect_selected_indices_and_scores( + selected_indices, + selected_scores, + num_detections_per_batch, + row_offsets, + num_total_detections, + _collect_selected_indices_and_scores_ir, ) - return [selected_indices, num_total_detections] + return [selected_indices, selected_scores, num_total_detections] diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index 1147b1687783..d12592fd111a 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -106,28 +106,63 @@ def collect_selected_indices(num_class, selected_indices, num_detections, row_of first, in descending of scores, followed by boxes from batch 0, class 1 etc. """ batch_class, num_boxes = selected_indices.shape - - selected_indices_buf = tvm.tir.decl_buffer( - selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8 - ) - num_detections_buf = tvm.tir.decl_buffer( - num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8 - ) - row_offsets_buf = tvm.tir.decl_buffer( - row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8 - ) - return te.extern( [(batch_class * num_boxes, 3)], [selected_indices, num_detections, row_offsets], lambda ins, outs: ir(num_class, ins[0], ins[1], ins[2], outs[0]), dtype=["int64"], - in_buffers=[selected_indices_buf, num_detections_buf, row_offsets_buf], name="collect_indices", tag="collect_indices", ) +def collect_selected_indices_and_scores( + selected_indices, selected_scores, num_detections, row_offsets, num_total_detections, ir +): + """Collect selected indices and scores from the core NMS loop into one linear output + + Parameters + ---------- + num_class : int + + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the indices + of selected boxes by the core NMS loop. + + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the scores + of selected boxes by the core NMS loop. + + num_detections tvm.te.Tensor + 2-D tensor with shape (batch_size, num_classes), representing + the number of boxes selected by the core NMS loop, per batch and class + + row_offsets tvm.te.Tensor + 2-D tensor with shape (batch_size, num_classes), this should be the exclusive scan + of num_detections along axis 1 + + ir : function + A function to generate IR for CPU or GPU, see its usage in vision/nms.py and cuda/nms.py + + Returns + ------- + out : [tvm.te.Tensor, tvm.te.Tensor] + The output is two tensors. The first is indices of size + (batch_size, num_class* num_boxes, 2), and the second is scores of size + (batch_size, num_class* num_boxes). + """ + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + return te.extern( + [(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)], + [selected_indices, selected_scores, num_detections, row_offsets, num_total_detections], + lambda ins, outs: ir(ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], outs[1]), + dtype=["int64", "float32"], + name="collect_indices_and_scores", + tag="collect_indices_and_scores", + ) + + def _all_class_nms_ir( boxes, sorted_scores, @@ -139,6 +174,7 @@ def _all_class_nms_ir( iou_threshold, max_output_size_per_class, box_indices, + selected_scores, num_valid_boxes, nms_loop, ): @@ -150,6 +186,9 @@ def _all_class_nms_ir( box_indices = ib.buffer_ptr(box_indices) num_valid_boxes = ib.buffer_ptr(num_valid_boxes) + if selected_scores is not None: + selected_scores = ib.buffer_ptr(selected_scores) + if isinstance(iou_threshold, float): iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) @@ -171,6 +210,9 @@ def on_new_valid_box(ib, tid, num_current_valid_box, i, j): with ib.if_scope(tid + 0 == 0): box_indices[i, num_current_valid_box] = sorted_indices[i, j] + if selected_scores is not None: + selected_scores[i, num_current_valid_box] = sorted_scores[i, j] + def on_new_invalidated_box(*_): pass @@ -201,6 +243,7 @@ def run_all_class_nms( max_output_size_per_class, iou_threshold, nms_loop, + return_scores=False, ): """The core all class NMS routine @@ -230,31 +273,49 @@ def run_all_class_nms( nms_loop : function A core NMS loop, see its usage in vision/nms.py and cuda/nms.py + return_scores : bool, optional + Whether or not to return selected scores, needed by the tensorflow output format. + Returns ------- - out : [tvm.te.Tensor, tvm.te.Tensor] - The output is two tensors, the first is indices of size - (batch_size * num_class, num_boxes) and the second is a tensor + out : a list of tvm.te.Tensor + The output is three tensors, the first and second are indices and scores of size + (batch_size * num_class, num_boxes), and the third is a tensor num_selected_boxes of shape (batch_size * num_class,) representing the total number of - selected boxes per batch and class. + selected boxes per batch and class. If return_scores is False, the second output is + None. """ batch, num_boxes, _ = boxes.shape batch_class = sorted_scores.shape[0] num_class = batch_class // batch - boxes_buf = tvm.tir.decl_buffer(boxes.shape, boxes.dtype, "boxes_buf", data_alignment=8) - sorted_scores_buf = tvm.tir.decl_buffer( - sorted_scores.shape, sorted_scores.dtype, "sorted_scores_buf", data_alignment=8 - ) - sorted_indices_buf = tvm.tir.decl_buffer( - sorted_indices.shape, sorted_indices.dtype, "sorted_indices_buf", data_alignment=8 - ) - valid_count_buf = tvm.tir.decl_buffer( - valid_count.shape, "int32", "valid_count_buf", data_alignment=4 - ) + if return_scores is False: + selected_indices, num_detections = te.extern( + [(batch_class, num_boxes), (1, batch_class)], + [boxes, sorted_scores, sorted_indices, valid_count], + lambda ins, outs: _all_class_nms_ir( + ins[0], # boxes + ins[1], # sorted_scores + ins[2], # sorted_indices + ins[3], # valid_count + batch_class, + num_class, + num_boxes, + iou_threshold, + max_output_size_per_class, + outs[0], # box_indices + None, # scores + outs[1], # num_selected_boxes + nms_loop, + ), + dtype=["int32", "int32"], + name="all_class_nms", + tag="all_class_nms", + ) + return selected_indices, None, num_detections return te.extern( - [(batch_class, num_boxes), (1, batch_class)], + [(batch_class, num_boxes), (batch_class, num_boxes), (1, batch_class)], [boxes, sorted_scores, sorted_indices, valid_count], lambda ins, outs: _all_class_nms_ir( ins[0], # boxes @@ -267,16 +328,11 @@ def run_all_class_nms( iou_threshold, max_output_size_per_class, outs[0], # box_indices - outs[1], # num_selected_boxes + outs[1], # selected scores + outs[2], # num_selected_boxes nms_loop, ), - dtype=["int32", "int32"], - in_buffers=[ - boxes_buf, - sorted_scores_buf, - sorted_indices_buf, - valid_count_buf, - ], + dtype=["int32", "float32", "int32"], name="all_class_nms", tag="all_class_nms", ) diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 53cd71745d5b..8c33c1648cf3 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -152,24 +152,39 @@ bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs IndexExpr num_classes = scores_shape[1]; IndexExpr num_boxes = boxes_shape[1]; - IndexExpr num_total_boxes = Any(); - if (!batch.as() && !num_boxes.as()) { - num_total_boxes = batch * num_classes * num_boxes; - } + const auto* param = attrs.as(); + CHECK(param); - // assign output type std::vector fields; - std::vector oshape{num_total_boxes, 3}; - fields.push_back(TensorType(oshape, DataType::Int(64))); - std::vector countshape{1}; - fields.push_back(TensorType(countshape, DataType::Int(64))); + if (param->output_format == "onnx") { + IndexExpr num_total_boxes = Any(); + if (!batch.as() && !num_boxes.as()) { + num_total_boxes = batch * num_classes * num_boxes; + } + std::vector oshape{num_total_boxes, 3}; + std::vector counts_shape{1}; + fields.push_back(TensorType(oshape, DataType::Int(64))); + fields.push_back(TensorType(counts_shape, DataType::Int(64))); + } else { + IndexExpr num_total_boxes_per_batch = Any(); + if (!num_boxes.as()) { + num_total_boxes_per_batch = num_classes * num_boxes; + } + std::vector indices_shape{batch, num_total_boxes_per_batch, 2}; + std::vector scores_shape{batch, num_total_boxes_per_batch}; + std::vector counts_shape{batch}; + fields.push_back(TensorType(indices_shape, DataType::Int(64))); + fields.push_back(TensorType(scores_shape, DataType::Float(32))); + fields.push_back(TensorType(counts_shape, DataType::Int(64))); + } reporter->Assign(types[5], TupleType(Array(fields))); return true; } Expr MakeAllClassNMS(Expr boxes, Expr scores, Expr max_output_boxes_per_class, Expr iou_threshold, - Expr score_threshold) { + Expr score_threshold, std::string output_format = "onnx") { auto attrs = make_object(); + attrs->output_format = std::move(output_format); static const Op& op = Op::Get("vision.all_class_non_max_suppression"); return Call(op, {boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold}, Attrs(attrs), {}); diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index f29450dbb604..331553388b48 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -151,7 +151,6 @@ def run_tvm_graph( return vmobj_to_list(result) elif mode == "vm": with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): - print(mod["main"]) mod = relay.transform.InferType()(mod) vm_exec = relay.vm.compile(mod, target="llvm", params=params) if serialize: @@ -3438,16 +3437,18 @@ def _test_forward_combined_nms( "nms/CombinedNonMaxSuppression:2", "nms/CombinedNonMaxSuppression:3", ], - mode="vm", ) def test_forward_combined_nms(): """CombinedNonMaxSuppression""" _test_forward_combined_nms((1, 64, 1, 4), (1, 64, 1), 0.7, 0.5, 64, 64) + _test_forward_combined_nms((1, 32, 1, 4), (1, 32, 1), 0.7, 0.5, 10, 64) + _test_forward_combined_nms((1, 32, 1, 4), (1, 32, 2), 0.7, 0.5, 32, 64) _test_forward_combined_nms((1, 64, 1, 4), (1, 64, 20), 0.7, 0.5, 64, 10) _test_forward_combined_nms((1, 64, 20, 4), (1, 64, 20), 0.7, 0.5, 64, 64, clip_boxes=True) _test_forward_combined_nms((2, 200, 1, 4), (2, 200, 1), 0.4, 0.6, 100, 100) + _test_forward_combined_nms((2, 200, 1, 4), (2, 200, 10), 0.4, 0.2, 150, 1000) ####################################################################### diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 74b8ec51e1fa..57f07b3f00e5 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1611,7 +1611,8 @@ def verify_all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, score_threshold, - expected_indices, + expected, + output_format="onnx", ): batch_size = boxes_np.shape[0] num_classes = scores_np.shape[1] @@ -1622,23 +1623,23 @@ def verify_all_class_non_max_suppression( ) nms_out = relay.vision.all_class_non_max_suppression( - boxes, - scores, - max_output_boxes_per_class, - iou_threshold, - score_threshold, + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format ) - three = relay.const(np.array([3]), dtype="int64") - begin = relay.const(np.array([0, 0]), dtype="int64") - end = relay.op.concatenate([nms_out[1], three], axis=0) - strides = relay.const(np.array([1, 1]), dtype="int64") - out = relay.op.strided_slice(nms_out[0], begin, end, strides) - - mod = tvm.IRModule() - mod["main"] = relay.Function([boxes, scores], out) - - check_result([boxes_np, scores_np], mod, [expected_indices]) + if output_format == "onnx": + three = relay.const(np.array([3]), dtype="int64") + begin = relay.const(np.array([0, 0]), dtype="int64") + end = relay.op.concatenate([nms_out[1], three], axis=0) + strides = relay.const(np.array([1, 1]), dtype="int64") + out = relay.op.strided_slice(nms_out[0], begin, end, strides) + mod = tvm.IRModule() + mod["main"] = relay.Function([boxes, scores], out) + check_result([boxes_np, scores_np], mod, [expected]) + else: + out = nms_out.tuple_value + mod = tvm.IRModule() + mod["main"] = relay.Function([boxes, scores], out) + check_result([boxes_np, scores_np], mod, expected) boxes = np.array( [ @@ -1668,6 +1669,39 @@ def verify_all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected ) + expected = [ + np.array( + [[[0, 4], [0, 2], [1, 4], [1, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]] + ), + np.array( + [ + [ + 0.9, + 0.6, + 0.9, + 0.8, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ] + ] + ), + np.array([4]), + ] + + verify_all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + expected, + output_format="tensorflow", + ) + boxes = np.array( [ [