Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] NMS in ONNX #6839

Merged
merged 10 commits into from
Dec 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,12 @@ struct MultiBoxTransformLocAttrs : public tvm::AttrsNode<MultiBoxTransformLocAtt

/*! \brief Attributes used in get_valid_counts operator */
struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs> {
double score_threshold;
Optional<FloatImm> score_threshold;
int id_index;
int score_index;

TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") {
TVM_ATTR_FIELD(score_threshold)
.set_default(0.0)
.describe("Lower limit of score for valid bounding boxes.");
TVM_ATTR_FIELD(score_threshold).describe("Lower limit of score for valid bounding boxes.");
TVM_ATTR_FIELD(id_index).set_default(0).describe("Axis index of id.");
TVM_ATTR_FIELD(score_index).set_default(1).describe("Index of the scores/confidence of boxes.");
}
Expand All @@ -89,7 +87,7 @@ struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs> {
/*! \brief Attributes used in non_maximum_suppression operator */
struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionAttrs> {
Optional<Integer> max_output_size;
double iou_threshold;
Optional<FloatImm> iou_threshold;
bool force_suppress;
int top_k;
int coord_start;
Expand All @@ -100,9 +98,7 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA

TVM_DECLARE_ATTRS(NonMaximumSuppressionAttrs, "relay.attrs.NonMaximumSuppressionAttrs") {
TVM_ATTR_FIELD(max_output_size).describe("Max number of output valid boxes for each instance.");
TVM_ATTR_FIELD(iou_threshold)
.set_default(0.5)
.describe("Non-maximum suppression iou threshold.");
TVM_ATTR_FIELD(iou_threshold).describe("Non-maximum suppression iou threshold.");
TVM_ATTR_FIELD(force_suppress)
.set_default(false)
.describe("Suppress all detections regardless of class_id.");
Expand Down
269 changes: 269 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2303,6 +2303,274 @@ def _impl_v1(cls, inputs, attr, params):
return _expr.If(cond, then_expr, else_expr)


class NonMaxSuppression(OnnxOpConverter):
"""Operator converter for NonMaxSuppression."""

@classmethod
def _impl_v10(cls, inputs, attr, params):
"""
High level note: ONNX implements what TF calls combined_non_max_suppression
It passes in scores for each box for every class in the output and expects boxes to be
analyzed for each class independently

It also asks for the data to be returned in a particular format.

To support these, we implement a series of lops:
The first loop splits over class number, performs NMS, and collects the outputs.
The second (nested) loop takes the outputs and transforms them into the format ONNX wants
"""
# Get parameter values
boxes = inputs[0]
scores = inputs[1]
max_output_boxes_per_class = inputs[2]
iou_threshold = inputs[3]
score_threshold = inputs[4]

dtype = infer_type(boxes).checked_type.dtype

if "center_point_box" in attr:
assert (
attr["center_point_box"] == 0
), "Only support center_point_box = 0 in onnx importer right now"

if iou_threshold is None:
iou_threshold = _expr.const(0.0, dtype="float32")
if score_threshold is None:
score_threshold = _expr.const(0.0, dtype="float32")

def conditionally_squeeze_scalar(x):
rank = len(infer_shape(x))
assert rank <= 1, "nms thresholds must be scalars"
if rank == 1:
return _op.squeeze(x, [0])
return x

max_output_boxes_per_class = conditionally_squeeze_scalar(max_output_boxes_per_class)
iou_threshold = conditionally_squeeze_scalar(iou_threshold)
score_threshold = conditionally_squeeze_scalar(score_threshold)

## prepare utility constants
zero = _op.const(np.array([0]), dtype="int64")
one = _op.const(np.array([1]), dtype="int64")
two = _op.const(np.array([2]), dtype="int64")
three = _op.const(np.array([3]), dtype="int64")
three_ones = _op.const(np.array([1, 1, 1]), dtype="int64")
four_ones = _op.const(np.array([1, 1, 1, 1]), dtype="int64")

## First loop: split by class and perform NMS
# Create Loop Vars
i = _expr.var("i", shape=(1,), dtype="int64")
scores_var = _expr.var("scores_var", shape=(_ty.Any(), _ty.Any(), _ty.Any()), dtype=dtype)
boxes_var = _expr.var("boxes_var", shape=(_ty.Any(), _ty.Any(), 4), dtype=dtype)
max_output_boxes_per_class_var = _expr.var(
"max_output_boxes_per_class_var", shape=(), dtype="int64"
)
iou_threshold_var = _expr.var("iou_threshold_var", shape=(), dtype="float32")
score_threshold_var = _expr.var("score_threshold_var", shape=(), dtype="float32")
B = _expr.var("B", shape=(1,), dtype="int64")
C = _expr.var("C", shape=(1,), dtype="int64")
S = _expr.var("S", shape=(1,), dtype="int64")
# Outputs of first loop should be padded nms values shape (B, C, S, 3)
onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
# and sizes of valid outputs, shape (B, C, 1)
nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")

def _first_cond(
i,
scores,
boxes,
B,
C,
S,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
onnx_out,
nms_size_out,
):
# Loop over classes, end when i == C
return _op.min(_op.less(i, C))

def _first_body(
i,
scores,
boxes,
B,
C,
S,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
onnx_out,
nms_size_out,
):
# slice to get current class
begin = _op.concatenate([zero, i, zero], axis=0)
end = _op.concatenate([B, i + one, S], axis=0)
class_scores = _op.strided_slice(scores, begin, end, three_ones)
class_scores = _op.expand_dims(_op.squeeze(class_scores, [1]), -1, 1)
# combine scores and boxes
data = _op.concatenate([class_scores, boxes], axis=-1)

# get valid counts
ct, data, indices = _op.vision.get_valid_counts(
data, score_threshold=score_threshold, id_index=-1, score_index=0
)
# reason why using get_valid_counts is for inference performance
# ONNX NMS doesn't have parameter top_k
top_k = -1
# ONNX doesn't have class id for nms input
score_index = 0
# perform nms on current class
nms_ret = _op.vision.non_max_suppression(
data=data,
valid_count=ct,
indices=indices,
max_output_size=max_output_boxes_per_class,
iou_threshold=iou_threshold,
force_suppress=True,
top_k=top_k,
coord_start=1,
score_index=score_index,
id_index=-1,
return_indices=True,
invalid_to_bottom=False,
)
# partially prepare ONNX output format by labeling batch_num, class_id
nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1)
batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1)
batch_num = _op.broadcast_to(batch_num, _op.shape_of(nms_ret[0], dtype="int64"))
batch_num = _op.expand_dims(batch_num, -1, 1)
class_num = _op.broadcast_to(i, _op.shape_of(nms_padded_out, dtype="int64"))
new_onnx_out = _op.concatenate(
[batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1
)
new_onnx_out = _op.expand_dims(new_onnx_out, 1, 1)
# store valid nms outputs for this class
nms_size = _op.cast(nms_ret[1], "int64")
nms_size = _op.expand_dims(nms_size, 1, 1)
return [
i + one,
scores,
boxes,
B,
C,
S,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
_op.concatenate([onnx_out, new_onnx_out], axis=1),
_op.concatenate([nms_size_out, nms_size], axis=1),
]

# create the first loop
first_loop = _loops.while_loop(
_first_cond,
[
i,
scores_var,
boxes_var,
B,
C,
S,
max_output_boxes_per_class_var,
iou_threshold_var,
score_threshold_var,
onnx_out,
nms_size_out,
],
_first_body,
)

## Second loop slices outputs of the first loop for valid boxes and
## concats in the order ONNX wants
# Second inner Loop Vars
i = _expr.var("i", shape=(1,), dtype="int64")
j = _expr.var("j", shape=(1,), dtype="int64")
B = _expr.var("B", shape=(1,), dtype="int64")
C = _expr.var("C", shape=(1,), dtype="int64")
# Outputs of first loop should be padded nms values shape (B, C, 3)
onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
# and sizes of valid outputs, shape (B, C, 1)
nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")
out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64")

def _inner_cond(i, j, C, onnx_out, nms_size, out):
# inner loop over number of classes
return _op.min(_op.less(j, C))

def _inner_body(i, j, C, onnx_out, nms_size, out):
# slice to get current batch and class for valid box indicator
start = _op.concatenate([i, j + one, zero], axis=0)
end = _op.concatenate([i + one, j + two, one], axis=0)
num_valid_boxes = _op.reshape(_op.strided_slice(nms_size, start, end, three_ones), [1])
# slice to get current batch, class, and valid outputs
start = _op.concatenate([i, j + one, zero, zero], axis=0)
end = _op.concatenate([i + one, j + two, num_valid_boxes, three], axis=0)
new_out = _op.squeeze(_op.strided_slice(onnx_out, start, end, four_ones), [0, 1])
return i, j + one, C, onnx_out, nms_size, _op.concatenate([out, new_out], axis=0)

inner_loop = _loops.while_loop(
_inner_cond, [i, j, C, onnx_out, nms_size_out, out], _inner_body
)

# Second Outer Loop Vars
i = _expr.var("i", shape=(1,), dtype="int64")
j = _expr.var("j", shape=(1,), dtype="int64")
B = _expr.var("B", shape=(1,), dtype="int64")
C = _expr.var("C", shape=(1,), dtype="int64")
# Outputs of first loop should be padded nms values shape (B, C, 3)
onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
# and sizes of valid outputs, shape (B, C, 1)
nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")
out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64")

def _outer_cond(i, B, C, onnx_out, nms_size_out, out):
# Outer loop is over batch size
return _op.min(_op.less(i, B))

def _outer_body(i, B, C, onnx_out, nms_size_out, out):
# Outer loop just calls inner loop
init_count = _op.const(np.array([0]), dtype="int64")
inner_loop_vals = inner_loop(i, init_count, C, onnx_out, nms_size_out, out)
return i + one, B, C, onnx_out, nms_size_out, _expr.TupleGetItem(inner_loop_vals, 5)

# Create the second loop
outer_loop = _loops.while_loop(
_outer_cond, [i, B, C, onnx_out, nms_size_out, out], _outer_body
)

# Call the first loop, perform NMS
B, C, S = _op.split(_op.shape_of(scores, dtype="int64"), 3)
init_count = _op.const(np.array([0]), dtype="int64")
init_onnx_out = _op.const([1], dtype="int64")
init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, one, S, three], 0))
init_nms_size_out = _op.const([1], dtype="int64")
init_nms_size_out = _op.broadcast_to(init_nms_size_out, _op.concatenate([B, one, one], 0))
loop_vals = first_loop(
init_count,
scores,
boxes,
B,
C,
S,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
init_onnx_out,
init_nms_size_out,
)
onnx_output = _expr.TupleGetItem(loop_vals, 9)
nms_size_output = _expr.TupleGetItem(loop_vals, 10)

# Call the second loop, rework outputs into correct form
init_count = _op.const(np.array([0]).astype("int64"), dtype="int64")
init_out = _op.const(np.array([]).reshape([0, 3]).astype("int64"), dtype="int64")
loop_vals = outer_loop(init_count, B, C, onnx_output, nms_size_output, init_out)

return _expr.TupleGetItem(loop_vals, 5)


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -2415,6 +2683,7 @@ def _get_convert_map(opset):
# defs/vision
"MaxRoiPool": MaxRoiPool.get_converter(opset),
"RoiAlign": RoiAlign.get_converter(opset),
"NonMaxSuppression": NonMaxSuppression.get_converter(opset),
# defs/reduction
"ReduceMax": ReduceMax.get_converter(opset),
"ReduceMin": ReduceMin.get_converter(opset),
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,9 +885,11 @@ def wrap_compute_get_valid_counts(topi_compute):
"""wrap get_valid_counts topi compute"""

def _compute_get_valid_counts(attrs, inputs, out_type):
score_threshold = get_const_float(attrs.score_threshold)
score_threshold = inputs[1]
id_index = get_const_int(attrs.id_index)
score_index = get_const_int(attrs.score_index)
if attrs.score_threshold is not None:
score_threshold = get_const_float(attrs.score_threshold)
return topi_compute(inputs[0], score_threshold, id_index, score_index)

return _compute_get_valid_counts
Expand All @@ -911,10 +913,12 @@ def wrap_compute_nms(topi_compute):

def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[3]
iou_threshold = inputs[4]
if attrs.max_output_size is not None:
max_output_size = attrs.max_output_size
if attrs.iou_threshold is not None:
iou_threshold = get_const_float(attrs.iou_threshold)
return_indices = bool(get_const_int(attrs.return_indices))
iou_threshold = get_const_float(attrs.iou_threshold)
force_suppress = bool(get_const_int(attrs.force_suppress))
top_k = get_const_int(attrs.top_k)
coord_start = get_const_int(attrs.coord_start)
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/relay/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def get_valid_counts(data, score_threshold, id_index=0, score_index=1):
out_indices: relay.Expr
Indices in input data
"""
if not isinstance(score_threshold, expr.Expr):
score_threshold = expr.const(score_threshold, "float32")
return expr.TupleWrapper(
_make.get_valid_counts(data, score_threshold, id_index, score_index), 3
)
Expand Down Expand Up @@ -94,7 +96,7 @@ def non_max_suppression(
Max number of output valid boxes for each instance.
Return all valid boxes if the value of max_output_size is less than 0.

iou_threshold : float, optional
iou_threshold : float or relay.Expr, optional
Non-maximum suppression threshold.

force_suppress : bool, optional
Expand Down Expand Up @@ -126,8 +128,10 @@ def non_max_suppression(
If return_indices is True, return relay.Tuple of two 2-D tensors, with
shape [batch_size, num_anchors] and [batch_size, num_valid_anchors] respectively.
"""
if isinstance(max_output_size, int):
if not isinstance(max_output_size, expr.Expr):
max_output_size = expr.const(max_output_size, "int32")
if not isinstance(iou_threshold, expr.Expr):
iou_threshold = expr.const(iou_threshold, "float32")
out = _make.non_max_suppression(
data,
valid_count,
Expand Down
Loading