Skip to content

Commit

Permalink
[Relay][OP] Support NMSv4 ingestion from TF. (apache#6085)
Browse files Browse the repository at this point in the history
  • Loading branch information
csullivan authored and Trevor Morris committed Aug 26, 2020
1 parent c342b93 commit 8df1ec8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
6 changes: 5 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,10 +637,11 @@ def _impl(inputs, attr, params, mod):
iou_threshold = np.atleast_1d(inputs[3].data.asnumpy())[0]
# score_threshold was introduced from V3
score_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0] if len(inputs) > 4 else 0.0
pad_output = 'pad_to_max_output_size'

# Generate data with shape (1, num_anchors, 5)
scores = AttrCvt(op_name="expand_dims",
ignores=['T_threshold'],
ignores=['T_threshold', pad_output],
extras={'axis': -1, 'num_newaxis': 1})([inputs[1]], attr)
data = get_relay_op('concatenate')([scores, inputs[0]], -1)
data = get_relay_op('expand_dims')(data, 0, 1)
Expand All @@ -667,6 +668,8 @@ def _impl(inputs, attr, params, mod):
return_indices=True,
invalid_to_bottom=False)

if pad_output in attr and attr[pad_output]:
return nms_ret
# squeeze it, TF NMS is not batched
size = get_relay_op("squeeze")(nms_ret[1], axis=[1])
data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])
Expand Down Expand Up @@ -2152,6 +2155,7 @@ def _impl(inputs, attr, params, mod):
'Neg' : AttrCvt('negative'),
'NonMaxSuppressionV2' : _nms(),
'NonMaxSuppressionV3' : _nms(),
'NonMaxSuppressionV4' : _nms(),
'NoOp' : _no_op(),
'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
Expand Down
33 changes: 26 additions & 7 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,12 +2046,31 @@ def _test_forward_nms_v3(bx_shape, score_shape, iou_threshold, score_threshold,
compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'],
'nms/NonMaxSuppressionV3:0', mode='debug')

def test_forward_nms_v3():
""" NonMaxSuppressionV3 """
_test_forward_nms_v3((5, 4), (5,), 0.7, 0.5, 5)
_test_forward_nms_v3((20, 4), (20,), 0.5, 0.6, 10)
_test_forward_nms_v3((1000, 4), (1000,), 0.3, 0.7, 1000)
_test_forward_nms_v3((2000, 4), (2000,), 0.4, 0.6, 7)
def _test_forward_nms_v4(bx_shape, score_shape, iou_threshold, score_threshold, out_size, dtype="float32"):
boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype)
scores = np.random.uniform(size=score_shape).astype(dtype)
max_output_size = np.int32(out_size)
tf.reset_default_graph()
in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1")
in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2")
in_data_3 = tf.placeholder(tf.int32, name="in_data_3")
indices_padded, num_valid = tf.image.non_max_suppression_padded(boxes=in_data_1, scores=in_data_2, max_output_size=in_data_3,
iou_threshold=iou_threshold, score_threshold=score_threshold, name="nms", pad_to_max_output_size=True)
num_valid = tf.reshape(num_valid,shape=(-1,))
indices_padded = tf.reshape(indices_padded, shape=(-1,))
tf.slice(indices_padded, tf.constant([0]), num_valid, name="SlicedIndices")
compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'],
['nms/NonMaxSuppressionV4:1', "SlicedIndices:0"], mode='vm')
compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'],
['nms/NonMaxSuppressionV4:1', "SlicedIndices:0"], mode='debug')

def test_forward_nms():
""" NonMaxSuppressionV3,4 """
for _test_forward_nms in [_test_forward_nms_v3, _test_forward_nms_v4]:
_test_forward_nms((5, 4), (5,), 0.7, 0.5, 5)
_test_forward_nms((20, 4), (20,), 0.5, 0.6, 10)
_test_forward_nms((1000, 4), (1000,), 0.3, 0.7, 1000)
_test_forward_nms((2000, 4), (2000,), 0.4, 0.6, 7)


#######################################################################
Expand Down Expand Up @@ -3883,7 +3902,7 @@ def lstm_cell():
test_forward_truncatemod()
test_forward_one_hot()
test_forward_atan2()
test_forward_nms_v3()
test_forward_nms()

# Activations
test_forward_sigmoid()
Expand Down

0 comments on commit 8df1ec8

Please sign in to comment.