Skip to content

Commit

Permalink
[ONNX]LpPool Support added (apache#5696)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and Trevor Morris committed Jun 9, 2020
1 parent 92f4fd1 commit 22761ab
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 0 deletions.
52 changes: 52 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,57 @@ class MaxPool(Pool):
"""
name = 'max_pool'

class LpPool(OnnxOpConverter):
""" A helper class for lppool op converters.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
input_shape = infer_shape(inputs[0])
dtype = infer_type(inputs[0]).checked_type.dtype

if 'auto_pad' in attr:
attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
pad_tuple = []
for axis in range(len(input_shape) - 2):
axis_shape = input_shape[2 + axis]
stride = attr['strides'][axis]
kernel = attr['kernel_shape'][axis]
pad = get_pad_pair(axis_shape, kernel, stride)
pad_tuple.append(pad)
pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
attr['pads'] = pad_tuple
elif attr['auto_pad'] == 'VALID':
attr['pads'] = 0
elif attr['auto_pad'] == 'NOTSET':
pass
else:
msg = 'Value {} in attribute "auto_pad" of operator {} is invalid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'], "LpPool"))
attr.pop("auto_pad")

if 'storage_order' in attr:
attr['layout'] = onnx_storage_order2layout(attr['storage_order'],
dims=(len(input_shape) - 2))
else:
attr['layout'] = onnx_default_layout(dims=(len(input_shape) - 2))

p = _expr.const(attr['p'], dtype)
reci_p = _expr.const(1.0 / attr['p'], dtype)
inputs[0] = _op.power(inputs[0], p)

out = AttrCvt(op_name=dimension_picker("avg_pool"),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', 0)
},
extras={'count_include_pad': True},
ignores=['p'],
custom_check=dimension_constraint())(inputs, attr, params)
kernels = attr['kernel_shape']
out = _op.abs(out) * _expr.const(np.prod(kernels).astype(dtype))
return _op.power(out, reci_p)


class Mul(Elemwise):
""" Operator converter for Multiply.
Expand Down Expand Up @@ -1660,6 +1711,7 @@ def _get_convert_map(opset):

# defs/nn
'AveragePool': AveragePool.get_converter(opset),
'LpPool': LpPool.get_converter(opset),
'MaxPool': MaxPool.get_converter(opset),
'Conv': Conv.get_converter(opset),
'ConvTranspose': ConvTranspose.get_converter(opset),
Expand Down
71 changes: 71 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2304,6 +2304,76 @@ def test_pooling():
auto_pad='SAME_UPPER')


def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="NOTSET"):
x_np = np.random.uniform(size=x_shape).astype('float32')

if pads is None:
pool_node = helper.make_node("LpPool",
inputs=["x"],
outputs=["y"],
kernel_shape=kernel_shape,
p = p,
auto_pad=auto_pad,
strides=strides)
else:
pool_node = helper.make_node("LpPool",
inputs=["x"],
outputs=["y"],
kernel_shape=kernel_shape,
p = p,
pads=pads,
strides=strides)

graph = helper.make_graph([pool_node],
"lppool_test",
inputs=[helper.make_tensor_value_info("x",
TensorProto.FLOAT, list(x_shape))],
outputs=[helper.make_tensor_value_info("y",
TensorProto.FLOAT, list(out_shape))])

model = helper.make_model(graph, producer_name='lppool_test')

for target, ctx in ctx_list():
onnx_out = get_onnxruntime_output(model, x_np, 'float32')
tvm_out = get_tvm_output(
model, [x_np], target, ctx, out_shape)
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)


def test_lppool():
# Pool1D
verify_lppool(x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[1], pads=[1, 1],
out_shape=[1, 1, 32])

# Pool2D
verify_lppool(x_shape=[1, 1, 32, 32], kernel_shape=[3, 3], p=2, strides=[1, 1],
pads=[1, 1, 1, 1], out_shape=[1, 1, 32, 32])

# Pool1D with stride
verify_lppool(x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[2], pads=[1, 1],
out_shape=[1, 1, 16])

# Pool2D with stride
verify_lppool(x_shape=[1, 1, 32, 32], kernel_shape=[3, 3], p=2, strides=[2, 2],
pads=[1, 1, 1, 1], out_shape=[1, 1, 16, 16])

# Pool1D with stride and autopadding
verify_lppool(x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[2], pads=None,
out_shape=[1, 1, 16], auto_pad='SAME_UPPER')

# Pool2D with stride and autopadding
verify_lppool(x_shape=[1, 1, 32, 32], kernel_shape=[3, 3], p=2, strides=[2, 2],
pads=None, out_shape=[1, 1, 16, 16], auto_pad='SAME_UPPER')

# Pool3D with stride
verify_lppool(x_shape=[1, 1, 32, 32, 32], kernel_shape=[3, 3, 3], p=2, strides=[2, 2, 2],
pads=[1, 1, 1, 1, 1, 1], out_shape=[1, 1, 16, 16, 16])

# Pool3D with stride and autopadding
verify_lppool(x_shape=[1, 1, 32, 32, 32], kernel_shape=[3, 3, 3], p=2, strides=[2, 2, 2],
pads=None, out_shape=[1, 1, 16, 16, 16], auto_pad='SAME_UPPER')


def verify_lstm(seq_length,
batch_size,
input_size,
Expand Down Expand Up @@ -2722,6 +2792,7 @@ def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_
test_convtranspose()
test_unsqueeze_constant()
test_pooling()
test_lppool()
test_lstm()
test_resize()
test_nonzero()
Expand Down

0 comments on commit 22761ab

Please sign in to comment.