Skip to content

Commit

Permalink
make it more readable
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 6baee99 commit b02faae
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,7 @@ def all_class_non_max_suppression(
row_offsets, num_total_detections = exclusive_scan(
num_detections_per_batch, return_reduction=True, output_dtype="int64", axis=1
)

selected_indices, selected_scores = collect_selected_indices_and_scores(
selected_indices,
selected_scores,
Expand All @@ -1132,7 +1133,9 @@ def all_class_non_max_suppression(
num_total_detections,
_collect_selected_indices_and_scores_ir,
)

topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices")[0]

return post_process_max_detections(
selected_indices, topk_indices, num_total_detections, max_total_size
)
3 changes: 3 additions & 0 deletions python/tvm/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ def all_class_non_max_suppression(
num_detections_per_batch = reshape(num_detections, (batch, num_class))
row_offsets = cumsum(num_detections_per_batch, exclusive=True, dtype="int64", axis=1)
num_total_detections = reduction.sum(cast(num_detections_per_batch, "int64"), axis=1)

selected_indices, selected_scores = collect_selected_indices_and_scores(
selected_indices,
selected_scores,
Expand All @@ -860,7 +861,9 @@ def all_class_non_max_suppression(
num_total_detections,
_collect_selected_indices_and_scores_ir,
)

topk_indices = topk(selected_scores, k=max_total_size, axis=1, ret_type="indices")

return post_process_max_detections(
selected_indices,
topk_indices,
Expand Down

0 comments on commit b02faae

Please sign in to comment.