Skip to content

Commit

Permalink
[Relay] symbolic max_output_size (apache#5844)
Browse files Browse the repository at this point in the history
* symbolic max_output_size

* pylint

* fix ci
  • Loading branch information
yongwww authored and zhiics committed Jul 2, 2020
1 parent 4eb0783 commit 8aeb02d
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 86 deletions.
8 changes: 2 additions & 6 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs> {

/*! \brief Attributes used in non_maximum_suppression operator */
struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionAttrs> {
int max_output_size;
Optional<Integer> max_output_size;
double iou_threshold;
bool force_suppress;
int top_k;
Expand All @@ -99,11 +99,7 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
bool invalid_to_bottom;

TVM_DECLARE_ATTRS(NonMaximumSuppressionAttrs, "relay.attrs.NonMaximumSuppressionAttrs") {
TVM_ATTR_FIELD(max_output_size)
.set_default(-1)
.describe(
"Max number of output valid boxes for each instance."
"By default all valid boxes are returned.");
TVM_ATTR_FIELD(max_output_size).describe("Max number of output valid boxes for each instance.");
TVM_ATTR_FIELD(iou_threshold)
.set_default(0.5)
.describe("Non-maximum suppression iou threshold.");
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,6 @@ def _impl(inputs, attr, params, mod):
def _nms():
def _impl(inputs, attr, params, mod):
# Get parameter values
# TODO(yongwww) change nms in relay to support symbolic max_output_size
try:
max_output_size = int(np.atleast_1d(inputs[2].data.asnumpy()
.astype("int64"))[0])
Expand All @@ -634,7 +633,7 @@ def _impl(inputs, attr, params, mod):
max_output_size = _infer_value(inputs[2], params,
mod).asnumpy().astype("int64").tolist()[0]
except Exception:
max_output_size = -1
max_output_size = inputs[2]
iou_threshold = np.atleast_1d(inputs[3].data.asnumpy())[0]
# score_threshold was introduced from V3
score_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0] if len(inputs) > 4 else 0.0
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,8 +747,10 @@ def get_valid_counts_strategy(attrs, inputs, out_type, target):
def wrap_compute_nms(topi_compute):
"""wrap nms topi compute"""
def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[3]
if attrs.max_output_size is not None:
max_output_size = attrs.max_output_size
return_indices = bool(get_const_int(attrs.return_indices))
max_output_size = get_const_int(attrs.max_output_size)
iou_threshold = get_const_float(attrs.iou_threshold)
force_suppress = bool(get_const_int(attrs.force_suppress))
top_k = get_const_int(attrs.top_k)
Expand Down
8 changes: 5 additions & 3 deletions python/tvm/relay/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def non_max_suppression(data,
second dimension are like the output of arange(num_anchors)
if get_valid_counts is not used before non_max_suppression.
max_output_size : int, optional
max_output_size : int or relay.Expr, optional
Max number of output valid boxes for each instance.
By default all valid boxes are returned.
Return all valid boxes if the value of max_output_size is less than 0.
iou_threshold : float, optional
Non-maximum suppression threshold.
Expand Down Expand Up @@ -124,9 +124,11 @@ def non_max_suppression(data,
out : relay.Expr or relay.Tuple
return relay.Expr if return_indices is disabled, a 3-D tensor
with shape [batch_size, num_anchors, 6] or [batch_size, num_anchors, 5].
if return_indices is True, return relay.Tuple of two 2-D tensors, with
If return_indices is True, return relay.Tuple of two 2-D tensors, with
shape [batch_size, num_anchors] and [batch_size, num_valid_anchors] respectively.
"""
if isinstance(max_output_size, int):
max_output_size = expr.const(max_output_size, "int32")
out = _make.non_max_suppression(data,
valid_count,
indices,
Expand Down
14 changes: 7 additions & 7 deletions src/relay/op/vision/nms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs);

bool NMSRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
CHECK_EQ(types.size(), 5);
const auto* data = types[0].as<TensorTypeNode>();
const auto* valid_count = types[1].as<TensorTypeNode>();
const NonMaximumSuppressionAttrs* param = attrs.as<NonMaximumSuppressionAttrs>();
Expand All @@ -90,18 +90,17 @@ bool NMSRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
fields.push_back(TensorType(oshape, DataType::Int(32)));
std::vector<IndexExpr> countshape({dshape[0], 1});
fields.push_back(TensorType(countshape, DataType::Int(32)));
reporter->Assign(types[3], TupleType(Array<Type>(fields)));
reporter->Assign(types[4], TupleType(Array<Type>(fields)));
} else {
reporter->Assign(types[3], TensorType(dshape, data->dtype));
reporter->Assign(types[4], TensorType(dshape, data->dtype));
}
return true;
}

Expr MakeNMS(Expr data, Expr valid_count, Expr indices, int max_output_size, double iou_threshold,
Expr MakeNMS(Expr data, Expr valid_count, Expr indices, Expr max_output_size, double iou_threshold,
bool force_suppress, int top_k, int coord_start, int score_index, int id_index,
bool return_indices, bool invalid_to_bottom) {
auto attrs = make_object<NonMaximumSuppressionAttrs>();
attrs->max_output_size = max_output_size;
attrs->iou_threshold = iou_threshold;
attrs->force_suppress = force_suppress;
attrs->top_k = top_k;
Expand All @@ -111,7 +110,7 @@ Expr MakeNMS(Expr data, Expr valid_count, Expr indices, int max_output_size, dou
attrs->return_indices = return_indices;
attrs->invalid_to_bottom = invalid_to_bottom;
static const Op& op = Op::Get("vision.non_max_suppression");
return Call(op, {data, valid_count, indices}, Attrs(attrs), {});
return Call(op, {data, valid_count, indices, max_output_size}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression").set_body_typed(MakeNMS);
Expand All @@ -122,10 +121,11 @@ be in the format of [class_id, score, left, top, right, bottom]
or [score, left, top, right, bottom]. Set id_index to be -1 to
ignore class_id axis.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(3)
.set_num_inputs(4)
.add_argument("data", "Tensor", "Input data.")
.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.")
.add_argument("indices", "Tensor", "Corresponding indices in original input tensor.")
.add_argument("max_output_size", "Tensor", "Max number of output valid boxes.")
.set_support_level(5)
.add_type_rel("NMS", NMSRel);

Expand Down
12 changes: 7 additions & 5 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2035,22 +2035,24 @@ def test_forward_crop_and_resize():
def _test_forward_nms_v3(bx_shape, score_shape, iou_threshold, score_threshold, out_size, dtype="float32"):
boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype)
scores = np.random.uniform(size=score_shape).astype(dtype)
max_output_size = np.int32(out_size)
tf.reset_default_graph()
in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1")
in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2")
tf.image.non_max_suppression(boxes=in_data_1, scores=in_data_2,
max_output_size=out_size, iou_threshold=iou_threshold,
score_threshold=score_threshold, name="nms")
compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'],
in_data_3 = tf.placeholder(tf.int32, name="in_data_3")
tf.image.non_max_suppression(boxes=in_data_1, scores=in_data_2, max_output_size=in_data_3,
iou_threshold=iou_threshold, score_threshold=score_threshold, name="nms")
compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'],
'nms/NonMaxSuppressionV3:0', mode='vm')
compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'],
compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'],
'nms/NonMaxSuppressionV3:0', mode='debug')

def test_forward_nms_v3():
""" NonMaxSuppressionV3 """
_test_forward_nms_v3((5, 4), (5,), 0.7, 0.5, 5)
_test_forward_nms_v3((20, 4), (20,), 0.5, 0.6, 10)
_test_forward_nms_v3((1000, 4), (1000,), 0.3, 0.7, 1000)
_test_forward_nms_v3((2000, 4), (2000,), 0.4, 0.6, 7)


#######################################################################
Expand Down
40 changes: 21 additions & 19 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,16 +283,17 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):


def test_non_max_suppression():
def verify_nms(x0_data, x1_data, x2_data, dshape, ref_res, ref_indices_res,
iou_threshold=0.5, force_suppress=False, top_k=-1,
check_type_only=False):
def verify_nms(x0_data, x1_data, x2_data, x3_data, dshape, ref_res,
ref_indices_res, iou_threshold=0.5, force_suppress=False,
top_k=-1, check_type_only=False):
x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32"))
x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int32"))
x2 = relay.var("x2", relay.ty.TensorType((dshape[0], dshape[1]), "int32"))
z = relay.vision.non_max_suppression(x0, x1, x2, max_output_size=-1, \
x3 = relay.var("x3", relay.ty.TensorType((), "int32"))
z = relay.vision.non_max_suppression(x0, x1, x2, x3, \
iou_threshold=iou_threshold, force_suppress=force_suppress, \
top_k=top_k, return_indices=False)
z_indices = relay.vision.non_max_suppression(x0, x1, x2, max_output_size=-1, \
z_indices = relay.vision.non_max_suppression(x0, x1, x2, x3, \
iou_threshold=iou_threshold, force_suppress=force_suppress, \
top_k=top_k, return_indices=True)
if isinstance(z_indices, relay.expr.TupleWrapper):
Expand All @@ -309,30 +310,30 @@ def verify_nms(x0_data, x1_data, x2_data, dshape, ref_res, ref_indices_res,
if check_type_only:
return

func = relay.Function([x0, x1, x2], z)
func = relay.Function([x0, x1, x2, x3], z)
func = run_infer_type(func)
func_indices = relay.Function([x0, x1, x2], z_indices)
func_indices = relay.Function([x0, x1, x2, x3], z_indices)
func_indices = run_infer_type(func_indices)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x0_data, x1_data, x2_data)
op_res1 = intrp1.evaluate(func)(x0_data, x1_data, x2_data, x3_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data)
op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data, x3_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
if target == 'cuda':
return
op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data)
op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data)
tvm.testing.assert_allclose(op_indices_res1[0].asnumpy(), ref_indices_res, rtol=1e-5)
op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data)
op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data)
tvm.testing.assert_allclose(op_indices_res2[0].asnumpy(), ref_indices_res, rtol=1e-5)

np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110]]]).astype("float32")
np_valid_count = np.array([4]).astype("int32")

np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32")
np_max_output_size = -1

np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
Expand All @@ -341,22 +342,23 @@ def verify_nms(x0_data, x1_data, x2_data, dshape, ref_res, ref_indices_res,
num_anchors = 5

dshape = (te.size_var("n"), num_anchors, 6)
verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result,
verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result, np_indices_result,
force_suppress=True, top_k=2, check_type_only=True)
dshape = (1, num_anchors, 6)
verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result,
verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result, np_indices_result,
force_suppress=True, top_k=2, check_type_only=False)

np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[1, 0.7, 30, 60, 50, 80], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[4, 0, 1, -1, -1]])
np_indices_result = np.array([[4, 0, -1, -1, -1]])
np_max_output_size = 2
dshape = (te.size_var("n"), num_anchors, 6)
verify_nms(np_data, np_valid_count, np_indices, dshape, np_result,
verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result,
np_indices_result, check_type_only=True)
dshape = (1, num_anchors, 6)
verify_nms(np_data, np_valid_count, np_indices, dshape, np_result,
np_indices_result, top_k=3)
verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result,
np_indices_result, top_k=2)


def test_multibox_transform_loc():
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def non_max_suppression(data, valid_count, indices, max_output_size=-1,
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
name="nms",
tag="nms")

# TODO(yongwww): Update cuda nms to be consistent with cpu version
if return_indices:
return box_indices

Expand Down
Loading

0 comments on commit 8aeb02d

Please sign in to comment.