Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
Trevor Morris authored and masahi committed May 29, 2021
1 parent 168a617 commit 0044365
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 11 deletions.
3 changes: 3 additions & 0 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
/*! \brief Attributes used in non_maximum_suppression operator */
struct AllClassNonMaximumSuppressionAttrs
: public tvm::AttrsNode<AllClassNonMaximumSuppressionAttrs> {
std::string output_format;
TVM_ATTR_FIELD(output_format).set_default("onnx").describe(
"Output format. onnx or tensorflow");
TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs,
"relay.attrs.AllClassNonMaximumSuppressionAttrs") {}
};
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def _impl(inputs, attr, params, mod):
# Transpose (batch_size, num_boxes, num_classes) -> (batch_size, num_classes, num_boxes)
scores = _op.transpose(scores, [0, 2, 1])
indices, count = _op.vision.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format="tensorflow"
)
# Slice indices to count
three = _op.const(np.array([3]), dtype="int64")
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def non_max_suppression(


def all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0
boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0, output_format="onnx"
):
"""Non-maximum suppression operator for object detection, corresponding to ONNX
NonMaxSuppression and TensorFlow combined_non_max_suppression.
Expand Down Expand Up @@ -185,6 +185,7 @@ def all_class_non_max_suppression(
in descending of scores, followed by boxes from batch 0, class 1 etc. Out of
`batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection`
rows are valid.
TODO(trvmorr): explain tf mode
"""
if not isinstance(max_output_boxes_per_class, expr.Expr):
max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32")
Expand All @@ -194,6 +195,6 @@ def all_class_non_max_suppression(
score_threshold = expr.const(score_threshold, "float32")

out = _make.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format
)
return expr.TupleWrapper(out, 2)
13 changes: 9 additions & 4 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro


def all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format
):
"""Non-maximum suppression operator for object detection, corresponding to ONNX
NonMaxSuppression and TensorFlow combined_non_max_suppression.
Expand All @@ -1011,6 +1011,8 @@ def all_class_non_max_suppression(
score_threshold : float or tvm.te.Tensor, optional
Score threshold to filter out low score boxes early
output_format : str
Returns
-------
Expand Down Expand Up @@ -1043,8 +1045,11 @@ def all_class_non_max_suppression(
num_detections, return_reduction=True, output_dtype="int64"
)

selected_indices = collect_selected_indices(
num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir
)
if output_format == "onnx":
selected_indices = collect_selected_indices(
num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir
)
else:
selected_indices = reshape(selected_indices, (batch, num_class, num_boxes))

return [selected_indices, num_total_detections]
11 changes: 7 additions & 4 deletions python/tvm/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro


def all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format
):
"""Non-maximum suppression operator for object detection, corresponding to ONNX
NonMaxSuppression and TensorFlow combined_non_max_suppression.
Expand All @@ -750,6 +750,8 @@ def all_class_non_max_suppression(
score_threshold : float or tvm.te.Tensor, optional
Score threshold to filter out low score boxes early
output_format : TODO
Returns
-------
Expand Down Expand Up @@ -783,8 +785,9 @@ def all_class_non_max_suppression(

num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1)

selected_indices = collect_selected_indices(
num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir
)
if output_format == "onnx":
selected_indices = collect_selected_indices(
num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir
)

return [selected_indices, num_total_detections]
53 changes: 53 additions & 0 deletions python/tvm/topi/vision/nms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,59 @@ def collect_selected_indices(num_class, selected_indices, num_detections, row_of
)


def collect_selected_indices_tf(num_class, selected_indices, num_detections, row_offsets, ir):
"""Collect selected indices from the core NMS loop into one linear output
Parameters
----------
num_class : int
selected_indices: tvm.te.Tensor
2-D tensor with shape (batch_size * num_classes, num_boxes), representing the indices
of selected boxes by the core NMS loop.
num_detections tvm.te.Tensor
1-D tensor with shape (batch_size * num_classes,), representing
the number of boxes selected by the core NMS loop, per batch and class
row_offsets tvm.te.Tensor
1-D tensor with shape (batch_size * num_classes,), this should be the exclusive scan
of num_detections
ir : function
A function to generate IR for CPU or GPU, see its usage in vision/nms.py and cuda/nms.py
Returns
-------
out : tvm.te.Tensor
The output is indices of size (batch_size * num_class* num_boxes , 3).
Rows of indices are ordered such that selected boxes from batch 0, class 0 come
first, in descending of scores, followed by boxes from batch 0, class 1 etc.
"""
batch_class, num_boxes = selected_indices.shape

selected_indices_buf = tvm.tir.decl_buffer(
selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8
)
num_detections_buf = tvm.tir.decl_buffer(
num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8
)
row_offsets_buf = tvm.tir.decl_buffer(
row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8
)

return te.extern(
[(batch_size, num_class * num_boxes, 2)],
[selected_indices, num_detections, row_offsets],
lambda ins, outs: ir(num_class, ins[0], ins[1], ins[2], outs[0]),
dtype=["int64"],
in_buffers=[selected_indices_buf, num_detections_buf, row_offsets_buf],
name="collect_indices",
tag="collect_indices",
)



def _all_class_nms_ir(
boxes,
sorted_scores,
Expand Down

0 comments on commit 0044365

Please sign in to comment.