Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, TF] Support converting TF combined_nms using Relay all_class_nms #8174

Merged
merged 17 commits into from
Jun 4, 2021
12 changes: 10 additions & 2 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,19 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
}
};

/*! \brief Attributes used in non_maximum_suppression operator */
/*! \brief Attributes used in all_class_non_maximum_suppression operator */
struct AllClassNonMaximumSuppressionAttrs
: public tvm::AttrsNode<AllClassNonMaximumSuppressionAttrs> {
std::string output_format;

TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs,
"relay.attrs.AllClassNonMaximumSuppressionAttrs") {}
"relay.attrs.AllClassNonMaximumSuppressionAttrs") {
TVM_ATTR_FIELD(output_format)
.set_default("onnx")
.describe(
"Output format, onnx or tensorflow. Returns outputs in a way that can be easily "
"consumed by each frontend.");
}
};

/*! \brief Attributes used in roi_align operators */
Expand Down
106 changes: 103 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,88 @@ def _impl(inputs, attr, params, mod):
return _impl


def convert_combined_nms_with_all_class_nms(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I want to use this function to rewrite ONNX models, it is defined at the top level and visible from outside.

batch_size,
max_output_boxes_per_batch,
num_class,
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
max_total_size,
clip_boxes,
):
"""Converts TF combined_nms using Relay all_class_max_suppression op"""
(selected_indices, selected_scores, num_detections,) = _op.vision.all_class_non_max_suppression(
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
output_format="tensorflow",
)
box_range = _op.arange(
_op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64"
)
assert isinstance(batch_size, int), "dynamic batch size not supported yet."
tile_batch_reps = _op.const([batch_size, 1])
box_range_2d = _op.tile(box_range, tile_batch_reps)
valid_mask = _op.cast(
_op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32"
)

def select_topk(do_zero_pad):
def true_branch():
arange = _op.arange(
_op.const(0, dtype="int64"),
_op.const(max_output_boxes_per_batch, dtype="int64"),
dtype="int64",
)
pad = _op.full(
_op.const(0, dtype="int64"), (max_total_size - max_output_boxes_per_batch,)
)
topk_indices = _op.tile(_op.concatenate([arange, pad], 0), tile_batch_reps)
nmsed_scores = _op.gather(selected_scores, 1, topk_indices)
nmsed_scores = nmsed_scores * valid_mask
return nmsed_scores, topk_indices

def false_branch():
if isinstance(max_output_boxes_per_class, int):
# Do topk on smaller input if possible
slice_mx = _op.const([max_output_boxes_per_class * num_class], dtype="int64")
selected_scores_slice = _op.strided_slice(
selected_scores, begin=_op.const([0], dtype="int64"), end=slice_mx, axes=[1]
)
else:
selected_scores_slice = selected_scores
return _op.topk(selected_scores_slice, k=max_total_size, axis=1, ret_type="both")

# TODO(masahi): support dynamic num_boxes
# return _expr.If(do_zero_pad, true_branch(), false_branch())
return true_branch() if do_zero_pad else false_branch()

assert isinstance(max_output_boxes_per_batch, int), "dynamic number of boxes not supported yet."
nmsed_scores, topk_indices = select_topk(max_output_boxes_per_batch < max_total_size)

indices = _op.take(selected_indices, topk_indices, axis=1, batch_dims=1)
nmsed_box_indices = _op.take(indices, _op.const(1), axis=2)
nmsed_classes = _op.take(indices, _op.const(0), axis=2)
nmsed_classes = _op.cast(nmsed_classes, "float32")
nmsed_boxes = _op.take(boxes, nmsed_box_indices, axis=1, batch_dims=1)
num_detections = _op.minimum(num_detections, _op.const(max_total_size, dtype="int64"))

if clip_boxes:
nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32"))
nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32"))

nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2)

return _expr.TupleWrapper(
_expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, num_detections]), 4
)


def _combined_nms():
def _impl(inputs, attr, params, mod):
# Get parameter values
Expand Down Expand Up @@ -821,9 +903,27 @@ def _impl(inputs, attr, params, mod):
q = boxes_shape[2]
num_classes = scores_shape[2]

if q != num_classes:
# When q is 1, it means same box coords are used for all classes.
boxes = _op.broadcast_to(boxes, (batch_size, num_anchors, num_classes, 4))
assert isinstance(batch_size, int) and isinstance(
num_anchors, int
), "Dynamic inputs not supported yet"

if q == 1:
boxes = _op.squeeze(boxes, axis=[2])
scores_trans = _op.transpose(scores, [0, 2, 1])
max_output_boxes_per_batch = num_anchors * num_classes
return convert_combined_nms_with_all_class_nms(
batch_size,
max_output_boxes_per_batch,
num_classes,
boxes,
scores_trans,
max_output_size,
iou_threshold,
score_threshold,
max_total_size.data.numpy().item(),
attr["clip_boxes"],
)

boxes = _op.reshape(boxes, newshape=[batch_size, num_anchors * num_classes, 4])
scores = _op.reshape(scores, newshape=[batch_size, num_anchors * num_classes, 1])

Expand Down
10 changes: 9 additions & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,15 @@ def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[2]
iou_threshold = inputs[3]
score_threshold = inputs[4]
return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold)
output_format = attrs.output_format
return topi_compute(
inputs[0],
inputs[1],
max_output_size,
iou_threshold,
score_threshold,
output_format,
)

return _compute_nms

Expand Down
22 changes: 20 additions & 2 deletions python/tvm/relay/op/vision/_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def nms_shape_func(attrs, inputs, _):


@script
def _all_class_nms_shape_func(boxes_shape, scores_shape):
def _all_class_nms_shape_func_onnx(boxes_shape, scores_shape):
out_shape = output_tensor((2,), "int64")
count_shape = output_tensor((1,), "int64")

Expand All @@ -99,9 +99,27 @@ def _all_class_nms_shape_func(boxes_shape, scores_shape):
return out_shape, count_shape


@script
def _all_class_nms_shape_func_tf(boxes_shape, scores_shape):
out_indices_shape = output_tensor((3,), "int64")
out_scores_shape = output_tensor((2,), "int64")
count_shape = output_tensor((1,), "int64")

out_indices_shape[0] = boxes_shape[0]
out_indices_shape[1] = scores_shape[1] * boxes_shape[1]
out_indices_shape[2] = int64(2)
out_scores_shape[0] = boxes_shape[0]
out_scores_shape[1] = scores_shape[1] * boxes_shape[1]
count_shape[0] = boxes_shape[0]

return out_indices_shape, out_scores_shape, count_shape


@reg.register_shape_func("vision.all_class_non_max_suppression", False)
def all_class_nms_shape_func(attrs, inputs, _):
return _all_class_nms_shape_func(inputs[0], inputs[1])
if attrs.output_format == "onnx":
return _all_class_nms_shape_func_onnx(inputs[0], inputs[1])
return _all_class_nms_shape_func_tf(inputs[0], inputs[1])


@script
Expand Down
41 changes: 35 additions & 6 deletions python/tvm/relay/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,12 @@ def non_max_suppression(


def all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0
boxes,
scores,
max_output_boxes_per_class=-1,
iou_threshold=-1.0,
score_threshold=-1.0,
output_format="onnx",
):
"""Non-maximum suppression operator for object detection, corresponding to ONNX
NonMaxSuppression and TensorFlow combined_non_max_suppression.
Expand All @@ -175,16 +180,31 @@ def all_class_non_max_suppression(
score_threshold : float or relay.Expr, optional
Score threshold to filter out low score boxes early

output_format : string, optional
"onnx" or "tensorflow". Specify by which frontends the outputs are
intented to be consumed.

Returns
-------
out : relay.Tuple
The output is a relay.Tuple of two tensors, the first is `indices` of size
`(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor
`num_total_detection` of shape `(1,)` representing the total number of selected boxes.
If `output_format` is "onnx", the output is a relay.Tuple of two tensors, the first is
`indices` of size `(batch_size * num_class* num_boxes , 3)` and the second is a scalar
tensor `num_total_detection` of shape `(1,)` representing the total number of selected
boxes. The three values in `indices` encode batch, class, and box indices.
Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come first,
in descending of scores, followed by boxes from batch 0, class 1 etc. Out of
`batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection`
rows are valid.

If `output_format` is "tensorflow", the output is a relay.Tuple of three tensors, the first
is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of
size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size
`(batch_size,)` representing the total number of selected boxes per batch. The two values
in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at
batch b, only the first `num_total_detection[b]` entries are valid. The second axis of
`indices` and `scores` are sorted within each class by box scores, but not across classes.
So the box indices and scores for the class 0 come first in a sorted order, followed by
the class 1 etc.
"""
if not isinstance(max_output_boxes_per_class, expr.Expr):
max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32")
Expand All @@ -194,6 +214,15 @@ def all_class_non_max_suppression(
score_threshold = expr.const(score_threshold, "float32")

out = _make.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
output_format,
)
return expr.TupleWrapper(out, 2)

if output_format == "onnx":
return expr.TupleWrapper(out, 2)

return expr.TupleWrapper(out, 3)
Loading