diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 45d01ff44fe79..3a61f18eb36e2 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -117,11 +117,14 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { + Optional max_total_size; std::string output_format; - TVM_ATTR_FIELD(output_format).set_default("onnx").describe( - "Output format. onnx or tensorflow"); + TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs, - "relay.attrs.AllClassNonMaximumSuppressionAttrs") {} + "relay.attrs.AllClassNonMaximumSuppressionAttrs") { + TVM_ATTR_FIELD(max_total_size).set_default(NullValue()).describe("TODO"); + TVM_ATTR_FIELD(output_format).set_default("onnx").describe("Output format. onnx or tensorflow"); + } }; /*! \brief Attributes used in roi_align operators */ diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index c1087d2dafdf3..ca56cc2182820 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -798,6 +798,7 @@ def _impl(inputs, attr, params, mod): # Get parameter values boxes = inputs[0] scores = inputs[1] + try: max_output_boxes_per_class = int( np.atleast_1d(inputs[2].data.asnumpy().astype("int64"))[0] @@ -809,6 +810,7 @@ def _impl(inputs, attr, params, mod): ) except Exception: max_output_boxes_per_class = inputs[2] + max_total_size = inputs[3] iou_threshold = np.atleast_1d(inputs[4].data.numpy())[0] score_threshold = np.atleast_1d(inputs[5].data.numpy())[0] @@ -832,6 +834,7 @@ def _impl(inputs, attr, params, mod): max_output_boxes_per_class, iou_threshold, score_threshold, + max_total_size, output_format="tensorflow", ) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 0d6c3ef58cdf2..451d01a4fc053 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1095,7 +1095,17 @@ def _compute_nms(attrs, inputs, out_type): max_output_size = inputs[2] iou_threshold = inputs[3] score_threshold = inputs[4] - return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold) + max_total_size = attrs.max_total_size + output_format = attrs.output_format + return topi_compute( + inputs[0], + inputs[1], + max_output_size, + iou_threshold, + score_threshold, + max_total_size, + output_format, + ) return _compute_nms diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 703edb040bb74..785579cd7973f 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -152,7 +152,13 @@ def non_max_suppression( def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0, output_format="onnx" + boxes, + scores, + max_output_boxes_per_class=-1, + iou_threshold=-1.0, + score_threshold=-1.0, + max_total_size=None, + output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -195,6 +201,12 @@ def all_class_non_max_suppression( score_threshold = expr.const(score_threshold, "float32") out = _make.all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + max_total_size, + output_format, ) return expr.TupleWrapper(out, 2) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 3e6c8cc932377..2bb50ef8ebd98 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -1066,7 +1066,7 @@ def collect_selected_indices_tf(selected_indices, selected_scores, num_detection def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format="onnx" + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, max_total_size, output_format="onnx" ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -1127,7 +1127,7 @@ def all_class_non_max_suppression( ) return [selected_indices, num_total_detections] - max_detection_per_batch = 100 + max_detection_per_batch = max_total_size selected_indices, selected_scores, num_detections = run_all_class_nms( boxes, diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 53cd71745d5b4..9524d268a9266 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -157,19 +157,34 @@ bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs num_total_boxes = batch * num_classes * num_boxes; } - // assign output type + const auto* param = attrs.as(); + CHECK(param); + std::vector fields; - std::vector oshape{num_total_boxes, 3}; - fields.push_back(TensorType(oshape, DataType::Int(64))); - std::vector countshape{1}; - fields.push_back(TensorType(countshape, DataType::Int(64))); + if (param->output_format == "onnx") { + std::vector oshape{num_total_boxes, 3}; + std::vector countshape{1}; + fields.push_back(TensorType(oshape, DataType::Int(64))); + fields.push_back(TensorType(countshape, DataType::Int(64))); + } else { + ICHECK(param->max_total_size) << "max_total_size required for tf mode"; + Integer max_total_size = param->max_total_size.value(); + std::vector oshape{batch, max_total_size, 2}; + std::vector countshape{1}; + fields.push_back(TensorType(oshape, DataType::Int(64))); + fields.push_back(TensorType(countshape, DataType::Int(64))); + } + reporter->Assign(types[5], TupleType(Array(fields))); return true; } Expr MakeAllClassNMS(Expr boxes, Expr scores, Expr max_output_boxes_per_class, Expr iou_threshold, - Expr score_threshold) { + Expr score_threshold, Optional max_total_size = NullValue(), + std::string output_format = "onnx") { auto attrs = make_object(); + attrs->max_total_size = std::move(max_total_size); + attrs->output_format = std::move(output_format); static const Op& op = Op::Get("vision.all_class_non_max_suppression"); return Call(op, {boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold}, Attrs(attrs), {});