Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX Supoort for MXNet reverse op (#19737)
Browse files Browse the repository at this point in the history
* reverse

* Update _op_translations.py
  • Loading branch information
Zha0q1 authored Jan 13, 2021
1 parent aa00f4b commit 13c449e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
41 changes: 41 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2784,6 +2784,47 @@ def convert_arange(node, **kwargs):
return nodes


@mx_op.register("reverse")
def convert_reverse(node, **kwargs):
"""Map MXNet's reverse operator attributes to ONNX
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

axis = int(attrs.get('axis', 0))

# Transpose takes perm as a parameter, so we must 'pad' the input to a known dim (10 here)
perm = [i for i in range(10)]
perm[0], perm[axis] = axis, 0

nodes = [
create_tensor([10], name+'_10', kwargs['initializer']),
create_tensor([0], name+'_0', kwargs['initializer']),
create_tensor([1], name+'_1', kwargs['initializer']),
create_tensor([-1], name+'_m1', kwargs['initializer']),
create_tensor([axis], name+'_axis', kwargs['initializer']),
create_tensor([axis+1], name+'_axis_p1', kwargs['initializer']),
create_tensor([], name+'_void', kwargs['initializer']),
create_const_scalar_node(name+'_m1_s', np.array([-1], dtype='int64'), kwargs),
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Shape', [name+'_shape'], [name+'_dim']),
make_node('Sub', [name+'_10', name+'_dim'], [name+'_sub']),
make_node('Concat', [name+'_0', name+'_sub'], [name+'_concat'], axis=0),
make_node('Pad', [name+'_shape', name+'_concat', name+'_1'], [name+'_shape_10_dim']),
make_node('Reshape', [input_nodes[0], name+'_shape_10_dim'], [name+'_data_10_dim']),
make_node('Transpose', [name+'_data_10_dim'], [name+'_data_t'], perm=perm),
make_node('Slice', [name+'_shape', name+'_axis', name+'_axis_p1'], [name+'_axis_len']),
make_node('Sub', [name+'_axis_len', name+'_1'], [name+'_axis_len_m1']),
make_node('Reshape', [name+'_axis_len_m1', name+'_void'], [name+'_axis_len_m1_s']),
make_node('Range', [name+'_axis_len_m1_s', name+'_m1_s', name+'_m1_s'], [name+'_indices']),
make_node('Gather', [name+'_data_t', name+'_indices'], [name+'_gather']),
make_node('Transpose', [name+'_gather'], [name+'_data_reversed'], perm=perm),
make_node('Reshape', [name+'_data_reversed', name+'_shape'], [name], name=name)
]

return nodes


@mx_op.register('repeat')
def convert_repeat(node, **kwargs):
"""Map MXNet's repeat operator attributes to onnx's Tile operator.
Expand Down
8 changes: 8 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,14 @@ def test_onnx_export_softmax(tmp_path, dtype):
op_export_test('softmax_4', M4, [x, l4], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
@pytest.mark.parametrize('axis', [0, 1, 2, 3])
def test_onnx_export_reverse(tmp_path, dtype, axis):
x = mx.nd.arange(0, 120, dtype=dtype).reshape((2, 3, 4, 5))
M = def_model('reverse', axis=axis)
op_export_test('reverse', M, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
@pytest.mark.parametrize('axis', [None, 0, 1, 2, -1, -2, -3])
@pytest.mark.parametrize('repeats', [2, 1, 3])
Expand Down

0 comments on commit 13c449e

Please sign in to comment.