Skip to content

Commit

Permalink
do minimum in topi
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 52c5e8a commit da75b2a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
1 change: 0 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,6 @@ def all_class_impl(
max_total_size,
output_format="tensorflow",
)
num_detections = _op.minimum(num_detections, _op.const(max_total_size, dtype="int64"))
box_range = _op.arange(
_op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64"
)
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust
from tvm.ir import register_intrin_lowering
from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust, topk
from .sort import argsort, argsort_thrust
from ..broadcast import minimum
from .scan import exclusive_scan
from ..utils import ceil_div
from ..math import cast
Expand Down Expand Up @@ -1133,4 +1134,6 @@ def all_class_non_max_suppression(
_collect_selected_indices_and_scores_ir,
)

num_total_detections = minimum(num_total_detections, max_total_size)

return [selected_indices, selected_scores, num_total_detections]
3 changes: 3 additions & 0 deletions python/tvm/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..sort import argsort
from ..math import cast
from ..transform import reshape, gather
from ..broadcast import minimum
from .. import reduction
from ..scan import cumsum
from .nms_util import (
Expand Down Expand Up @@ -861,4 +862,6 @@ def all_class_non_max_suppression(
_collect_selected_indices_and_scores_ir,
)

num_total_detections = minimum(num_total_detections, max_total_size)

return [selected_indices, selected_scores, num_total_detections]

0 comments on commit da75b2a

Please sign in to comment.