Skip to content

Commit

Permalink
improve sort on cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 30, 2021
1 parent 480f6b7 commit 0c659bf
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions python/tvm/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from ..sort import sort, argsort, topk
from ..math import cast
from ..transform import reshape, arange, expand_dims
from ..transform import reshape, arange, expand_dims, gather
from .. import reduction
from ..scan import cumsum
from .nms_util import (
Expand Down Expand Up @@ -824,8 +824,9 @@ def all_class_non_max_suppression(
batch, num_class, num_boxes = scores.shape
scores = reshape(scores, (batch * num_class, num_boxes))

sorted_scores = sort(scores, axis=1, is_ascend=False)
sorted_indices = argsort(scores, axis=1, is_ascend=False, dtype="int32")
sorted_scores = gather(scores, 1, sorted_indices)

valid_count = _get_valid_box_count(sorted_scores, score_threshold)

if output_format == "onnx":
Expand Down

0 comments on commit 0c659bf

Please sign in to comment.