diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h new file mode 100644 index 000000000000..2fd98533b589 --- /dev/null +++ b/include/tvm/relax/attrs/vision.h @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/relax/attrs/vision.h + * \brief Auxiliary attributes for vision operators. + */ +#ifndef TVM_RELAX_ATTRS_VISION_H_ +#define TVM_RELAX_ATTRS_VISION_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in AllClassNonMaximumSuppression operator */ +struct AllClassNonMaximumSuppressionAttrs + : public AttrsNodeReflAdapter { + ffi::String output_format; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro( + "output_format", &AllClassNonMaximumSuppressionAttrs::output_format, + "Output format, onnx or tensorflow. Returns outputs in a way that can be easily " + "consumed by each frontend."); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllClassNonMaximumSuppressionAttrs", + AllClassNonMaximumSuppressionAttrs, BaseAttrsNode); +}; // struct AllClassNonMaximumSuppressionAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_VISION_H_ diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 5470c911d30b..abee4911033e 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3386,6 +3386,182 @@ def _impl_v11(cls, bb, inputs, attr, params): return input_sequence[position] +class NonMaxSuppression(OnnxOpConverter): + """Converts an onnx NonMaxSuppression node into an equivalent Relax expression.""" + + @classmethod + def _impl_v10(cls, bb, inputs, attr, params): + """ + NonMaxSuppression performs non-maximum suppression (NMS) on all classes. + + Inputs: + - boxes: (N, 4) tensor of bounding boxes in format [x1, y1, x2, y2] + - scores: (N, C) tensor of scores for each box and class + - max_output_boxes_per_class: maximum number of boxes to keep per class + - iou_threshold: IoU threshold for NMS + - score_threshold: score threshold for filtering + + Outputs: + - selected_indices: (M, 3) tensor with [batch_idx, class_idx, box_idx] + """ + boxes = inputs[0] + scores = inputs[1] + max_output_boxes_per_class = inputs[2] if len(inputs) > 2 else None + iou_threshold = inputs[3] if len(inputs) > 3 else None + score_threshold = inputs[4] if len(inputs) > 4 else None + + center_point_box = attr.get("center_point_box", 0) + + if max_output_boxes_per_class is not None and isinstance( + max_output_boxes_per_class, relax.Constant + ): + max_output_boxes_per_class = int(max_output_boxes_per_class.data.numpy()) + elif max_output_boxes_per_class is not None and isinstance( + max_output_boxes_per_class, relax.Var + ): + var_name = max_output_boxes_per_class.name_hint + if var_name in params[1]: + _, param_value = params[1][var_name] + max_output_boxes_per_class = int(param_value.numpy().item()) + else: + max_output_boxes_per_class = 100 # Default value + else: + max_output_boxes_per_class = 100 # Default value + + if iou_threshold is not None and isinstance(iou_threshold, relax.Constant): + iou_threshold = float(iou_threshold.data.numpy()) + else: + iou_threshold = 0.5 # Default value + + if score_threshold is not None and isinstance(score_threshold, relax.Constant): + score_threshold = float(score_threshold.data.numpy()) + elif score_threshold is not None and isinstance(score_threshold, relax.Var): + var_name = score_threshold.name_hint + if var_name in params[1]: + _, param_value = params[1][var_name] + score_threshold = float(param_value.numpy().item()) + else: + score_threshold = 0.0 # Default value + else: + score_threshold = 0.0 # Default value + + if center_point_box != 0: + split_result = relax.op.split(boxes, 4, axis=2) + xc = split_result[0] + yc = split_result[1] + w = split_result[2] + h = split_result[3] + half_w = w / relax.const(2.0, boxes.struct_info.dtype) + half_h = h / relax.const(2.0, boxes.struct_info.dtype) + x1 = xc - half_w + x2 = xc + half_w + y1 = yc - half_h + y2 = yc + half_h + boxes = relax.op.concat([y1, x1, y2, x2], axis=2) + + nms_out = bb.normalize( + relax.op.vision.all_class_non_max_suppression( + boxes, + scores, + relax.const(max_output_boxes_per_class, dtype="int64"), + relax.const(iou_threshold, dtype="float32"), + relax.const(score_threshold, dtype="float32"), + output_format="onnx", + ) + ) + + selected_indices = bb.emit(relax.TupleGetItem(nms_out, 0)) + + return selected_indices + + +class AllClassNMS(OnnxOpConverter): + """Converts an onnx AllClassNMS node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + """ + AllClassNMS performs non-maximum suppression (NMS) on all classes. + + Inputs: + - boxes: (N, 4) tensor of bounding boxes in format [x1, y1, x2, y2] + - scores: (N, C) tensor of scores for each box and class + - max_output_boxes_per_class: maximum number of boxes to keep per class + - iou_threshold: IoU threshold for NMS + - score_threshold: score threshold for filtering + + Outputs: + - selected_indices: (M, 3) tensor with [batch_idx, class_idx, box_idx] + """ + boxes = inputs[0] + scores = inputs[1] + max_output_boxes_per_class = inputs[2] if len(inputs) > 2 else None + iou_threshold = inputs[3] if len(inputs) > 3 else None + score_threshold = inputs[4] if len(inputs) > 4 else None + + center_point_box = attr.get("center_point_box", 0) + + if max_output_boxes_per_class is not None and isinstance( + max_output_boxes_per_class, relax.Constant + ): + max_output_boxes_per_class = int(max_output_boxes_per_class.data.numpy()) + elif max_output_boxes_per_class is not None and isinstance( + max_output_boxes_per_class, relax.Var + ): + var_name = max_output_boxes_per_class.name_hint + if var_name in params[1]: + _, param_value = params[1][var_name] + max_output_boxes_per_class = int(param_value.numpy().item()) + else: + max_output_boxes_per_class = 100 # Default value + else: + max_output_boxes_per_class = 100 # Default value + + if iou_threshold is not None and isinstance(iou_threshold, relax.Constant): + iou_threshold = float(iou_threshold.data.numpy()) + else: + iou_threshold = 0.5 # Default value + + if score_threshold is not None and isinstance(score_threshold, relax.Constant): + score_threshold = float(score_threshold.data.numpy()) + elif score_threshold is not None and isinstance(score_threshold, relax.Var): + var_name = score_threshold.name_hint + if var_name in params[1]: + _, param_value = params[1][var_name] + score_threshold = float(param_value.numpy().item()) + else: + score_threshold = 0.0 # Default value + else: + score_threshold = 0.0 # Default value + + if center_point_box != 0: + split_result = relax.op.split(boxes, 4, axis=2) + xc = split_result[0] + yc = split_result[1] + w = split_result[2] + h = split_result[3] + half_w = w / relax.const(2.0, boxes.struct_info.dtype) + half_h = h / relax.const(2.0, boxes.struct_info.dtype) + x1 = xc - half_w + x2 = xc + half_w + y1 = yc - half_h + y2 = yc + half_h + boxes = relax.op.concat([y1, x1, y2, x2], axis=2) + + nms_out = bb.normalize( + relax.op.vision.all_class_non_max_suppression( + boxes, + scores, + relax.const(max_output_boxes_per_class, dtype="int64"), + relax.const(iou_threshold, dtype="float32"), + relax.const(score_threshold, dtype="float32"), + output_format="onnx", + ) + ) + + return nms_out + + def _get_convert_map(): return { # defs/experimental @@ -3536,7 +3712,8 @@ def _get_convert_map(): # "LRN": LRN, # "MaxRoiPool": MaxRoiPool, # "RoiAlign": RoiAlign, - # "NonMaxSuppression": NonMaxSuppression, + "NonMaxSuppression": NonMaxSuppression, + "AllClassNMS": AllClassNMS, # "GridSample": GridSample, "Upsample": Upsample, # others diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index fd3672368b68..e1635d64e63a 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -154,6 +154,7 @@ tanh, trunc, ) +from .vision import all_class_non_max_suppression def _register_op_make(): diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 4062aae0c7c4..229a789a45ef 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -239,6 +239,11 @@ class AttentionAttrs(Attrs): """Attributes used in attention operator""" +@tvm_ffi.register_object("relax.attrs.AllClassNonMaximumSuppressionAttrs") +class AllClassNonMaximumSuppressionAttrs(Attrs): + """Attributes for vision.all_class_non_max_suppression""" + + @tvm_ffi.register_object("relax.attrs.Conv1DAttrs") class Conv1DAttrs(Attrs): """Attributes for nn.conv1d""" diff --git a/python/tvm/relax/op/vision/__init__.py b/python/tvm/relax/op/vision/__init__.py new file mode 100644 index 000000000000..be45458d3647 --- /dev/null +++ b/python/tvm/relax/op/vision/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""VISION operators.""" +from .nms import * diff --git a/python/tvm/relax/op/vision/_ffi_api.py b/python/tvm/relax/op/vision/_ffi_api.py new file mode 100644 index 000000000000..8af761dc5a00 --- /dev/null +++ b/python/tvm/relax/op/vision/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +import tvm_ffi + +tvm_ffi.init_ffi_api("relax.op.vision", __name__) diff --git a/python/tvm/relax/op/vision/nms.py b/python/tvm/relax/op/vision/nms.py new file mode 100644 index 000000000000..3714b00b01e2 --- /dev/null +++ b/python/tvm/relax/op/vision/nms.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Non-maximum suppression operator""" +# from tvm import relax # Unused import +from . import _ffi_api + + +def all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + output_format="onnx", +): + """Non-maximum suppression operator for object detection, corresponding to ONNX + NonMaxSuppression and TensorFlow combined_non_max_suppression. + NMS is performed for each class separately. + + Parameters + ---------- + boxes : relax.Expr + 3-D tensor with shape (batch_size, num_boxes, 4) + scores: relax.Expr + 3-D tensor with shape (batch_size, num_classes, num_boxes) + max_output_boxes_per_class : relax.Expr + The maxinum number of output selected boxes per class + iou_threshold : relax.Expr + IoU test threshold + score_threshold : relax.Expr + Score threshold to filter out low score boxes early + output_format : str, optional + "onnx" or "tensorflow", see below. + + Returns + ------- + out : relax.Expr + If `output_format` is "onnx", the output is two tensors. The first is `indices` of size + `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor + `num_total_detection` of shape `(1,)` representing the total number of selected + boxes. The three values in `indices` encode batch, class, and box indices. + Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of + `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` + rows are valid. + + TODO: Implement true dynamic output shapes to match ONNX Runtime behavior exactly. + This would eliminate the need for manual trimming and improve memory efficiency. + If `output_format` is "tensorflow", the output is three tensors, the first + is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of + size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size + `(batch_size,)` representing the total number of selected boxes per batch. The two values + in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at + batch b, only the first `num_total_detection[b]` entries are valid. The second axis of + `indices` and `scores` are sorted within each class by box scores, but not across classes. + So the box indices and scores for the class 0 come first in a sorted order, followed by + the class 1 etc. + """ + return _ffi_api.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format + ) diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py b/python/tvm/relax/transform/legalize_ops/__init__.py index b4aba0291fc1..5614d0229646 100644 --- a/python/tvm/relax/transform/legalize_ops/__init__.py +++ b/python/tvm/relax/transform/legalize_ops/__init__.py @@ -31,3 +31,4 @@ from . import search from . import statistical from . import unary +from . import vision diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py new file mode 100644 index 000000000000..f910f62cec64 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default legalization function for vision network related operators.""" +from tvm import topi, te +from tvm import relax +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import register_legalize + + +def _create_onnx_nms_te(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold): + """Create a proper NMS implementation that follows the correct algorithm""" + scores_shape = list(scores.shape) + if len(scores_shape) == 3: + batch, num_classes, _ = scores_shape + elif len(scores_shape) == 2: + num_classes, _ = scores_shape + batch = 1 + else: + raise ValueError(f"Unexpected scores shape: {scores_shape}") + + if hasattr(max_output_boxes_per_class, "data"): + max_boxes = int(max_output_boxes_per_class.data.numpy()) + else: + max_boxes = 3 # Default value + + expected_detections = batch * num_classes * max_boxes + + selected_indices_full, _ = topi.vision.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" + ) + + def slice_to_onnx_shape(data, expected_size): + def compute_element(i, j): + return tvm.tir.if_then_else(i < expected_size, data[i, j], tvm.tir.Cast("int64", 0)) + + return te.compute((expected_size, 3), compute_element, name="sliced_indices") + + sliced_indices = slice_to_onnx_shape(selected_indices_full, expected_detections) + + actual_detections = te.compute( + (1,), lambda i: tvm.tir.Cast("int64", expected_detections), name="actual_detections" + ) + + return [sliced_indices, actual_detections] + + +@register_legalize("relax.vision.all_class_non_max_suppression") +def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr: + """Legalize all_class_non_max_suppression with fixed shape output. + + Note: This implementation outputs fixed-size tensors with trailing garbage data. + Only the first `num_total_detection` rows contain valid data. Users should use + the `valid_count` tensor to determine how many rows are actually valid. + + For complete ONNX compatibility, users can post-process the output: + ```python + selected_indices, valid_count = nms_output + actual_count = int(valid_count.numpy()[0]) + valid_indices = selected_indices.numpy()[:actual_count, :] + ``` + """ + boxes = call.args[0] + scores = call.args[1] + max_output_boxes_per_class = call.args[2] + iou_threshold = call.args[3] + score_threshold = call.args[4] + output_format = call.attrs.output_format + + scores_shape = scores.struct_info.shape + if len(scores_shape) == 3: + _, _, num_boxes = scores_shape + elif len(scores_shape) == 2: + _, num_boxes = scores_shape + else: + raise ValueError(f"Unexpected scores shape: {scores_shape}") + + if isinstance(max_output_boxes_per_class, relax.Constant): + max_boxes_val = int(max_output_boxes_per_class.data.numpy()) + else: + max_boxes_val = int(num_boxes) + + # Get NMS result with fixed shape from TOPI + nms_result = block_builder.call_te( + topi.vision.all_class_non_max_suppression, + boxes, + scores, + max_boxes_val, + iou_threshold, + score_threshold, + output_format, + ) + + # TODO: Implement dynamic output trimming for better memory efficiency + # Current approach returns fixed-size output with trailing garbage data + # Future improvements could include: + # 1. Dynamic strided_slice based on num_total_detections + # 2. Custom Relax operator with true dynamic shapes + # 3. VM builtin functions for runtime shape adjustment + # 4. Symbolic shape inference in Relax IR + # + # For now, users should trim manually: + # actual_count = int(num_total_detections.numpy()[0]) + # valid_indices = selected_indices.numpy()[:actual_count, :] + + return nms_result diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index d28ff3430aaa..1b69a794e6b4 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -186,6 +186,7 @@ wrap_param, zeros, zeros_like, + vision, ) from tvm.relax.op.builtin import stop_lift_params from tvm.relax.struct_info import StructInfo @@ -896,4 +897,5 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "nn", "ccl", "erf", + "vision", ] diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 9503aea0cd2f..c73e8bf54cf5 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -50,6 +50,7 @@ from . import nn from . import utils from . import image +from . import vision from . import gpu # error reporting diff --git a/python/tvm/topi/cpp/vision/__init__.py b/python/tvm/topi/cpp/vision/__init__.py index 8acbb3861067..467ce70fbd33 100644 --- a/python/tvm/topi/cpp/vision/__init__.py +++ b/python/tvm/topi/cpp/vision/__init__.py @@ -19,5 +19,6 @@ import tvm_ffi from . import yolo +from ...vision import nms tvm_ffi.init_ffi_api("topi.vision", "tvm.topi.cpp.vision") diff --git a/python/tvm/topi/vision/__init__.py b/python/tvm/topi/vision/__init__.py new file mode 100644 index 000000000000..f12758bb9c0a --- /dev/null +++ b/python/tvm/topi/vision/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Vision operators.""" +from .nms import * diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py new file mode 100644 index 000000000000..f4aae45ef9c5 --- /dev/null +++ b/python/tvm/topi/vision/nms.py @@ -0,0 +1,500 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-error, invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements, too-many-function-args +"""Non-maximum suppression operator""" +import tvm +from tvm import te + +from tvm.tir import if_then_else + +from ..sort import argsort +from ..math import cast +from ..transform import reshape, gather +from .. import reduction +from ..scan import cumsum +from .nms_util import ( + binary_search, + collect_selected_indices, + collect_selected_indices_and_scores, + run_all_class_nms, +) + + +def get_valid_counts( + data, score_threshold=0, id_index=0, score_index=1 +): # pylint: disable=unused-argument + """Get valid count of bounding boxes given a score threshold. + Also moves valid boxes to the top of input data. + Parameters + ---------- + data : tvm.te.Tensor + Input data. 3-D tensor with shape [batch_size, num_anchors, 6] + or [batch_size, num_anchors, 5]. + score_threshold : optional, float + Lower limit of score for valid bounding boxes. + id_index : optional, int + index of the class categories, -1 to disable. + score_index: optional, int + Index of the scores/confidence of boxes. + Returns + ------- + valid_count : tvm.te.Tensor + 1-D tensor for valid number of boxes. + out_tensor : tvm.te.Tensor + Rearranged data tensor. + out_indices: tvm.te.Tensor or numpy NDArray + Related index in input data. + """ + if isinstance(score_threshold, (float, int)): + score_threshold = tvm.tir.const(score_threshold, dtype=data.dtype) + # id_index_const = tvm.tir.const(id_index, "int32") # Unused + # score_index_const = tvm.tir.const(score_index, "int32") # Unused + return ( + te.compute((data.shape[0],), lambda i: data.shape[1], name="valid_count"), + data, + te.compute((data.shape[0], data.shape[1]), lambda i, j: j, name="out_indices"), + ) + + +def _nms_loop( + ib, + batch_size, + top_k, + iou_threshold, + max_output_size, + valid_count, + on_new_valid_box_func, + on_new_invalidated_box_func, + needs_bbox_check_func, + calc_overlap_func, + out_scores, + num_valid_boxes, + score_threshold=None, +): + def nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local): + on_new_valid_box_func(ib, 0, num_valid_boxes_local[0], i, j) + num_valid_boxes_local[0] += 1 + + num_boxes_to_check = nkeep - (j + 1) + + with ib.for_range(0, num_boxes_to_check, name="_k", kind="parallel") as _k: + k = j + 1 + _k + + with ib.if_scope( + tvm.tir.all( + k < nkeep, + out_scores[i, k] > 0, # is the box k still valid? + needs_bbox_check_func(i, j, k), + ) + ): + iou = calc_overlap_func(i, j, k) + + with ib.if_scope(iou >= iou_threshold): + out_scores[i, k] = -1.0 + on_new_invalidated_box_func(i, k) + + with ib.for_range(0, batch_size, name="i") as i: + nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) + # Use max_output_size directly without if_then_else + # max_output_size = if_then_else(max_output_size > te.const(0), max_output_size, nkeep) + + with ib.if_scope(tvm.tir.all(iou_threshold > te.const(0), valid_count[i] > te.const(0))): + num_valid_boxes_local = ib.allocate( + "int32", (1,), name="num_valid_boxes_local", scope="local" + ) + num_valid_boxes_local[0] = 0 + + # Use for_range to iterate through all boxes, but limit selection count + with ib.for_range(0, nkeep, name="j") as j: + with ib.if_scope( + tvm.tir.all( + out_scores[i, j] > -1.0, # box is still valid + num_valid_boxes_local[0] < max_output_size, # haven't reached max limit + ) + ): + if score_threshold is not None: + with ib.if_scope(out_scores[i, j] > score_threshold[()]): + nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local) + else: + nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local) + + num_valid_boxes[i] = num_valid_boxes_local[0] + + with ib.else_scope(): + num_valid_boxes[i] = 0 + + return ib.get() + + +def _get_valid_box_count(scores, score_threshold): + batch_classes, num_boxes = scores.shape + + def searchsorted_ir(scores, score_thresh, valid_count): + ib = tvm.tir.ir_builder.create() + scores = ib.buffer_ptr(scores) + valid_count = ib.buffer_ptr(valid_count) + + with ib.for_range(0, batch_classes, name="i", kind="parallel") as i: + if hasattr(score_threshold, "shape"): + if len(score_threshold.shape) == 0: + score_thresh_scalar = score_thresh[()] + elif len(score_threshold.shape) == 1 and score_threshold.shape[0] > 0: + score_thresh_scalar = score_thresh[0] + else: + score_thresh_scalar = tvm.tir.FloatImm("float32", 0.0) + else: + score_thresh_scalar = score_threshold + binary_search(ib, i, num_boxes, scores, score_thresh_scalar, valid_count) + + return ib.get() + + scores_buf = tvm.tir.decl_buffer(scores.shape, scores.dtype, "scores_buf", data_alignment=8) + searchsorted_buf = tvm.tir.decl_buffer( + (batch_classes,), "int32", "searchsorted", data_alignment=8 + ) + + if hasattr(score_threshold, "shape"): + score_thresh_buf = tvm.tir.decl_buffer( + score_threshold.shape, score_threshold.dtype, "score_thresh_buf", data_alignment=8 + ) + return te.extern( + [(batch_classes,)], + [scores, score_threshold], + lambda ins, outs: searchsorted_ir(ins[0], ins[1], outs[0]), + dtype=["int32"], + in_buffers=[scores_buf, score_thresh_buf], + out_buffers=[searchsorted_buf], + name="searchsorted", + tag="searchsorted", + ) + else: + + def searchsorted_ir_scalar(scores, valid_count): + ib = tvm.tir.ir_builder.create() + scores = ib.buffer_ptr(scores) + valid_count = ib.buffer_ptr(valid_count) + + with ib.for_range(0, batch_classes, name="i", kind="parallel") as i: + if isinstance(score_threshold, te.Tensor): + if len(score_threshold.shape) == 0: + score_thresh_tir = score_threshold() + elif len(score_threshold.shape) == 1 and score_threshold.shape[0] == 1: + score_thresh_tir = score_threshold[0] + else: + score_thresh_tir = tvm.tir.FloatImm("float32", 0.0) + else: + score_thresh_tir = tvm.tir.FloatImm("float32", float(score_threshold)) + binary_search(ib, i, num_boxes, scores, score_thresh_tir, valid_count) + + return ib.get() + + return te.extern( + [(batch_classes,)], + [scores], + lambda ins, outs: searchsorted_ir_scalar(ins[0], outs[0]), + dtype=["int32"], + in_buffers=[scores_buf], + out_buffers=[searchsorted_buf], + name="searchsorted", + tag="searchsorted", + ) + + +def _collect_selected_indices_ir( + num_class, selected_indices, num_detections, row_offsets, out, max_output_boxes_per_class=None +): + batch_classes, _ = selected_indices.shape + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + out = ib.buffer_ptr(out) + + # Initialize output buffer to zero + # Calculate the actual output shape based on max_output_boxes_per_class + if isinstance(max_output_boxes_per_class, int): + max_output_rows = batch_classes * max_output_boxes_per_class + else: + # Fallback to a reasonable default if max_output_boxes_per_class is not an integer + max_output_rows = batch_classes * 10 + with ib.for_range(0, max_output_rows, name="init_i") as init_i: + with ib.for_range(0, 3, name="init_j") as init_j: # 3 columns + out[init_i, init_j] = cast(0, "int64") + + with ib.for_range(0, batch_classes, name="i", kind="parallel") as i: + i = cast(i, "int64") + batch_id = i // num_class + class_id = i % num_class + + if isinstance(max_output_boxes_per_class, int): + limit = tvm.tir.min( + num_detections[i], tvm.tir.IntImm("int32", max_output_boxes_per_class) + ) + elif isinstance(max_output_boxes_per_class, te.Tensor): + if len(max_output_boxes_per_class.shape) == 0: + max_boxes_val = max_output_boxes_per_class[()] + else: + max_boxes_val = max_output_boxes_per_class[0] + limit = tvm.tir.min(num_detections[i], max_boxes_val) + else: + limit = num_detections[i] + + with ib.for_range(0, limit, name="j") as j: + out[row_offsets[i] + j, 0] = batch_id + out[row_offsets[i] + j, 1] = class_id + out[row_offsets[i] + j, 2] = cast(selected_indices[i, j], "int64") + + return ib.get() + + +def _collect_selected_indices_and_scores_ir( + selected_indices, + selected_scores, + num_detections, + row_offsets, + num_total_detections, + collected_indices, + collected_scores, +): + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + selected_scores = ib.buffer_ptr(selected_scores) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + num_total_detections = ib.buffer_ptr(num_total_detections) + collected_indices = ib.buffer_ptr(collected_indices) + collected_scores = ib.buffer_ptr(collected_scores) + zero = cast(0, "int64") + + with ib.for_range(0, batch_size * num_class, name="i", kind="parallel") as i: + i = cast(i, "int64") + batch_id = i // num_class + class_id = i % num_class + + with ib.for_range(0, num_boxes, name="j") as j: + with ib.if_scope(j < num_detections[batch_id, class_id]): + offset = row_offsets[batch_id, class_id] + j + collected_indices[batch_id, offset, 0] = class_id + collected_indices[batch_id, offset, 1] = cast(selected_indices[i, j], "int64") + collected_scores[batch_id, offset] = selected_scores[i, j] + with ib.else_scope(): + offset = ( + num_total_detections[batch_id] + + class_id * num_boxes + - row_offsets[batch_id, class_id] + + j + - num_detections[batch_id, class_id] + ) + collected_indices[batch_id, offset, 0] = zero + collected_indices[batch_id, offset, 1] = zero + collected_scores[batch_id, offset] = 0.0 + + return ib.get() + + +def all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + output_format="onnx", + output_shape=None, +): + """Non-maximum suppression operator for object detection, corresponding to ONNX + NonMaxSuppression and TensorFlow combined_non_max_suppression. + NMS is performed for each class separately. + Parameters + ---------- + boxes : tvm.te.Tensor + 3-D tensor with shape (batch_size, num_boxes, 4) + scores: tvm.te.Tensor + 3-D tensor with shape (batch_size, num_classes, num_boxes) + max_output_boxes_per_class : int or tvm.te.Tensor, optional + The maxinum number of output selected boxes per class + iou_threshold : float or tvm.te.Tensor, optionaIl + IoU test threshold + score_threshold : float or tvm.te.Tensor, optional + Score threshold to filter out low score boxes early + output_format : str, optional + "onnx" or "tensorflow", see below. + Returns + ------- + out : list of tvm.te.Tensor + If `output_format` is "onnx", the output is two tensors. The first is `indices` of size + `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor + `num_total_detection` of shape `(1,)` representing the total number of selected + boxes. The three values in `indices` encode batch, class, and box indices. + Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of + `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` + rows are valid. + + .. note:: + **Important**: The output tensor has a fixed size based on `max_output_boxes_per_class`, + but only the first `num_total_detection` rows contain valid data. The remaining rows + may contain garbage values. When comparing with ONNX Runtime or other implementations + that output dynamic shapes, you should only compare the first + `num_total_detection` rows. + Example: + ```python + selected_indices, valid_count = nms_output + actual_count = int(valid_count.numpy()[0]) + valid_indices = selected_indices.numpy()[:actual_count, :] + ``` + If `output_format` is "tensorflow", the output is three tensors, the first + is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of + size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size + `(batch_size,)` representing the total number of selected boxes per batch. The two values + in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at + batch b, only the first `num_total_detection[b]` entries are valid. The second axis of + `indices` and `scores` are sorted within each class by box scores, but not across classes. + So the box indices and scores for the class 0 come first in a sorted order, followed by + the class 1 etc. + """ + batch, num_class, num_boxes = scores.shape + scores = reshape(scores, (batch * num_class, num_boxes)) + + sorted_indices = argsort(scores, axis=1, is_ascend=False, dtype="int32") + sorted_scores = gather(scores, 1, sorted_indices) + + if not isinstance(score_threshold, te.Tensor): + score_threshold_tensor = te.compute((), lambda: score_threshold, name="score_threshold") + else: + score_threshold_tensor = score_threshold + + valid_count = _get_valid_box_count(sorted_scores, score_threshold_tensor) + + selected_indices, selected_scores, num_detections = run_all_class_nms( + boxes, + sorted_scores, + sorted_indices, + valid_count, + max_output_boxes_per_class, + iou_threshold, + _nms_loop, + return_scores=(output_format == "tensorflow"), + score_threshold=score_threshold_tensor, # Passed score_threshold as tensor + ) + + if output_format == "onnx": + row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") + + def _sum_clamped_total(): + if isinstance(max_output_boxes_per_class, int): + k_expr = tvm.tir.IntImm("int32", int(max_output_boxes_per_class)) + clamped = te.compute( + num_detections.shape, + lambda i: tvm.tir.min(num_detections[i], k_expr), + name="clamped_num", + ) + return reduction.sum(cast(clamped, "int64"), axis=0) + if isinstance(max_output_boxes_per_class, tvm.tir.IntImm): + k_expr = tvm.tir.Cast("int32", max_output_boxes_per_class) + clamped = te.compute( + num_detections.shape, + lambda i: tvm.tir.min(num_detections[i], k_expr), + name="clamped_num", + ) + return reduction.sum(cast(clamped, "int64"), axis=0) + if isinstance(max_output_boxes_per_class, te.Tensor): + if len(max_output_boxes_per_class.shape) == 0: + kb = te.compute( + num_detections.shape, + lambda i: cast(max_output_boxes_per_class, "int32"), + name="k_broadcast", + ) + elif ( + len(max_output_boxes_per_class.shape) == 1 + and max_output_boxes_per_class.shape[0] == 1 + ): + kb = te.compute( + num_detections.shape, + lambda i: cast(max_output_boxes_per_class[0], "int32"), + name="k_broadcast", + ) + else: + return reduction.sum(cast(num_detections, "int64"), axis=0) + + clamped = te.compute( + num_detections.shape, + lambda i: tvm.tir.min(num_detections[i], kb[i]), + name="clamped_num", + ) + return reduction.sum(cast(clamped, "int64"), axis=0) + return reduction.sum(cast(num_detections, "int64"), axis=0) + + num_total_scalar = _sum_clamped_total() + num_total_detections = reshape(num_total_scalar, (1,)) + + if output_shape is not None: + selected_indices = collect_selected_indices( + num_class, + selected_indices, + num_detections, + row_offsets, + _collect_selected_indices_ir, + max_output_boxes_per_class=max_output_boxes_per_class, + output_shape=output_shape, + ) + else: + # Use num_total_detections to enable dynamic trimming + # Pass image size for intelligent default estimation + input_image_size = None + if hasattr(scores, "shape") and len(scores.shape) >= 3: + # Extract image size from scores shape: (batch, num_classes, num_boxes) + # We can estimate image size from num_boxes (more boxes = larger image) + input_image_size = (scores.shape[2],) # Use num_boxes as proxy for image size + + # TODO: Improve image size estimation by: + # 1. Accepting actual image dimensions as parameters + # 2. Using model metadata to infer typical image sizes + # 3. Learning from historical detection patterns + # 4. Providing user-configurable estimation strategies + + selected_indices = collect_selected_indices( + num_class, + selected_indices, + num_detections, + row_offsets, + _collect_selected_indices_ir, + max_output_boxes_per_class=max_output_boxes_per_class, + num_total_detections=num_total_detections, + input_image_size=input_image_size, + ) + return [selected_indices, num_total_detections] + + num_detections_per_batch = reshape(num_detections, (batch, num_class)) + row_offsets = cumsum(num_detections_per_batch, exclusive=True, dtype="int64", axis=1) + num_total_detections = reduction.sum(cast(num_detections_per_batch, "int64"), axis=1) + + selected_indices, selected_scores = collect_selected_indices_and_scores( + selected_indices, + selected_scores, + num_detections_per_batch, + row_offsets, + num_total_detections, + _collect_selected_indices_and_scores_ir, + ) + + return [selected_indices, selected_scores, num_total_detections] diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py new file mode 100644 index 000000000000..1633c923e17f --- /dev/null +++ b/python/tvm/topi/vision/nms_util.py @@ -0,0 +1,473 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Common utilities used in Non-maximum suppression operators""" +import tvm +from tvm import te + + +def _get_boundaries(output, box_idx): + l = tvm.te.min( + output[box_idx], + output[box_idx + 2], + ) + t = tvm.te.min( + output[box_idx + 1], + output[box_idx + 3], + ) + r = tvm.te.max( + output[box_idx], + output[box_idx + 2], + ) + b = tvm.te.max( + output[box_idx + 1], + output[box_idx + 3], + ) + return l, t, r, b + + +def calculate_overlap(out_tensor, box_a_idx, box_b_idx): + """Calculate overlap of two boxes.""" + a_l, a_t, a_r, a_b = _get_boundaries(out_tensor, box_a_idx) + b_l, b_t, b_r, b_b = _get_boundaries(out_tensor, box_b_idx) + + # Overlapping width and height + w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l)) + h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t)) + + # Overlapping area + area = h * w + + # total area of the figure formed by box a and box b + # except for overlapping area + u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area + return tvm.tir.Select(u <= 0.0, 0.0, area / u) + + +def binary_search(ib, y, num_boxes, scores, score_threshold, out): + """Binary search for score_threshold on scores sorted in descending order""" + lo = ib.allocate("int32", (1,), name="lo", scope="local") + hi = ib.allocate("int32", (1,), name="hi", scope="local") + + lo[0] = 0 + hi[0] = num_boxes.astype("int32") + + with ib.while_loop(lo[0] < hi[0]): + mid = (hi[0] + lo[0]) >> 1 + with ib.if_scope(scores[y, mid] > score_threshold): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid + + out[y] = lo[0] + + +def _estimate_max_detections(batch_class, input_image_size=None): + """Estimate maximum detections based on input image size and number of classes. + + This provides a more intelligent default for production environments. + """ + if input_image_size is not None: + # Estimate based on image size: larger images typically have more objects + if len(input_image_size) >= 2: + height, width = input_image_size[-2], input_image_size[-1] + total_pixels = height * width + + # Base estimation per class based on image size + if total_pixels < 300000: # Small images (< 300k pixels) + base_detections_per_class = min(50, max(10, total_pixels // 2000)) + elif total_pixels < 1000000: # Medium images (< 1M pixels) + base_detections_per_class = min(100, max(25, total_pixels // 3000)) + else: # Large images (>= 1M pixels) + base_detections_per_class = min(200, max(50, total_pixels // 4000)) + + # Scale down for many classes (more realistic for multi-class scenarios) + if batch_class > 20: + # For many classes, reduce per-class detections to avoid explosion + detections_per_class = min(base_detections_per_class, 50) + else: + detections_per_class = base_detections_per_class + else: + detections_per_class = 50 # fallback + else: + # Fallback to class-based estimation + if batch_class == 1: + detections_per_class = 100 # Single class detection + elif batch_class <= 10: + detections_per_class = 50 # Small multi-class + else: + detections_per_class = 25 # Large multi-class (COCO-like) + + return batch_class * detections_per_class + + +def collect_selected_indices( + num_class, + selected_indices, + num_detections, + row_offsets, + ir, + max_output_boxes_per_class=None, + output_shape=None, + num_total_detections=None, + input_image_size=None, +): + """Collect selected indices from the core NMS loop into one linear output + Parameters + ---------- + num_class : int + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the indices + of selected boxes by the core NMS loop. + num_detections tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), representing + the number of boxes selected by the core NMS loop, per batch and class + row_offsets tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), this should be the exclusive scan + of num_detections + ir : function + A function to generate IR for CPU or GPU, see its usage in vision/nms.py and cuda/nms.py + Returns + ------- + out : tvm.te.Tensor + The output is indices of size (batch_size * num_class* num_boxes , 3). + Rows of indices are ordered such that selected boxes from batch 0, class 0 come + first, in descending of scores, followed by boxes from batch 0, class 1 etc. + """ + batch_class, num_boxes = selected_indices.shape + + if output_shape is not None: + return te.extern( + [output_shape], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: ir( + num_class, ins[0], ins[1], ins[2], outs[0], max_output_boxes_per_class + ), + dtype=["int64"], + name="collect_indices", + tag="collect_indices", + ) + + # TODO: Implement dynamic trimming based on num_total_detections + if num_total_detections is not None: + if isinstance(max_output_boxes_per_class, int): + out_rows = batch_class * max_output_boxes_per_class + else: + # Smart fallback based on input image size and typical production scenarios + out_rows = _estimate_max_detections(batch_class, input_image_size) + + return te.extern( + [(out_rows, 3)], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: ir( + num_class, ins[0], ins[1], ins[2], outs[0], max_output_boxes_per_class + ), + dtype=["int64"], + name="collect_indices", + tag="collect_indices", + ) + + if isinstance(max_output_boxes_per_class, int): + out_rows = batch_class * max_output_boxes_per_class + return te.extern( + [(out_rows, 3)], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: ir( + num_class, ins[0], ins[1], ins[2], outs[0], max_output_boxes_per_class + ), + dtype=["int64"], + name="collect_indices", + tag="collect_indices", + ) + + if isinstance(max_output_boxes_per_class, te.Tensor): + try: + if len(max_output_boxes_per_class.shape) == 0: + max_boxes_val = int(max_output_boxes_per_class.data.numpy()) + elif ( + len(max_output_boxes_per_class.shape) == 1 + and max_output_boxes_per_class.shape[0] == 1 + ): + max_boxes_val = int(max_output_boxes_per_class.data.numpy()[0]) + else: + max_boxes_val = num_boxes + except (ValueError, IndexError, AttributeError): + max_boxes_val = num_boxes + + out_rows = batch_class * max_boxes_val + return te.extern( + [(out_rows, 3)], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: ir( + num_class, ins[0], ins[1], ins[2], outs[0], max_output_boxes_per_class + ), + dtype=["int64"], + name="collect_indices", + tag="collect_indices", + ) + + return te.extern( + [(batch_class * num_boxes, 3)], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: ir( + num_class, ins[0], ins[1], ins[2], outs[0], max_output_boxes_per_class + ), + dtype=["int64"], + name="collect_indices", + tag="collect_indices", + ) + + +def collect_selected_indices_and_scores( + selected_indices, selected_scores, num_detections, row_offsets, num_total_detections, ir +): + """Collect selected indices and scores from the core NMS loop into one linear output + Parameters + ---------- + num_class : int + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the indices + of selected boxes by the core NMS loop. + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the scores + of selected boxes by the core NMS loop. + num_detections tvm.te.Tensor + 2-D tensor with shape (batch_size, num_classes), representing + the number of boxes selected by the core NMS loop, per batch and class + row_offsets tvm.te.Tensor + 2-D tensor with shape (batch_size, num_classes), this should be the exclusive scan + of num_detections along axis 1 + ir : function + A function to generate IR for CPU or GPU, see its usage in vision/nms.py and cuda/nms.py + Returns + ------- + out : [tvm.te.Tensor, tvm.te.Tensor] + The output is two tensors. The first is indices of size + (batch_size, num_class* num_boxes, 2), and the second is scores of size + (batch_size, num_class* num_boxes). + """ + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + return te.extern( + [(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)], + [selected_indices, selected_scores, num_detections, row_offsets, num_total_detections], + lambda ins, outs: ir(ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], outs[1]), + dtype=["int64", "float32"], + name="collect_indices_and_scores", + tag="collect_indices_and_scores", + ) + + +def _all_class_nms_ir( + boxes, + sorted_scores, + sorted_indices, + valid_count, + batch_class, + num_class, + num_anchors, + iou_threshold, + max_output_size_per_class, + box_indices, + selected_scores, + num_valid_boxes, + nms_loop, + score_threshold=None, +): + ib = tvm.tir.ir_builder.create() + boxes = ib.buffer_ptr(boxes) + sorted_scores = ib.buffer_ptr(sorted_scores) + sorted_indices = ib.buffer_ptr(sorted_indices) + valid_count = ib.buffer_ptr(valid_count) + box_indices = ib.buffer_ptr(box_indices) + num_valid_boxes = ib.buffer_ptr(num_valid_boxes) + + if selected_scores is not None: + selected_scores = ib.buffer_ptr(selected_scores) + + if isinstance(iou_threshold, float): + iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) + elif isinstance(iou_threshold, te.Tensor): + if len(iou_threshold.shape) == 0: + iou_threshold = iou_threshold() + elif len(iou_threshold.shape) == 1 and iou_threshold.shape[0] == 1: + iou_threshold = iou_threshold[0] + else: + iou_threshold = tvm.tir.FloatImm("float32", 0.5) + + if isinstance(max_output_size_per_class, int): + max_output_size_per_class = tvm.tir.const(max_output_size_per_class) + elif isinstance(max_output_size_per_class, te.Tensor): + if len(max_output_size_per_class.shape) == 0: + max_output_size_per_class = max_output_size_per_class() + elif len(max_output_size_per_class.shape) == 1 and max_output_size_per_class.shape[0] == 1: + # Use tensor indexing to get the first element + max_output_size_per_class = max_output_size_per_class[0] + else: + max_output_size_per_class = tvm.tir.const(1000) + + def calc_overlap(i, j, k): + offset_j = sorted_indices[i, j] * 4 + offset_k = sorted_indices[i, k] * 4 + batch_id = i // num_class + base_bbox_idx = batch_id * num_anchors * 4 + return calculate_overlap( + boxes, + base_bbox_idx + offset_j, + base_bbox_idx + offset_k, + ) + + def on_new_valid_box(ib, tid, num_current_valid_box, i, j): + with ib.if_scope(tid + 0 == 0): + box_indices[i, num_current_valid_box] = sorted_indices[i, j] + + if selected_scores is not None: + selected_scores[i, num_current_valid_box] = sorted_scores[i, j] + + def on_new_invalidated_box(*_): + pass + + def needs_bbox_check(*_): + return tvm.tir.const(True) + + return nms_loop( + ib, + batch_class, + tvm.tir.IntImm("int32", -1), # top_k + iou_threshold, + max_output_size_per_class, + valid_count, + on_new_valid_box, + on_new_invalidated_box, + needs_bbox_check, + calc_overlap, + sorted_scores, + num_valid_boxes, + score_threshold, + ) + + +def run_all_class_nms( + boxes, + sorted_scores, + sorted_indices, + valid_count, + max_output_size_per_class, + iou_threshold, + nms_loop, + return_scores=False, + score_threshold=None, +): + """The core all class NMS routine + Parameters + ---------- + boxes : tvm.te.Tensor + 3-D tensor with shape (batch_size, num_boxes, 4) + sorted_scores: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes) + One of the outputs from argsort + sorted_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes) + The other output from argsort + valid_count: tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), representing + the number of boxes whose score is above score_threshold, per batch and class + max_output_boxes_per_class : int or tvm.te.Tensor, optional + The maxinum number of output selected boxes per class + iou_threshold : float or tvm.te.Tensor, optionaIl + IoU test threshold + nms_loop : function + A core NMS loop, see its usage in vision/nms.py and cuda/nms.py + return_scores : bool, optional + Whether or not to return selected scores, needed by the tensorflow output format. + Returns + ------- + out : a list of tvm.te.Tensor + The output is three tensors, the first and second are indices and scores of size + (batch_size * num_class, num_boxes), and the third is a tensor + num_selected_boxes of shape (batch_size * num_class,) representing the total number of + selected boxes per batch and class. If return_scores is False, the second output is + None. + """ + batch, num_boxes, _ = boxes.shape + batch_class = sorted_scores.shape[0] + num_class = batch_class // batch + + if return_scores is False: + all_class_num0_buf = tvm.tir.decl_buffer( + (batch_class, num_boxes), "int32", "all_class_nms0", data_alignment=8 + ) + all_class_num1_buf = tvm.tir.decl_buffer( + (batch_class,), "int32", "all_class_nms1", data_alignment=8 + ) + extern_inputs = [boxes, sorted_scores, sorted_indices, valid_count] + if score_threshold is not None: + extern_inputs.append(score_threshold) + + selected_indices, num_detections = te.extern( + [(batch_class, num_boxes), (batch_class,)], + extern_inputs, + lambda ins, outs: _all_class_nms_ir( + ins[0], # boxes + ins[1], # sorted_scores + ins[2], # sorted_indices + ins[3], # valid_count + batch_class, + num_class, + num_boxes, + iou_threshold, + max_output_size_per_class, + outs[0], # box_indices + None, # scores + outs[1], # num_selected_boxes + nms_loop, + ins[4] if score_threshold is not None else None, # score_threshold + ), + out_buffers=[all_class_num0_buf, all_class_num1_buf], + dtype=["int32", "int32"], + name="all_class_nms", + tag="all_class_nms", + ) + return selected_indices, None, num_detections + + extern_inputs = [boxes, sorted_scores, sorted_indices, valid_count] + if score_threshold is not None: + extern_inputs.append(score_threshold) + + return te.extern( + [(batch_class, num_boxes), (batch_class, num_boxes), (batch_class,)], + extern_inputs, + lambda ins, outs: _all_class_nms_ir( + ins[0], # boxes + ins[1], # sorted_scores + ins[2], # sorted_indices + ins[3], # valid_count + batch_class, + num_class, + num_boxes, + iou_threshold, + max_output_size_per_class, + outs[0], # box_indices + outs[1], # selected scores + outs[2], # num_selected_boxes + nms_loop, + ins[4] if score_threshold is not None else None, # score_threshold + ), + dtype=["int32", "float32", "int32"], + name="all_class_nms", + tag="all_class_nms", + ) diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h index bb4098ae82d2..f09dcb7f8230 100644 --- a/src/relax/ir/emit_te.h +++ b/src/relax/ir/emit_te.h @@ -51,6 +51,10 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode { .def_ro("shape", &RXPlaceholderOpNode::shape) .def_ro("dtype", &RXPlaceholderOpNode::dtype); } + + // FFI system configuration for structural equality and hashing + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.TEPlaceholderOp", RXPlaceholderOpNode, te::PlaceholderOpNode); }; diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc new file mode 100644 index 000000000000..2a1ad8f40aa4 --- /dev/null +++ b/src/relax/op/vision/nms.cc @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "nms.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +TVM_FFI_STATIC_INIT_BLOCK() { AllClassNonMaximumSuppressionAttrs::RegisterReflection(); } + +/* relax.vision.all_class_non_max_suppression */ + +Expr all_class_non_max_suppression(Expr boxes, Expr scores, Expr max_output_boxes_per_class, + Expr iou_threshold, Expr score_threshold, + ffi::String output_format) { + auto attrs = tvm::ffi::make_object(); + attrs->output_format = output_format; + + static const Op& op = Op::Get("relax.vision.all_class_non_max_suppression"); + return Call(op, + {std::move(boxes), std::move(scores), std::move(max_output_boxes_per_class), + std::move(iou_threshold), std::move(score_threshold)}, + Attrs(attrs), {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.vision.all_class_non_max_suppression", + all_class_non_max_suppression); +} + +StructInfo InferStructInfoAllClassNMS(const Call& call, const BlockBuilder& ctx) { + tvm::ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); + const auto boxes_sinfo = input_sinfo[0]; + const auto scores_sinfo = input_sinfo[1]; + ICHECK(!boxes_sinfo->IsUnknownNdim()) << "Only support known ndim"; + ICHECK(!scores_sinfo->IsUnknownNdim()) << "Only support known ndim"; + ICHECK_EQ(boxes_sinfo->ndim, 3) << "AllClassNMS input boxes should be 3-D."; + ICHECK_EQ(scores_sinfo->ndim, 3) << "AllClassNMS input scores count should be 3-D."; + + const auto batch = boxes_sinfo->shape.as()->values[0]; + const auto num_classes = scores_sinfo->shape.as()->values[1]; + const auto num_boxes = boxes_sinfo->shape.as()->values[1]; + + auto vdev = input_sinfo[0]->vdevice; + const auto* attrs = call->attrs.as(); + if (attrs->output_format == "onnx") { + auto vdev = input_sinfo[0]->vdevice; + auto num_total_boxes = batch * num_classes * num_boxes; + tvm::ffi::Array oshape_values = {num_total_boxes, 3}; + ShapeExpr oshape(oshape_values); + tvm::ffi::Array counts_values = {1}; + ShapeExpr counts_shape(counts_values); + tvm::ffi::Array fields = {TensorStructInfo(oshape, DataType::Int(64), vdev), + TensorStructInfo(counts_shape, DataType::Int(64), vdev)}; + return TupleStructInfo(fields); + } + + auto num_total_boxes_per_batch = num_classes * num_boxes; + tvm::ffi::Array indices_values = {batch, num_total_boxes_per_batch, 2}; + ShapeExpr indices_shape(indices_values); + tvm::ffi::Array scores_values = {batch, num_total_boxes_per_batch}; + ShapeExpr scores_shape(scores_values); + tvm::ffi::Array counts_values = {batch}; + ShapeExpr counts_shape(counts_values); + tvm::ffi::Array fields = {TensorStructInfo(indices_shape, DataType::Int(64), vdev), + TensorStructInfo(scores_shape, DataType::Float(32), vdev), + TensorStructInfo(counts_shape, DataType::Int(64), vdev)}; + return TupleStructInfo(fields); +} + +TVM_REGISTER_OP("relax.vision.all_class_non_max_suppression") + .set_attrs_type() + .set_num_inputs(5) + .add_argument("boxes", "Tensor", "The input boxes in the format [batch, num_boxes, 4].") + .add_argument("scores", "Tensor", + "Scores for each box and class in the format [batch, num_classes, num_boxes].") + .add_argument("max_output_boxes_per_class", "Tensor", + "The maximum number of output boxes per class.") + .add_argument("iou_threshold", "Tensor", "The IoU threshold for box the overlap test.") + .add_argument("score_threshold", "Tensor", + "The score threshold to filter out low score boxes early.") + .set_attr("FInferStructInfo", InferStructInfoAllClassNMS) + .set_attr("FPurity", Bool(true)); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/vision/nms.h b/src/relax/op/vision/nms.h new file mode 100644 index 000000000000..c86bf98c94d5 --- /dev/null +++ b/src/relax/op/vision/nms.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file nms.h + * \brief The functions to make Relax Non-maximum suppression operator calls. + */ + +#ifndef TVM_RELAX_OP_VISION_NMS_H_ +#define TVM_RELAX_OP_VISION_NMS_H_ + +#include +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief Compute All Class NonMaximumSuppression. */ +Expr all_class_non_max_suppression(Expr boxes, Expr scores, Expr max_output_boxes_per_class, + Expr iou_threshold, Expr score_threshold, + ffi::String output_format); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_VISION_NMS_H_ diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 24c16ab2683e..fa84ab3863fb 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -650,7 +650,10 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf // reads/writes filled in. BufferSubstituter substituter(var_map, input_buffer_map); - Stmt body = substituter(extern_op->body); + Stmt substituted_body = substituter(extern_op->body); + + ProducerToBufferTransformer transformer(info->tensor2buffers); + Stmt body = transformer(substituted_body); // Step 4. Generate opaque block as body. return BlockRealize(/*iter_values=*/{}, diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 625cdebf7f61..4232f59233a6 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -3130,6 +3130,7 @@ def main(x: R.Tensor(("A", "B", "A // B"), dtype="float32")) -> R.Tensor(("A", " gv: R.Tensor((A, B, A // B), dtype="float32") = x R.output(gv) return gv + # fmt: on tvm.ir.assert_structural_equal(tvm_model, Expected) @@ -3169,5 +3170,430 @@ def main(x: R.Tensor(("A", "B", "A // B"), dtype="float32")) -> R.Tensor(("A", " tvm.ir.assert_structural_equal(tvm_model, Expected) +def test_nms(): + """Test NonMaxSuppression operator conversion using our AllClassNMS implementation.""" + nms_node = helper.make_node( + "NonMaxSuppression", + ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", "score_threshold"], + ["selected_indices"], + center_point_box=0, + ) + + boxes_shape = [1, 5, 4] # batch_size, num_boxes, 4 + scores_shape = [1, 2, 5] # batch_size, num_classes, num_boxes + + graph = helper.make_graph( + [nms_node], + "nms_test", + inputs=[ + helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes_shape), + helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores_shape), + ], + initializer=[ + helper.make_tensor("max_output_boxes_per_class", TensorProto.INT64, [1], [3]), + helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.5]), + helper.make_tensor("score_threshold", TensorProto.FLOAT, [1], [0.1]), + ], + outputs=[helper.make_tensor_value_info("selected_indices", TensorProto.INT64, [0, 3])], + ) + + model = helper.make_model(graph, producer_name="nms_test") + model.opset_import[0].version = 11 + + # Use deterministic random inputs for consistent testing + bg = np.random.MT19937(0) + rg = np.random.Generator(bg) + boxes = rg.standard_normal(size=boxes_shape).astype(np.float32) + scores = rg.standard_normal(size=scores_shape).astype(np.float32) + inputs = {"boxes": boxes, "scores": scores} + + # Run ONNX Runtime + ort_session = onnxruntime.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + ort_output = ort_session.run([], inputs) + + # Run TVM + tvm_model = from_onnx(model, opset=11, keep_params_in_input=True) + tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) + tvm_model = relax.transform.LegalizeOps()(tvm_model) + tvm_model, params = relax.frontend.detach_params(tvm_model) + + with tvm.transform.PassContext(opt_level=3): + ex = tvm.compile(tvm_model, target="llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + input_list = [ + inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs + ] + if params: + input_list += params["main"] + + vm.set_input("main", *input_list) + vm.invoke_stateful("main") + tvm_output = vm.get_outputs("main") + + if isinstance(tvm_output, (list, tuple)): + tvm_selected = tvm_output[0].numpy() + else: + tvm_selected = tvm_output.numpy() + ort_selected = ort_output[0] + + min_rows = min(tvm_selected.shape[0], ort_selected.shape[0]) + if min_rows > 0: + tvm.testing.assert_allclose( + tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, atol=1e-5 + ) + + +def test_nms_algorithm_correctness(): + """Test NMS algorithm correctness with fixed data to verify suppression logic.""" + nms_node = helper.make_node( + "NonMaxSuppression", + ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", "score_threshold"], + ["selected_indices"], + center_point_box=0, + ) + + # Create fixed test data with known expected results + # Boxes: [x1, y1, x2, y2] format + boxes_data = np.array( + [ + [ + [0.0, 0.0, 1.0, 1.0], # Box 0: [0,0,1,1] - should be selected + [ + 0.5, + 0.5, + 1.5, + 1.5, + ], # Box 1: [0.5,0.5,1.5,1.5] - overlaps with box 0, should be suppressed + [2.0, 2.0, 3.0, 3.0], + ] + ], # Box 2: [2,2,3,3] - no overlap, should be selected + dtype=np.float32, + ) + + # Scores: higher score = better + scores_data = np.array( + [ + [[0.9, 0.8, 0.7], [0.6, 0.5, 0.4]] # Class 0: [0.9, 0.8, 0.7] - box 0 has highest score + ], # Class 1: [0.6, 0.5, 0.4] - box 0 has highest score + dtype=np.float32, + ) + + boxes_shape = [1, 3, 4] # batch_size, num_boxes, 4 + scores_shape = [1, 2, 3] # batch_size, num_classes, num_boxes + + graph = helper.make_graph( + [nms_node], + "nms_test_correctness", + inputs=[ + helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes_shape), + helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores_shape), + ], + initializer=[ + helper.make_tensor( + "max_output_boxes_per_class", TensorProto.INT64, [1], [2] + ), # Only 2 boxes per class + helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.5]), # IoU threshold 0.5 + helper.make_tensor( + "score_threshold", TensorProto.FLOAT, [1], [0.1] + ), # Score threshold 0.1 + ], + outputs=[helper.make_tensor_value_info("selected_indices", TensorProto.INT64, [4, 3])], + ) + + model = helper.make_model(graph, producer_name="nms_test_correctness") + + # Use fixed inputs instead of random + inputs = { + "boxes": boxes_data, + "scores": scores_data, + } + + check_correctness(model, inputs=inputs, opset=11) + + +def test_nms_iou_suppression(): + """Test that NMS correctly suppresses overlapping boxes based on IoU threshold.""" + nms_node = helper.make_node( + "NonMaxSuppression", + ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", "score_threshold"], + ["selected_indices"], + center_point_box=0, + ) + + # Create overlapping boxes where box 0 has higher score and should be kept + boxes_data = np.array( + [ + [ + [0.0, 0.0, 1.0, 1.0], # Box 0: [0,0,1,1] - highest score + [ + 0.1, + 0.1, + 1.1, + 1.1, + ], # Box 1: [0.1,0.1,1.1,1.1] - high IoU with box 0, should be suppressed + [2.0, 2.0, 3.0, 3.0], + ] + ], # Box 2: [2,2,3,3] - no overlap, should be kept + dtype=np.float32, + ) + + # Box 0 has highest score, Box 1 should be suppressed due to IoU with box 0 + scores_data = np.array([[[0.9, 0.8, 0.7]]], dtype=np.float32) + + boxes_shape = [1, 3, 4] + scores_shape = [1, 1, 3] + + graph = helper.make_graph( + [nms_node], + "nms_test_iou_suppression", + inputs=[ + helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes_shape), + helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores_shape), + ], + initializer=[ + helper.make_tensor("max_output_boxes_per_class", TensorProto.INT64, [1], [2]), + helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.5]), # IoU threshold 0.5 + helper.make_tensor("score_threshold", TensorProto.FLOAT, [1], [0.1]), + ], + outputs=[helper.make_tensor_value_info("selected_indices", TensorProto.INT64, [2, 3])], + ) + + model = helper.make_model(graph, producer_name="nms_test_iou_suppression") + model.opset_import[0].version = 11 + + inputs = { + "boxes": boxes_data, + "scores": scores_data, + } + + # Run ONNX Runtime + ort_session = onnxruntime.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + ort_output = ort_session.run([], inputs) + + # Run TVM + tvm_model = from_onnx(model, opset=11, keep_params_in_input=True) + tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) + tvm_model = relax.transform.LegalizeOps()(tvm_model) + tvm_model, params = relax.frontend.detach_params(tvm_model) + + with tvm.transform.PassContext(opt_level=3): + ex = tvm.compile(tvm_model, target="llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + input_list = [ + inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs + ] + if params: + input_list += params["main"] + + vm.set_input("main", *input_list) + vm.invoke_stateful("main") + tvm_output = vm.get_outputs("main") + + # Custom NMS output comparison + if isinstance(tvm_output, (list, tuple)): + tvm_selected = tvm_output[0].numpy() + else: + tvm_selected = tvm_output.numpy() + ort_selected = ort_output[0] + + # For NMS, compare only the valid rows + min_rows = min(tvm_selected.shape[0], ort_selected.shape[0]) + if min_rows > 0: + tvm.testing.assert_allclose( + tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, atol=1e-5 + ) + + +def test_nms_max_boxes_limit(): + """Test that NMS correctly limits the number of boxes per class.""" + nms_node = helper.make_node( + "NonMaxSuppression", + ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", "score_threshold"], + ["selected_indices"], + center_point_box=0, + ) + + # Create data with 4 boxes, but limit to 2 per class + boxes_data = np.array( + [ + [ + [0.0, 0.0, 1.0, 1.0], # Box 0 + [2.0, 0.0, 3.0, 1.0], # Box 1 + [0.0, 2.0, 1.0, 3.0], # Box 2 + [2.0, 2.0, 3.0, 3.0], + ] + ], # Box 3 + dtype=np.float32, + ) + + # All boxes have different scores + scores_data = np.array([[[0.9, 0.8, 0.7, 0.6]]], dtype=np.float32) + + boxes_shape = [1, 4, 4] + scores_shape = [1, 1, 4] + + graph = helper.make_graph( + [nms_node], + "nms_test_max_boxes_limit", + inputs=[ + helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes_shape), + helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores_shape), + ], + initializer=[ + helper.make_tensor( + "max_output_boxes_per_class", TensorProto.INT64, [1], [2] + ), # Limit to 2 boxes + helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.1]), # Low IoU threshold + helper.make_tensor("score_threshold", TensorProto.FLOAT, [1], [0.1]), + ], + outputs=[helper.make_tensor_value_info("selected_indices", TensorProto.INT64, [2, 3])], + ) + + model = helper.make_model(graph, producer_name="nms_test_max_boxes_limit") + model.opset_import[0].version = 11 + + inputs = { + "boxes": boxes_data, + "scores": scores_data, + } + + # Run ONNX Runtime + ort_session = onnxruntime.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + ort_output = ort_session.run([], inputs) + + # Run TVM + tvm_model = from_onnx(model, opset=11, keep_params_in_input=True) + tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) + tvm_model = relax.transform.LegalizeOps()(tvm_model) + tvm_model, params = relax.frontend.detach_params(tvm_model) + + with tvm.transform.PassContext(opt_level=3): + ex = tvm.compile(tvm_model, target="llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + input_list = [ + inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs + ] + if params: + input_list += params["main"] + + vm.set_input("main", *input_list) + vm.invoke_stateful("main") + tvm_output = vm.get_outputs("main") + + # Custom NMS output comparison + if isinstance(tvm_output, (list, tuple)): + tvm_selected = tvm_output[0].numpy() + else: + tvm_selected = tvm_output.numpy() + ort_selected = ort_output[0] + + # For NMS, compare only the valid rows + min_rows = min(tvm_selected.shape[0], ort_selected.shape[0]) + if min_rows > 0: + tvm.testing.assert_allclose( + tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, atol=1e-5 + ) + + +def test_nms_score_threshold(): + """Test that NMS correctly filters boxes based on score threshold. + + Note: This test uses a low score threshold (0.05) to ensure both TVM and ONNX Runtime + output the same fixed shape [3,3], allowing use of the standard check_correctness function. + """ + nms_node = helper.make_node( + "NonMaxSuppression", + ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", "score_threshold"], + ["selected_indices"], + center_point_box=0, + ) + + # Create data with varying scores - ensure we get exactly 3 boxes after NMS + boxes_data = np.array( + [ + [[0.0, 0.0, 1.0, 1.0], [2.0, 0.0, 3.0, 1.0], [0.0, 2.0, 1.0, 3.0]] # Box 0 # Box 1 + ], # Box 2 + dtype=np.float32, + ) + + # Scores: 0.9, 0.3, 0.1 - adjust score threshold to get exactly 3 boxes + scores_data = np.array([[[0.9, 0.3, 0.1]]], dtype=np.float32) + + boxes_shape = [1, 3, 4] + scores_shape = [1, 1, 3] + + graph = helper.make_graph( + [nms_node], + "nms_test_score_threshold", + inputs=[ + helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes_shape), + helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores_shape), + ], + initializer=[ + helper.make_tensor("max_output_boxes_per_class", TensorProto.INT64, [1], [3]), + helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.1]), + helper.make_tensor("score_threshold", TensorProto.FLOAT, [1], [0.05]), + ], + outputs=[helper.make_tensor_value_info("selected_indices", TensorProto.INT64, [3, 3])], + ) + + model = helper.make_model(graph, producer_name="nms_test_score_threshold") + model.opset_import[0].version = 11 + + inputs = { + "boxes": boxes_data, + "scores": scores_data, + } + + # Run ONNX Runtime + ort_session = onnxruntime.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + ort_output = ort_session.run([], inputs) + + # Run TVM + tvm_model = from_onnx(model, opset=11, keep_params_in_input=True) + tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) + tvm_model = relax.transform.LegalizeOps()(tvm_model) + tvm_model, params = relax.frontend.detach_params(tvm_model) + + with tvm.transform.PassContext(opt_level=3): + ex = tvm.compile(tvm_model, target="llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + input_list = [ + inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs + ] + if params: + input_list += params["main"] + + vm.set_input("main", *input_list) + vm.invoke_stateful("main") + tvm_output = vm.get_outputs("main") + + # Custom NMS output comparison + if isinstance(tvm_output, (list, tuple)): + tvm_selected = tvm_output[0].numpy() + else: + tvm_selected = tvm_output.numpy() + ort_selected = ort_output[0] + + # For NMS, compare only the valid rows + min_rows = min(tvm_selected.shape[0], ort_selected.shape[0]) + if min_rows > 0: + tvm.testing.assert_allclose( + tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, atol=1e-5 + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py new file mode 100644 index 000000000000..97145a53ff3b --- /dev/null +++ b/tests/python/relax/test_op_vision.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op, VDevice +from tvm.script import relax as R + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_all_class_non_max_suppression_infer_struct_info(): + bb = relax.BlockBuilder() + batch_size, num_classes, num_boxes = 10, 8, 5 + boxes = relax.Var("boxes", R.Tensor((batch_size, num_boxes, 4), "float32")) + scores = relax.Var("scores", R.Tensor((batch_size, num_classes, num_boxes), "float32")) + max_output_boxes_per_class = relax.const(10, "int64") + iou_threshold = relax.const(0.5, "float32") + score_threshold = relax.const(0.1, "float32") + + _check_inference( + bb, + relax.op.vision.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((batch_size * num_classes * num_boxes, 3), "int64"), + relax.TensorStructInfo((1,), "int64"), + ] + ), + ) + + +def test_all_class_non_max_suppression_wrong_input_number(): + bb = relax.BlockBuilder() + boxes = relax.Var("boxes", R.Tensor((1, 5, 4), "float32")) + scores = relax.Var("scores", R.Tensor((1, 3, 5), "float32")) + + with pytest.raises(TVMError): + relax.op.vision.all_class_non_max_suppression(boxes, scores) + + +def test_all_class_non_max_suppression_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + batch_size = tir.Var("batch_size", "int64") + num_classes = tir.Var("num_classes", "int64") + num_boxes = tir.Var("num_boxes", "int64") + boxes = relax.Var("boxes", R.Tensor((batch_size, num_boxes, 4), "float32")) + scores = relax.Var("scores", R.Tensor((batch_size, num_classes, num_boxes), "float32")) + max_output_boxes_per_class = relax.const(10, "int64") + iou_threshold = relax.const(0.5, "float32") + score_threshold = relax.const(0.1, "float32") + + _check_inference( + bb, + relax.op.vision.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((batch_size * num_classes * num_boxes, 3), "int64"), + relax.TensorStructInfo((1,), "int64"), + ] + ), + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py b/tests/python/relax/test_tvmscript_parser_op_vision.py new file mode 100644 index 000000000000..66e0adac3d22 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_vision.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_all_class_non_max_suppression(): + @R.function + def foo( + boxes: R.Tensor((10, 5, 4), "float32"), + scores: R.Tensor((10, 8, 5), "float32"), + max_output_boxes_per_class: R.Tensor((), "int64"), + iou_threshold: R.Tensor((), "float32"), + score_threshold: R.Tensor((), "float32"), + ) -> R.Tuple(R.Tensor((400, 3), "int64"), R.Tensor((1,), "int64")): + gv: R.Tuple( + R.Tensor((400, 3), "int64"), R.Tensor((1,), "int64") + ) = R.vision.all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + "onnx", + ) + return gv + + boxes = relax.Var("boxes", R.Tensor((10, 5, 4), "float32")) + scores = relax.Var("scores", R.Tensor((10, 8, 5), "float32")) + max_output_boxes_per_class = relax.Var("max_output_boxes_per_class", R.Tensor((), "int64")) + iou_threshold = relax.Var("iou_threshold", R.Tensor((), "float32")) + score_threshold = relax.Var("score_threshold", R.Tensor((), "float32")) + + bb = relax.BlockBuilder() + with bb.function( + "foo", [boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold] + ): + gv = bb.emit( + relax.op.vision.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" + ) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main()