From 59d4fb7095f6d56d12940c6323bd253df64e98db Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Sun, 17 Nov 2019 18:39:48 +0000 Subject: [PATCH] fix nnvm compatibility issues --- nnvm/python/nnvm/to_relay.py | 6 ++- .../frontend/tensorflow/test_forward.py | 4 +- topi/python/topi/cuda/nms.py | 5 ++- topi/python/topi/x86/conv2d_alter_op.py | 4 +- topi/tests/python/test_topi_vision.py | 45 ++++++++++++------- 5 files changed, 43 insertions(+), 21 deletions(-) diff --git a/nnvm/python/nnvm/to_relay.py b/nnvm/python/nnvm/to_relay.py index 94a736dabe70..3e7cb2c35ec0 100644 --- a/nnvm/python/nnvm/to_relay.py +++ b/nnvm/python/nnvm/to_relay.py @@ -244,7 +244,11 @@ def _strided_slice(children, attrs, odtype='float32'): begin = attrs.get_int_list('begin') end = attrs.get_int_list('end') strides = attrs.get_int_list('stride', None) - return op.strided_slice(children[0], begin, end, strides=strides) + strides = [1] * len(begin) if strides is None else strides + return op.strided_slice(children[0], + expr.const(list(begin), "int32"), + expr.const(list(end), "int32"), + strides=expr.const(list(strides), "int32")) def _split(children, attrs, odtype='float32'): diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 61f1f440dc24..d5b598ecb675 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1605,7 +1605,9 @@ def _test_forward_nms_v3(bx_shape, score_shape, iou_threshold, score_threshold, max_output_size=out_size, iou_threshold=iou_threshold, score_threshold=score_threshold, name="nms") compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'], - 'nms/NonMaxSuppressionV3:0', no_gpu=True, mode='vm') + 'nms/NonMaxSuppressionV3:0', mode='vm') + compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'], + 'nms/NonMaxSuppressionV3:0', mode='debug') def test_forward_nms_v3(): """ NonMaxSuppressionV3 """ diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index d032527ec273..62b504c18263 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -675,7 +675,7 @@ def invalid_to_bottom_ir(data, flag, idx, out): @non_max_suppression.register(["cuda", "gpu"]) -def non_max_suppression_gpu(data, valid_count, max_output_size=-1, +def non_max_suppression_gpu(data, valid_count, indices, max_output_size=-1, iou_threshold=0.5, force_suppress=False, top_k=-1, coord_start=2, score_index=1, id_index=0, return_indices=True, invalid_to_bottom=False): @@ -691,6 +691,9 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1, valid_count : tvm.Tensor 1-D tensor for valid number of boxes. + indices : tvm.Tensor + 2-D tensor with shape [batch_size, num_anchors]. + max_output_size : optional, int Max number of output valid boxes for each instance. By default all valid boxes are returned. diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index f596bc0eb503..b6408e22f67c 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -302,7 +302,9 @@ def _conv2d_legalize(attrs, inputs, arg_types): new_attrs['channels'] = new_out_channel out = tvm.relay.nn.conv2d(data, kernel, **new_attrs) original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape) + out = relay.strided_slice(out, + begin=relay.const([0, 0, 0, 0], "int32"), + end=relay.const(original_out_shape, "int32")) else: out = relay.nn.conv2d(data, kernel, **new_attrs) diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index fd686216b45a..0bae0edc73ce 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -65,11 +65,17 @@ def check_device(device): tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx) tvm_out2 = tvm.nd.array(np.zeros(np_out2.shape, dtype=dtype), ctx) tvm_out3 = tvm.nd.array(np.zeros(np_out3.shape, dtype="int32"), ctx) - f = tvm.build(s, [data, outs[0], outs[1], outs[2]], device) - f(tvm_input_data, tvm_out1, tvm_out2, tvm_out3) - tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) - tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) - tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3) + if device == "llvm": + f = tvm.build(s, [data, outs[0], outs[1], outs[2]], device) + f(tvm_input_data, tvm_out1, tvm_out2, tvm_out3) + tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3) + else: + f = tvm.build(s, [data, outs[0], outs[1]], device) + f(tvm_input_data, tvm_out1, tvm_out2) + tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) for device in ['llvm', 'cuda', 'opencl']: # Disable gpu test for now @@ -108,14 +114,16 @@ def check_device(device): return_indices=False) indices_out = non_max_suppression(data, valid_count, indices, -1, iou_threshold, force_suppress, top_k, coord_start=coord_start, score_index=score_index, id_index=id_index) + s = topi.generic.schedule_nms(out) + indices_s = topi.generic.schedule_nms(indices_out[0]) else: out = topi.cuda.non_max_suppression(data, valid_count, indices, -1, iou_threshold, force_suppress, top_k, coord_start=coord_start, score_index=score_index, id_index=id_index, return_indices=False) indices_out = topi.cuda.non_max_suppression(data, valid_count, indices, -1, iou_threshold, force_suppress, top_k, coord_start=coord_start, score_index=score_index, id_index=id_index) - s = topi.generic.schedule_nms(out) - indices_s = topi.generic.schedule_nms(indices_out[0]) + s = topi.generic.schedule_nms(out) + indices_s = topi.generic.schedule_nms(indices_out) tvm_data = tvm.nd.array(np_data, ctx) tvm_valid_count = tvm.nd.array(np_valid_count, ctx) @@ -127,10 +135,13 @@ def check_device(device): tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4) tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx) - f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device) - f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out) - # TODO (yongwww): add dynamic nms for gpu - # tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) + if device == 'llvm': + f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device) + f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out) + else: + f = tvm.build(indices_s, [data, valid_count, indices, indices_out], device) + f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out) + tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) for device in ['llvm', 'cuda', 'opencl']: check_device(device) @@ -141,24 +152,24 @@ def test_non_max_suppression(): [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], [1, 0.5, 100, 60, 70, 110]]]).astype("float32") np_valid_count = np.array([4]).astype("int32") - np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32") + np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32") np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) - np_indices_result = np.array([[4, 0, -1, -1, -1]]) + np_indices_result = np.array([[3, 0, -1, -1, -1]]) - verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.6, True, 2, 2, 1, 0) + verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.7, True, 2, 2, 1, 0) np_data = np.array([[[0.8, 1, 20, 25, 45], [0.7, 30, 60, 50, 80], [0.4, 4, 21, 19, 40], [0.9, 35, 61, 52, 79], [0.5, 100, 60, 70, 110]]]).astype("float32") np_valid_count = np.array([4]).astype("int32") - np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32") + np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32") np_result = np.array([[[0.9, 35, 61, 52, 79], [0.8, 1, 20, 25, 45], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]]]) - np_indices_result = np.array([[4, 0, -1, -1, -1]]) - verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.6, False, 2, 1, 0, -1) + np_indices_result = np.array([[3, 0, -1, -1, -1]]) + verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.7, False, 2, 1, 0, -1) def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False):