Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 30, 2021
1 parent 0c659bf commit a1fe7c4
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 39 deletions.
31 changes: 11 additions & 20 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,25 +1100,6 @@ def all_class_non_max_suppression(
sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both")
valid_count = _get_valid_box_count(sorted_scores, score_threshold)

if output_format == "onnx":
selected_indices, num_detections = run_all_class_nms(
boxes,
sorted_scores,
sorted_indices,
valid_count,
max_output_boxes_per_class,
iou_threshold,
_nms_loop,
)

row_offsets, num_total_detections = exclusive_scan(
num_detections, return_reduction=True, output_dtype="int64"
)
selected_indices = collect_selected_indices(
num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir
)
return [selected_indices, num_total_detections]

selected_indices, selected_scores, num_detections = run_all_class_nms(
boxes,
sorted_scores,
Expand All @@ -1127,8 +1108,18 @@ def all_class_non_max_suppression(
max_output_boxes_per_class,
iou_threshold,
_nms_loop,
return_scores=True,
return_scores=(output_format == "tensorflow"),
)

if output_format == "onnx":
row_offsets, num_total_detections = exclusive_scan(
num_detections, return_reduction=True, output_dtype="int64"
)
selected_indices = collect_selected_indices(
num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir
)
return [selected_indices, num_total_detections]

num_detections_per_batch = reshape(num_detections, (batch, num_class))
row_offsets, num_total_detections = exclusive_scan(
num_detections_per_batch, return_reduction=True, output_dtype="int64", axis=1
Expand Down
29 changes: 11 additions & 18 deletions python/tvm/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,23 +829,6 @@ def all_class_non_max_suppression(

valid_count = _get_valid_box_count(sorted_scores, score_threshold)

if output_format == "onnx":
selected_indices, num_detections = run_all_class_nms(
boxes,
sorted_scores,
sorted_indices,
valid_count,
max_output_boxes_per_class,
iou_threshold,
_nms_loop,
)
row_offsets = cumsum(num_detections, exclusive=True, dtype="int64")
num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1)
selected_indices = collect_selected_indices(
num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir
)
return [selected_indices, num_total_detections]

selected_indices, selected_scores, num_detections = run_all_class_nms(
boxes,
sorted_scores,
Expand All @@ -854,8 +837,18 @@ def all_class_non_max_suppression(
max_output_boxes_per_class,
iou_threshold,
_nms_loop,
return_scores=True,
return_scores=(output_format == "tensorflow"),
)

if output_format == "onnx":
row_offsets = cumsum(num_detections, exclusive=True, dtype="int64")
num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1)

selected_indices = collect_selected_indices(
num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir
)
return [selected_indices, num_total_detections]

num_detections_per_batch = reshape(num_detections, (batch, num_class))
row_offsets = cumsum(num_detections_per_batch, exclusive=True, dtype="int64", axis=1)
num_total_detections = reduction.sum(cast(num_detections_per_batch, "int64"), axis=1)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/topi/vision/nms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def run_all_class_nms(
)

if return_scores is False:
return te.extern(
selected_indices, num_detections = te.extern(
[(batch_class, num_boxes), (1, batch_class)],
[boxes, sorted_scores, sorted_indices, valid_count],
lambda ins, outs: _all_class_nms_ir(
Expand Down Expand Up @@ -334,6 +334,7 @@ def run_all_class_nms(
name="all_class_nms",
tag="all_class_nms",
)
return selected_indices, None, num_detections

return te.extern(
[(batch_class, num_boxes), (batch_class, num_boxes), (1, batch_class)],
Expand Down

0 comments on commit a1fe7c4

Please sign in to comment.