Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: GlobalLpPool, LpPool
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Dec 12, 2018
1 parent 3afe2fe commit a10ac50
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 25 deletions.
57 changes: 40 additions & 17 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ def convert_pooling(node, **kwargs):
pool_type = attrs["pool_type"]
stride = eval(attrs["stride"]) if attrs.get("stride") else None
global_pool = get_boolean_attribute_value(attrs, "global_pool")
p_value = int(attrs.get('p_value', '2'))

pooling_convention = attrs.get('pooling_convention', 'valid')

Expand All @@ -587,26 +588,48 @@ def convert_pooling(node, **kwargs):

pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
pad_dims = pad_dims + pad_dims
pool_types = {"max": "MaxPool", "avg": "AveragePool"}
global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool"}
pool_types = {"max": "MaxPool", "avg": "AveragePool", "lp": "LpPool"}
global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool",
"lp": "GlobalLpPool"}

if global_pool:
node = onnx.helper.make_node(
global_pool_types[pool_type],
input_nodes, # input
[name],
name=name
)
if pool_type == 'lp':
node = onnx.helper.make_node(
global_pool_types[pool_type],
input_nodes, # input
[name],
p=p_value,
name=name
)
else:
node = onnx.helper.make_node(
global_pool_types[pool_type],
input_nodes, # input
[name],
name=name
)
else:
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name
)
if pool_type == 'lp':
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
p=p_value,
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name
)
else:
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name
)

return [node]

Expand Down
8 changes: 5 additions & 3 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def global_lppooling(attrs, inputs, proto_obj):
'kernel': (1, 1),
'pool_type': 'lp',
'p_value': p_value})
new_attrs = translation_utils._remove_attributes(new_attrs, ['p'])
return 'Pooling', new_attrs, inputs

def linalg_gemm(attrs, inputs, proto_obj):
Expand Down Expand Up @@ -684,11 +685,12 @@ def lp_pooling(attrs, inputs, proto_obj):
new_attrs = translation_utils._fix_attribute_names(attrs,
{'kernel_shape': 'kernel',
'strides': 'stride',
'pads': 'pad',
'p_value': p_value
'pads': 'pad'
})
new_attrs = translation_utils._remove_attributes(new_attrs, ['p'])
new_attrs = translation_utils._add_extra_attributes(new_attrs,
{'pooling_convention': 'valid'
{'pooling_convention': 'valid',
'p_value': p_value
})
new_op = translation_utils._fix_pooling('lp', inputs, new_attrs)
return new_op, new_attrs, inputs
Expand Down
6 changes: 5 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _fix_pooling(pool_type, inputs, new_attr):
stride = new_attr.get('stride')
kernel = new_attr.get('kernel')
padding = new_attr.get('pad')
p_value = new_attr.get('p_value')

# Adding default stride.
if stride is None:
Expand Down Expand Up @@ -138,7 +139,10 @@ def _fix_pooling(pool_type, inputs, new_attr):
new_pad_op = symbol.pad(curr_sym, mode='constant', pad_width=pad_width)

# Apply pooling without pads.
new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel)
if pool_type == 'lp':
new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel, p_value=p_value)
else:
new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel)
return new_pooling_op

def _fix_bias(op_name, attrs, num_inputs):
Expand Down
44 changes: 42 additions & 2 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,11 @@ def forward_pass(sym, arg, aux, data_names, input_data):
# create module
mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)
mod.set_params(arg_params=arg, aux_params=aux,
allow_missing=True, allow_extra=True)
if not arg and not aux:
mod.init_params()
else:
mod.set_params(arg_params=arg, aux_params=aux,
allow_missing=True, allow_extra=True)
# run inference
batch = namedtuple('Batch', ['data'])
mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
Expand Down Expand Up @@ -345,6 +348,43 @@ def test_ops(op_name, inputs, input_tensors, numpy_op):
np.logical_not(input_data[0]).astype(np.float32))


@with_seed()
def testLpPooling():
def test_pooling(opname, data, attrs, p):
input1 = np.random.rand(*data).astype("float32")
inputs = [helper.make_tensor_value_info("input1", TensorProto.FLOAT, shape=data)]
sym = mx.sym.Pooling(mx.sym.Variable('input1'), pool_type='lp', p_value=p, **attrs)
lppool_output = forward_pass(sym, None, None, ['input1'], input1)

lppool_op_tensor = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(lppool_output))]

if attrs.get('global_pool', False):
lppool_node = [helper.make_node(opname, ["input1"], ["output"], p=p)]
else:
lppool_node = [helper.make_node(opname, ["input1"], ["output"], **attrs, p=p)]

lppool_graph = helper.make_graph(lppool_node,
opname+"_test",
inputs,
lppool_op_tensor)

lppool_model = helper.make_model(lppool_graph)

bkd_rep = backend.prepare(lppool_model)
output = bkd_rep.run([input1])

npt.assert_almost_equal(output[0], lppool_output)

ip = (2, 3, 20, 20)
kernel = (4, 5)
pad = (0, 0)
stride = (1, 1)

for p_value in range(1, 3):
test_pooling('LpPool', ip, {'kernel': kernel, 'stride': stride, 'pad': pad}, p=p_value)
test_pooling('GlobalLpPool', ip, {'kernel': kernel, 'stride': stride, 'pad': pad, 'global_pool': True}, p=p_value)


def _assert_sym_equal(lhs, rhs):
assert lhs.list_inputs() == rhs.list_inputs() # input names must be identical
assert len(lhs.list_outputs()) == len(rhs.list_outputs()) # number of outputs must be identical
Expand Down
2 changes: 0 additions & 2 deletions tests/python-pytest/onnx/import/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
'test_transpose',
'test_globalmaxpool',
'test_globalaveragepool',
'test_global_lppooling',
'test_slice_cpu',
'test_slice_neg',
'test_reciprocal',
Expand Down Expand Up @@ -77,7 +76,6 @@
'test_averagepool_2d_precomputed_strides',
'test_averagepool_2d_strides',
'test_averagepool_3d',
'test_LpPool_',
'test_cast',
'test_instancenorm',
#pytorch operator tests
Expand Down

0 comments on commit a10ac50

Please sign in to comment.