From 778c8e8226ca609a3e0b4b0081b80d69c1fe6213 Mon Sep 17 00:00:00 2001 From: bindog Date: Sun, 29 Sep 2019 01:22:01 +0800 Subject: [PATCH] [Fix] Add more pad_mode support for onnx converter (#4029) * [Fix] Add more pad_mode support for onnx converter * robustness fix --- python/tvm/relay/frontend/onnx.py | 22 ++++++++---- tests/python/frontend/onnx/test_forward.py | 39 ++++++++++++++-------- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index e5309367c08be..387e9cfe3ce4f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -326,15 +326,20 @@ def _impl_v1(cls, inputs, attr, params): for i in range(dims): pad_width.append((pads[i], pads[i+dims])) attr['pad_width'] = pad_width + pad_mode = attr.get('mode', 'constant').decode('utf-8') + if pad_mode in ['constant', 'edge', 'reflect']: + attr['pad_mode'] = pad_mode + attr.pop('mode', None) + else: + raise tvm.error.OpAttributeInvalid( + 'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.') return AttrCvt( _op.nn.pad, transforms={ 'value': 'pad_value', }, - ignores=['mode'], - custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant', - 'split mode != constant'))(inputs, attr, params) + )(inputs, attr, params) @classmethod def _impl_v2(cls, inputs, attr, params): @@ -344,15 +349,20 @@ def _impl_v2(cls, inputs, attr, params): for i in range(dims): pad_width.append((pads[i], pads[i+dims])) attr['pad_width'] = pad_width + pad_mode = attr.get('mode', 'constant').decode('utf-8') + if pad_mode in ['constant', 'edge', 'reflect']: + attr['pad_mode'] = pad_mode + attr.pop('mode', None) + else: + raise tvm.error.OpAttributeInvalid( + 'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.') return AttrCvt( 'pad', transforms={ 'value': 'pad_value', }, - ignores=['mode'], - custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant', - 'split mode != constant'))(inputs, attr, params) + )(inputs, attr, params) class ParametricSoftPlus(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 275ead5c80417..dc9493f322375 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -781,21 +781,31 @@ def test_constantfill(): verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6)) -def verify_pad(indata, pads, value=0.0): +def verify_pad(indata, pads, mode='constant', value=0.0): indata = np.array(indata).astype(np.float32) # numpy expect result len_dim = len(pads) // 2 np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)] - outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value) # onnx graph - node = helper.make_node( - 'Pad', - inputs=['input'], - outputs=['output'], - mode='constant', - pads=pads, - value=value - ) + if mode in ['edge', 'reflect']: + outdata = np.pad(indata, pad_width=np_pads, mode=mode) + node = helper.make_node( + 'Pad', + inputs=['input'], + outputs=['output'], + mode=mode, + pads=pads, + ) + else: + outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value) + node = helper.make_node( + 'Pad', + inputs=['input'], + outputs=['output'], + mode='constant', + pads=pads, + value=value + ) graph = helper.make_graph([node], 'pad_test', inputs = [helper.make_tensor_value_info("input", @@ -809,9 +819,11 @@ def verify_pad(indata, pads, value=0.0): tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) def test_pad(): - verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 0.0) - verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 0.0) - verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 5.0) + verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 'constant', 0.0) + verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 'constant', 0.0) + verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 'constant', 5.0) + verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge') + verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect') def verify_reduce_x(name, indata, axis, keepdims): indata = np.array(indata).astype(np.float32) @@ -1266,7 +1278,6 @@ def test_erf(): test_forward_arg_min_max() test_softmax() test_constantfill() - test_pad() test_reduce_max() test_reduce_min() test_reduce_sum()