diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 005b900d5d44..3a61f18eb36e 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -117,8 +117,14 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { + Optional max_total_size; + std::string output_format; + TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs, - "relay.attrs.AllClassNonMaximumSuppressionAttrs") {} + "relay.attrs.AllClassNonMaximumSuppressionAttrs") { + TVM_ATTR_FIELD(max_total_size).set_default(NullValue()).describe("TODO"); + TVM_ATTR_FIELD(output_format).set_default("onnx").describe("Output format. onnx or tensorflow"); + } }; /*! \brief Attributes used in roi_align operators */ diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4af73702ad9c..74c2233a517a 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -794,22 +794,76 @@ def _impl(inputs, attr, params, mod): def _combined_nms(): + def all_class_impl( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + max_total_size, + clip_boxes, + mod, + ): + indices, num_detections = _op.vision.all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + max_total_size, + output_format="tensorflow", + ) + nmsed_box_indices = _op.take(indices, _op.const(1), axis=2) + nmsed_classes = _op.cast(_op.take(indices, _op.const(0), axis=2), "float32") + nmsed_boxes = _op.gather_nd(boxes, _op.expand_dims(nmsed_box_indices, axis=0), batch_dims=1) + + indices_shape = _infer_shape(indices, mod) + indices_dims = len(indices_shape) + indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1))) + nmsed_scores = _op.gather_nd(scores, indices, batch_dims=1) + + if clip_boxes: + nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32")) + nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32")) + + # Fill in invalid entries with 0 + box_range = _op.arange( + _op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64" + ) + batch_size = indices_shape[0] + + if isinstance(batch_size, tvm.tir.Any): + box_range_2d = _op.tile(box_range, _op.concatenate([batch_size, 1])) + else: + box_range_2d = _op.tile(box_range, _op.const([batch_size, 1])) + + valid_mask = _op.cast( + _op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32" + ) + nmsed_scores = nmsed_scores * valid_mask + nmsed_classes = nmsed_classes * valid_mask + nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2) + + return _expr.TupleWrapper( + _expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, num_detections]), 4 + ) + def _impl(inputs, attr, params, mod): # Get parameter values boxes = inputs[0] scores = inputs[1] try: - max_output_size = int(np.atleast_1d(inputs[2].data.numpy().astype("int64"))[0]) + max_output_size = int(np.atleast_1d(inputs[2].data.asnumpy().astype("int64"))[0]) except Exception: try: max_output_size = ( - _infer_value(inputs[2], params, mod).numpy().astype("int64").tolist()[0] + _infer_value(inputs[2], params, mod).asnumpy().astype("int64").tolist()[0] ) except Exception: max_output_size = inputs[2] max_total_size = inputs[3] - iou_threshold = np.atleast_1d(inputs[4].data.numpy())[0] - score_threshold = np.atleast_1d(inputs[5].data.numpy())[0] + iou_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0] + score_threshold = np.atleast_1d(inputs[5].data.asnumpy())[0] if attr["pad_per_class"]: raise tvm.error.OpAttributeUnImplemented( "pad_per_class for CombinedNonMaxSuppression is not supported" @@ -821,9 +875,20 @@ def _impl(inputs, attr, params, mod): q = boxes_shape[2] num_classes = scores_shape[2] - if q != num_classes: - # When q is 1, it means same box coords are used for all classes. - boxes = _op.broadcast_to(boxes, (batch_size, num_anchors, num_classes, 4)) + if q == 1: + boxes = _op.squeeze(boxes, axis=[2]) + scores_trans = _op.transpose(scores, [0, 2, 1]) + return all_class_impl( + boxes, + scores_trans, + max_output_size, + iou_threshold, + score_threshold, + max_total_size.data.numpy().item(), + attr["clip_boxes"], + mod, + ) + boxes = _op.reshape(boxes, newshape=[batch_size, num_anchors * num_classes, 4]) scores = _op.reshape(scores, newshape=[batch_size, num_anchors * num_classes, 1]) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 0d6c3ef58cdf..451d01a4fc05 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1095,7 +1095,17 @@ def _compute_nms(attrs, inputs, out_type): max_output_size = inputs[2] iou_threshold = inputs[3] score_threshold = inputs[4] - return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold) + max_total_size = attrs.max_total_size + output_format = attrs.output_format + return topi_compute( + inputs[0], + inputs[1], + max_output_size, + iou_threshold, + score_threshold, + max_total_size, + output_format, + ) return _compute_nms diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 3f829e0b1cc7..785579cd7973 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -152,7 +152,13 @@ def non_max_suppression( def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0 + boxes, + scores, + max_output_boxes_per_class=-1, + iou_threshold=-1.0, + score_threshold=-1.0, + max_total_size=None, + output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -185,6 +191,7 @@ def all_class_non_max_suppression( in descending of scores, followed by boxes from batch 0, class 1 etc. Out of `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` rows are valid. + TODO(trvmorr): explain tf mode """ if not isinstance(max_output_boxes_per_class, expr.Expr): max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32") @@ -194,6 +201,12 @@ def all_class_non_max_suppression( score_threshold = expr.const(score_threshold, "float32") out = _make.all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + max_total_size, + output_format, ) return expr.TupleWrapper(out, 2) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 9a3b86d72b18..bb079222de78 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -23,11 +23,13 @@ from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust from tvm.ir import register_intrin_lowering from tvm.tir import if_then_else -from .sort import argsort, argsort_thrust +from .sort import argsort, argsort_thrust, topk from .scan import exclusive_scan from ..utils import ceil_div from ..math import cast -from ..transform import reshape +from .. import reduction +from ..broadcast import minimum +from ..transform import reshape, strided_slice, gather_nd, expand_dims, squeeze from ..vision.nms_util import ( calculate_overlap, binary_search, @@ -988,8 +990,97 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro return ib.get() +def _collect_selected_indices_tf_ir( + num_class, + selected_indices, + selected_scores, + num_detections, + row_offsets, + collected_indices, + collected_scores, +): + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + selected_scores = ib.buffer_ptr(selected_scores) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + collected_indices = ib.buffer_ptr(collected_indices) + collected_scores = ib.buffer_ptr(collected_scores) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(num_boxes, nthread_tx) + nthread_by = batch_size * num_class + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + zero = cast(0, "int64") + + with ib.new_scope(): + idx = bx * nthread_tx + tx + idy = cast(by, "int64") + batch_id = idy // num_class + class_id = idy % num_class + offset = row_offsets[batch_id, class_id] + idx + + with ib.if_scope(idx < num_detections[batch_id, class_id]): + collected_indices[batch_id, offset, 0] = class_id + collected_indices[batch_id, offset, 1] = cast(selected_indices[idy, idx], "int64") + collected_scores[batch_id, offset] = selected_scores[idy, idx] + with ib.else_scope(): + with ib.if_scope(idx < num_boxes): + collected_indices[batch_id, offset, 0] = zero + collected_indices[batch_id, offset, 1] = zero + collected_scores[batch_id, offset] = -1.0 + + return ib.get() + + +def collect_selected_indices_tf(selected_indices, selected_scores, num_detections, row_offsets): + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + + selected_indices_buf = tvm.tir.decl_buffer( + selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8 + ) + selected_scores_buf = tvm.tir.decl_buffer( + selected_scores.shape, selected_scores.dtype, "selected_scores_buf", data_alignment=8 + ) + num_detections_buf = tvm.tir.decl_buffer( + num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8 + ) + row_offsets_buf = tvm.tir.decl_buffer( + row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8 + ) + + return te.extern( + [(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)], + [selected_indices, selected_scores, num_detections, row_offsets], + lambda ins, outs: _collect_selected_indices_tf_ir( + num_class, ins[0], ins[1], ins[2], ins[3], outs[0], outs[1] + ), + dtype=["int64", "float32"], + in_buffers=[selected_indices_buf, selected_scores_buf, num_detections_buf, row_offsets_buf], + name="collect_indices", + tag="collect_indices", + ) + + def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + max_total_size, + output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -1012,6 +1103,8 @@ def all_class_non_max_suppression( score_threshold : float or tvm.te.Tensor, optional Score threshold to filter out low score boxes early + output_format : str + Returns ------- out : [tvm.te.Tensor, tvm.te.Tensor] @@ -1029,7 +1122,28 @@ def all_class_non_max_suppression( sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both") valid_count = _get_valid_box_count(sorted_scores, score_threshold) - selected_indices, num_detections = run_all_class_nms( + if output_format == "onnx": + selected_indices, num_detections = run_all_class_nms( + boxes, + sorted_scores, + sorted_indices, + valid_count, + max_output_boxes_per_class, + iou_threshold, + _nms_loop, + ) + + row_offsets, num_total_detections = exclusive_scan( + num_detections, return_reduction=True, output_dtype="int64" + ) + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) + return [selected_indices, num_total_detections] + + max_detection_per_batch = max_total_size + + selected_indices, selected_scores, num_detections = run_all_class_nms( boxes, sorted_scores, sorted_indices, @@ -1037,14 +1151,20 @@ def all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, _nms_loop, + return_scores=True, ) + # tf mode, return (batch_size, max_total_size, 2) + num_detections_per_batch = reshape(num_detections, (batch, num_class)) row_offsets, num_total_detections = exclusive_scan( - num_detections, return_reduction=True, output_dtype="int64" + num_detections_per_batch, return_reduction=True, output_dtype="int64", axis=1 ) - - selected_indices = collect_selected_indices( - num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + selected_indices, selected_scores = collect_selected_indices_tf( + selected_indices, selected_scores, num_detections_per_batch, row_offsets ) + topk_indices = topk(selected_scores, k=max_detection_per_batch, axis=1, ret_type="indices")[0] + topk_indices = expand_dims(topk_indices, axis=0) + final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1) + num_detections = minimum(num_total_detections, max_detection_per_batch) - return [selected_indices, num_total_detections] + return [final_indices, num_detections] diff --git a/python/tvm/topi/cuda/vision.py b/python/tvm/topi/cuda/vision.py index 88983ab89f76..5208aeccd413 100644 --- a/python/tvm/topi/cuda/vision.py +++ b/python/tvm/topi/cuda/vision.py @@ -39,7 +39,9 @@ def traverse(op): traverse(tensor.op) scheduled_ops.append(op) - traverse(outs[0].op) + for o in outs: + traverse(o.op) + return s diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index df30ff775f60..8f7063e48715 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -483,7 +483,7 @@ def gather(data, axis, indices): return cpp.gather(data, axis, indices) -def gather_nd(a, indices): +def gather_nd(a, indices, batch_dims=0): """Gather elements from a n-dimension array.. Parameters @@ -498,7 +498,7 @@ def gather_nd(a, indices): ------- ret : tvm.te.Tensor """ - return cpp.gather_nd(a, indices) + return cpp.gather_nd(a, indices, batch_dims) def matmul(a, b, transp_a=False, transp_b=False): diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 744c5ef7feda..455111086bec 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -728,7 +728,7 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -750,6 +750,8 @@ def all_class_non_max_suppression( score_threshold : float or tvm.te.Tensor, optional Score threshold to filter out low score boxes early + + output_format : TODO Returns ------- @@ -783,8 +785,9 @@ def all_class_non_max_suppression( num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1) - selected_indices = collect_selected_indices( - num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir - ) + if output_format == "onnx": + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) return [selected_indices, num_total_detections] diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index 1147b1687783..1c2511d42cd9 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -139,6 +139,7 @@ def _all_class_nms_ir( iou_threshold, max_output_size_per_class, box_indices, + selected_scores, num_valid_boxes, nms_loop, ): @@ -150,6 +151,9 @@ def _all_class_nms_ir( box_indices = ib.buffer_ptr(box_indices) num_valid_boxes = ib.buffer_ptr(num_valid_boxes) + if selected_scores is not None: + selected_scores = ib.buffer_ptr(selected_scores) + if isinstance(iou_threshold, float): iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) @@ -171,6 +175,9 @@ def on_new_valid_box(ib, tid, num_current_valid_box, i, j): with ib.if_scope(tid + 0 == 0): box_indices[i, num_current_valid_box] = sorted_indices[i, j] + if selected_scores is not None: + selected_scores[i, num_current_valid_box] = sorted_scores[i, j] + def on_new_invalidated_box(*_): pass @@ -201,6 +208,7 @@ def run_all_class_nms( max_output_size_per_class, iou_threshold, nms_loop, + return_scores=False ): """The core all class NMS routine @@ -253,8 +261,38 @@ def run_all_class_nms( valid_count.shape, "int32", "valid_count_buf", data_alignment=4 ) + if return_scores is False: + return te.extern( + [(batch_class, num_boxes), (1, batch_class)], + [boxes, sorted_scores, sorted_indices, valid_count], + lambda ins, outs: _all_class_nms_ir( + ins[0], # boxes + ins[1], # sorted_scores + ins[2], # sorted_indices + ins[3], # valid_count + batch_class, + num_class, + num_boxes, + iou_threshold, + max_output_size_per_class, + outs[0], # box_indices + None, # scores + outs[1], # num_selected_boxes + nms_loop, + ), + dtype=["int32", "int32"], + in_buffers=[ + boxes_buf, + sorted_scores_buf, + sorted_indices_buf, + valid_count_buf, + ], + name="all_class_nms", + tag="all_class_nms", + ) + return te.extern( - [(batch_class, num_boxes), (1, batch_class)], + [(batch_class, num_boxes), (batch_class, num_boxes), (1, batch_class)], [boxes, sorted_scores, sorted_indices, valid_count], lambda ins, outs: _all_class_nms_ir( ins[0], # boxes @@ -267,10 +305,11 @@ def run_all_class_nms( iou_threshold, max_output_size_per_class, outs[0], # box_indices - outs[1], # num_selected_boxes + outs[1], # selected scores + outs[2], # num_selected_boxes nms_loop, ), - dtype=["int32", "int32"], + dtype=["int32", "float32", "int32"], in_buffers=[ boxes_buf, sorted_scores_buf, diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 53cd71745d5b..1e63ccd04721 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -157,19 +157,34 @@ bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs num_total_boxes = batch * num_classes * num_boxes; } - // assign output type + const auto* param = attrs.as(); + CHECK(param); + std::vector fields; - std::vector oshape{num_total_boxes, 3}; - fields.push_back(TensorType(oshape, DataType::Int(64))); - std::vector countshape{1}; - fields.push_back(TensorType(countshape, DataType::Int(64))); + if (param->output_format == "onnx") { + std::vector oshape{num_total_boxes, 3}; + std::vector countshape{1}; + fields.push_back(TensorType(oshape, DataType::Int(64))); + fields.push_back(TensorType(countshape, DataType::Int(64))); + } else { + ICHECK(param->max_total_size) << "max_total_size required for tf mode"; + Integer max_total_size = param->max_total_size.value(); + std::vector oshape{batch, max_total_size, 2}; + std::vector countshape{batch}; + fields.push_back(TensorType(oshape, DataType::Int(64))); + fields.push_back(TensorType(countshape, DataType::Int(64))); + } + reporter->Assign(types[5], TupleType(Array(fields))); return true; } Expr MakeAllClassNMS(Expr boxes, Expr scores, Expr max_output_boxes_per_class, Expr iou_threshold, - Expr score_threshold) { + Expr score_threshold, Optional max_total_size = NullValue(), + std::string output_format = "onnx") { auto attrs = make_object(); + attrs->max_total_size = std::move(max_total_size); + attrs->output_format = std::move(output_format); static const Op& op = Op::Get("vision.all_class_non_max_suppression"); return Call(op, {boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold}, Attrs(attrs), {}); diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 0bce3bbc7f53..44f74151075c 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -131,7 +131,7 @@ TVM_REGISTER_GLOBAL("topi.gather").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = gather_nd(args[0], args[1]); + *rv = gather_nd(args[0], args[1], args[2]); }); TVM_REGISTER_GLOBAL("topi.unravel_index").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index f29450dbb604..ff000a0aa9b7 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -151,7 +151,6 @@ def run_tvm_graph( return vmobj_to_list(result) elif mode == "vm": with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): - print(mod["main"]) mod = relay.transform.InferType()(mod) vm_exec = relay.vm.compile(mod, target="llvm", params=params) if serialize: