Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 9dcd0f0 commit 47a05c4
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,13 @@ def collect_selected_indices_tf(selected_indices, selected_scores, num_detection


def all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, max_total_size, output_format="onnx"
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
max_total_size,
output_format="onnx",
):
"""Non-maximum suppression operator for object detection, corresponding to ONNX
NonMaxSuppression and TensorFlow combined_non_max_suppression.
Expand Down Expand Up @@ -1145,7 +1151,7 @@ def all_class_non_max_suppression(
max_output_boxes_per_class,
iou_threshold,
_nms_loop,
True,
return_scores=True,
)

# tf mode, return (batch_size, max_total_size, 2)
Expand All @@ -1156,16 +1162,9 @@ def all_class_non_max_suppression(
selected_indices, selected_scores = collect_selected_indices_tf(
selected_indices, selected_scores, num_detections_per_batch, row_offsets
)

# TODO
# max_total_detections = reduction.max(num_total_detections)
# selected_scores = strided_slice(
# selected_scores, begin=[0, 0], end=expand_dims(max_total_detections, 0)
# )
topk_indices = topk(selected_scores, k=max_detection_per_batch, axis=1, ret_type="indices")[0]
topk_indices = expand_dims(topk_indices, axis=0)
final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1)

num_detections = minimum(num_total_detections, max_detection_per_batch)

return [final_indices, num_detections]

0 comments on commit 47a05c4

Please sign in to comment.