Skip to content

Commit

Permalink
fix nnvm compatibility issues
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Nov 18, 2019
1 parent 4d0e056 commit 59d4fb7
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 21 deletions.
6 changes: 5 additions & 1 deletion nnvm/python/nnvm/to_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,11 @@ def _strided_slice(children, attrs, odtype='float32'):
begin = attrs.get_int_list('begin')
end = attrs.get_int_list('end')
strides = attrs.get_int_list('stride', None)
return op.strided_slice(children[0], begin, end, strides=strides)
strides = [1] * len(begin) if strides is None else strides
return op.strided_slice(children[0],
expr.const(list(begin), "int32"),
expr.const(list(end), "int32"),
strides=expr.const(list(strides), "int32"))


def _split(children, attrs, odtype='float32'):
Expand Down
4 changes: 3 additions & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,9 @@ def _test_forward_nms_v3(bx_shape, score_shape, iou_threshold, score_threshold,
max_output_size=out_size, iou_threshold=iou_threshold,
score_threshold=score_threshold, name="nms")
compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'],
'nms/NonMaxSuppressionV3:0', no_gpu=True, mode='vm')
'nms/NonMaxSuppressionV3:0', mode='vm')
compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'],
'nms/NonMaxSuppressionV3:0', mode='debug')

def test_forward_nms_v3():
""" NonMaxSuppressionV3 """
Expand Down
5 changes: 4 additions & 1 deletion topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def invalid_to_bottom_ir(data, flag, idx, out):


@non_max_suppression.register(["cuda", "gpu"])
def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
def non_max_suppression_gpu(data, valid_count, indices, max_output_size=-1,
iou_threshold=0.5, force_suppress=False, top_k=-1,
coord_start=2, score_index=1, id_index=0,
return_indices=True, invalid_to_bottom=False):
Expand All @@ -691,6 +691,9 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
indices : tvm.Tensor
2-D tensor with shape [batch_size, num_anchors].
max_output_size : optional, int
Max number of output valid boxes for each instance.
By default all valid boxes are returned.
Expand Down
4 changes: 3 additions & 1 deletion topi/python/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,9 @@ def _conv2d_legalize(attrs, inputs, arg_types):
new_attrs['channels'] = new_out_channel
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
out = relay.strided_slice(out,
begin=relay.const([0, 0, 0, 0], "int32"),
end=relay.const(original_out_shape, "int32"))
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)

Expand Down
45 changes: 28 additions & 17 deletions topi/tests/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,17 @@ def check_device(device):
tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx)
tvm_out2 = tvm.nd.array(np.zeros(np_out2.shape, dtype=dtype), ctx)
tvm_out3 = tvm.nd.array(np.zeros(np_out3.shape, dtype="int32"), ctx)
f = tvm.build(s, [data, outs[0], outs[1], outs[2]], device)
f(tvm_input_data, tvm_out1, tvm_out2, tvm_out3)
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3)
if device == "llvm":
f = tvm.build(s, [data, outs[0], outs[1], outs[2]], device)
f(tvm_input_data, tvm_out1, tvm_out2, tvm_out3)
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3)
else:
f = tvm.build(s, [data, outs[0], outs[1]], device)
f(tvm_input_data, tvm_out1, tvm_out2)
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)

for device in ['llvm', 'cuda', 'opencl']:
# Disable gpu test for now
Expand Down Expand Up @@ -108,14 +114,16 @@ def check_device(device):
return_indices=False)
indices_out = non_max_suppression(data, valid_count, indices, -1, iou_threshold, force_suppress, top_k,
coord_start=coord_start, score_index=score_index, id_index=id_index)
s = topi.generic.schedule_nms(out)
indices_s = topi.generic.schedule_nms(indices_out[0])
else:
out = topi.cuda.non_max_suppression(data, valid_count, indices, -1, iou_threshold, force_suppress, top_k,
coord_start=coord_start, score_index=score_index, id_index=id_index,
return_indices=False)
indices_out = topi.cuda.non_max_suppression(data, valid_count, indices, -1, iou_threshold, force_suppress, top_k,
coord_start=coord_start, score_index=score_index, id_index=id_index)
s = topi.generic.schedule_nms(out)
indices_s = topi.generic.schedule_nms(indices_out[0])
s = topi.generic.schedule_nms(out)
indices_s = topi.generic.schedule_nms(indices_out)

tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
Expand All @@ -127,10 +135,13 @@ def check_device(device):
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)

tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx)
f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device)
f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
# TODO (yongwww): add dynamic nms for gpu
# tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4)
if device == 'llvm':
f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device)
f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
else:
f = tvm.build(indices_s, [data, valid_count, indices, indices_out], device)
f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4)

for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
Expand All @@ -141,24 +152,24 @@ def test_non_max_suppression():
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110]]]).astype("float32")
np_valid_count = np.array([4]).astype("int32")
np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32")
np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32")
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[4, 0, -1, -1, -1]])
np_indices_result = np.array([[3, 0, -1, -1, -1]])

verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.6, True, 2, 2, 1, 0)
verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.7, True, 2, 2, 1, 0)

np_data = np.array([[[0.8, 1, 20, 25, 45], [0.7, 30, 60, 50, 80],
[0.4, 4, 21, 19, 40], [0.9, 35, 61, 52, 79],
[0.5, 100, 60, 70, 110]]]).astype("float32")
np_valid_count = np.array([4]).astype("int32")
np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32")
np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32")
np_result = np.array([[[0.9, 35, 61, 52, 79], [0.8, 1, 20, 25, 45],
[-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1]]])
np_indices_result = np.array([[4, 0, -1, -1, -1]])
verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.6, False, 2, 1, 0, -1)
np_indices_result = np.array([[3, 0, -1, -1, -1]])
verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.7, False, 2, 1, 0, -1)


def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False):
Expand Down

0 comments on commit 59d4fb7

Please sign in to comment.