Skip to content

Commit

Permalink
[Fix] Add more pad_mode support for onnx converter (apache#4029)
Browse files Browse the repository at this point in the history
* [Fix] Add more pad_mode support for onnx converter

* robustness fix
  • Loading branch information
bindog authored and wweic committed Sep 30, 2019
1 parent b866c7d commit 778c8e8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 20 deletions.
22 changes: 16 additions & 6 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,15 +326,20 @@ def _impl_v1(cls, inputs, attr, params):
for i in range(dims):
pad_width.append((pads[i], pads[i+dims]))
attr['pad_width'] = pad_width
pad_mode = attr.get('mode', 'constant').decode('utf-8')
if pad_mode in ['constant', 'edge', 'reflect']:
attr['pad_mode'] = pad_mode
attr.pop('mode', None)
else:
raise tvm.error.OpAttributeInvalid(
'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.')

return AttrCvt(
_op.nn.pad,
transforms={
'value': 'pad_value',
},
ignores=['mode'],
custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant',
'split mode != constant'))(inputs, attr, params)
)(inputs, attr, params)

@classmethod
def _impl_v2(cls, inputs, attr, params):
Expand All @@ -344,15 +349,20 @@ def _impl_v2(cls, inputs, attr, params):
for i in range(dims):
pad_width.append((pads[i], pads[i+dims]))
attr['pad_width'] = pad_width
pad_mode = attr.get('mode', 'constant').decode('utf-8')
if pad_mode in ['constant', 'edge', 'reflect']:
attr['pad_mode'] = pad_mode
attr.pop('mode', None)
else:
raise tvm.error.OpAttributeInvalid(
'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.')

return AttrCvt(
'pad',
transforms={
'value': 'pad_value',
},
ignores=['mode'],
custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant',
'split mode != constant'))(inputs, attr, params)
)(inputs, attr, params)


class ParametricSoftPlus(OnnxOpConverter):
Expand Down
39 changes: 25 additions & 14 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,21 +781,31 @@ def test_constantfill():
verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6))


def verify_pad(indata, pads, value=0.0):
def verify_pad(indata, pads, mode='constant', value=0.0):
indata = np.array(indata).astype(np.float32)
# numpy expect result
len_dim = len(pads) // 2
np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)]
outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value)
# onnx graph
node = helper.make_node(
'Pad',
inputs=['input'],
outputs=['output'],
mode='constant',
pads=pads,
value=value
)
if mode in ['edge', 'reflect']:
outdata = np.pad(indata, pad_width=np_pads, mode=mode)
node = helper.make_node(
'Pad',
inputs=['input'],
outputs=['output'],
mode=mode,
pads=pads,
)
else:
outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value)
node = helper.make_node(
'Pad',
inputs=['input'],
outputs=['output'],
mode='constant',
pads=pads,
value=value
)
graph = helper.make_graph([node],
'pad_test',
inputs = [helper.make_tensor_value_info("input",
Expand All @@ -809,9 +819,11 @@ def verify_pad(indata, pads, value=0.0):
tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)

def test_pad():
verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 0.0)
verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 0.0)
verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 5.0)
verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 'constant', 0.0)
verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 'constant', 0.0)
verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 'constant', 5.0)
verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge')
verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect')

def verify_reduce_x(name, indata, axis, keepdims):
indata = np.array(indata).astype(np.float32)
Expand Down Expand Up @@ -1266,7 +1278,6 @@ def test_erf():
test_forward_arg_min_max()
test_softmax()
test_constantfill()
test_pad()
test_reduce_max()
test_reduce_min()
test_reduce_sum()
Expand Down

0 comments on commit 778c8e8

Please sign in to comment.