Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX export support for RNN and sum_axis (#20226)
Browse files Browse the repository at this point in the history
* export support RNN

* add sum_axis

* fix sanity

* fix sanity

* fix sanity

* change regiester sum_axis

Co-authored-by: Wei Chu <weichu@amazon.com>
  • Loading branch information
waytrue17 and Wei Chu authored Apr 30, 2021
1 parent 4056c07 commit 2127c3e
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2137,7 +2137,9 @@ def convert_square(node, **kwargs):
)
return [tensor_node, node]

# sum_axis is equivalent to sum in MXNet
@mx_op.register("sum")
@mx_op.register("sum_axis")
def convert_sum(node, **kwargs):
"""Map MXNet's sum operator attributes to onnx's ReduceSum operator
and return the created node.
Expand Down Expand Up @@ -4476,12 +4478,12 @@ def convert_RNN(node, **kwargs):
initial_h = input_nodes[2]

nodes = []
create_tensor([0], name+'_0', kwargs['initializer'])

mode = str(attrs.get('mode'))
if mode == 'lstm':
initial_c = input_nodes[3]
if num_layers == 2:
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
Expand Down Expand Up @@ -4553,7 +4555,6 @@ def convert_RNN(node, **kwargs):
make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0),
]
elif num_layers == 1:
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
Expand Down Expand Up @@ -4598,7 +4599,6 @@ def convert_RNN(node, **kwargs):

elif mode == 'gru':
if num_layers == 2:
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer'])
create_tensor([1, 3*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
Expand Down Expand Up @@ -4669,7 +4669,7 @@ def convert_RNN(node, **kwargs):
]

elif num_layers == 1:
create_tensor([0], name+'_0', kwargs['initializer'])

create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([3*state_size], name+'_3*state_size', kwargs['initializer'])
create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
Expand Down Expand Up @@ -4712,6 +4712,100 @@ def convert_RNN(node, **kwargs):
else:
raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')

elif mode in ['rnn_tanh', 'rnn_relu']:
activations = ['Tanh']
if mode == 'rnn_relu':
activations = ['Relu']
if num_layers == 2:

create_tensor([2*state_size], name+'_2*state_size', kwargs['initializer'])
create_tensor([state_size*state_size], name+'_state_size^2', kwargs['initializer'])
create_tensor([1, state_size, state_size], name+'_WR_shape', kwargs['initializer'])
create_tensor([1, 2*state_size], name+'_B_shape', kwargs['initializer'])
create_tensor([4*state_size*state_size], name+'_WR_offset', kwargs['initializer'])

nodes += [
make_node('Shape', [data], [name+'_data_shape']),
make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),

# Layer 0
# get W
make_node('Slice', [param, name+'_0', name+'_state_size^2'], [name+'_W0_1d']),
make_node('Reshape', [name+'_W0_1d', name+'_WR_shape'], [name+'_W0']),
# get R
make_node('Add', [name+'_state_size^2', name+'_state_size^2'], [name+'_R0_offset']),
make_node('Slice', [param, name+'_state_size^2', name+'_R0_offset'], [name+'_R0_1d']),
make_node('Reshape', [name+'_R0_1d', name+'_WR_shape'], [name+'_R0']),
# get B
make_node('Add', [name+'_WR_offset', name+'_2*state_size'], [name+'_B0_offset']),
make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']),
make_node('Reshape', [name+'_B0_1d', name+'_B_shape'], [name+'_B0']),
# get initial states
make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0),
# get seq_len
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
# Layer 0 RNN
make_node('RNN', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len', name+'_initial_h0'],
[name+'_rnn0_out_', name+'_rnn0_h'], hidden_size=state_size, activations=activations),
make_node('Squeeze', [name+'_rnn0_out_'], [name+'_rnn0_out'], axes=[1]),

# Layer 1
# get W
make_node('Add', [name+'_R0_offset', name+'_state_size^2'], [name+'_W1_offset']),
make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']),
make_node('Reshape', [name+'_W1_1d', name+'_WR_shape'], [name+'_W1']),
# get R
make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']),
make_node('Reshape', [name+'_R1_1d', name+'_WR_shape'], [name+'_R1']),
# get B
make_node('Add', [name+'_B0_offset', name+'_2*state_size'], [name+'_B1_offset']),
make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']),
make_node('Reshape', [name+'_B1_1d', name+'_B_shape'], [name+'_B1']),
# Layer 1 RNN
make_node('RNN', [name+'_rnn0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len',
name+'_initial_h1'], [name+'_rnn1_out_', name+'_rnn1_h'],
hidden_size=state_size, activations=activations),
make_node('Squeeze', [name+'_rnn1_out_'], [name], axes=[1]),
make_node('Concat', [name+'_rnn0_h', name+'_rnn1_h'], [name+'1'], axis=0)
]

elif num_layers == 1:

create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([state_size], name+'_state_size', kwargs['initializer'])
create_tensor([2*state_size], name+'_2*state_size', kwargs['initializer'])
create_tensor([state_size*state_size], name+'_state_size^2', kwargs['initializer'])
create_tensor([1, state_size, state_size], name+'_R_shape', kwargs['initializer'])
create_tensor([1, 2*state_size], name+'_B_shape', kwargs['initializer'])

nodes += [
make_node('Shape', [data], [name+'_data_shape']),
make_node('Split', [name+'_data_shape'],
[name+'_seq_length', name+'_batch_size', name+'_input_size'], name='split0'),
# get W
make_node('Mul', [name+'_state_size', name+'_input_size'], [name+'_mul0']),
make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
make_node('Concat', [name+'_1', name+'_state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
make_node('Reshape', [name+'_W_1d', name+'_W_shape'], [name+'_W']),
# get R
make_node('Add', [name+'_mul0', name+'_state_size^2'], [name+'_add0']),
make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
make_node('Reshape', [name+'_R_1d', name+'_R_shape'], [name+'_R']),
# get B
make_node('Add', [name+'_add0', name+'_2*state_size'], [name+'_add1']),
make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
make_node('Reshape', [name+'_B_1d', name+'_B_shape'], [name+'_B']),
# get seq_len
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
# compute RNN
make_node('RNN', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h],
[name+'0_', name+'1'], hidden_size=state_size, activations=activations),
make_node('Squeeze', [name+'0_'], [name], axes=[1]),
]
else:
raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
else:
raise NotImplementedError(f"Currently RNN onnx export does not support {mode} mode")
return nodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1047,11 +1047,12 @@ def convert_RNN(node, **kwargs):
nodes = []

mode = str(attrs.get('mode'))
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([1], name+'_1', kwargs['initializer'])

if mode == 'lstm':
initial_c = input_nodes[3]
if num_layers == 2:
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
Expand Down Expand Up @@ -1123,7 +1124,6 @@ def convert_RNN(node, **kwargs):
make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0),
]
elif num_layers == 1:
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
Expand Down Expand Up @@ -1167,7 +1167,6 @@ def convert_RNN(node, **kwargs):

elif mode == 'gru':
if num_layers == 2:
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer'])
create_tensor([1, 3*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
Expand Down Expand Up @@ -1238,7 +1237,6 @@ def convert_RNN(node, **kwargs):
]

elif num_layers == 1:
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([3*state_size], name+'_3*state_size', kwargs['initializer'])
create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer'])
Expand Down Expand Up @@ -1272,14 +1270,106 @@ def convert_RNN(node, **kwargs):
# get seq_len
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
# compute LSTM
# compute GRU
make_node('GRU', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h],
[name+'0_', name+'1'], hidden_size=state_size, linear_before_reset=1),
make_node('Squeeze', [name+'0_', name+'_1'], [name]),
]
else:
raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')

elif mode in ['rnn_tanh', 'rnn_relu']:
activations = ['Tanh']
if mode == 'rnn_relu':
activations = ['Relu']
if num_layers == 2:
create_tensor([2*state_size], name+'_2*state_size', kwargs['initializer'])
create_tensor([state_size*state_size], name+'_state_size^2', kwargs['initializer'])
create_tensor([1, state_size, state_size], name+'_WR_shape', kwargs['initializer'])
create_tensor([1, 2*state_size], name+'_B_shape', kwargs['initializer'])
create_tensor([4*state_size*state_size], name+'_WR_offset', kwargs['initializer'])

nodes += [
make_node('Shape', [data], [name+'_data_shape']),
make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),

# Layer 0
# get W
make_node('Slice', [param, name+'_0', name+'_state_size^2'], [name+'_W0_1d']),
make_node('Reshape', [name+'_W0_1d', name+'_WR_shape'], [name+'_W0']),
# get R
make_node('Add', [name+'_state_size^2', name+'_state_size^2'], [name+'_R0_offset']),
make_node('Slice', [param, name+'_state_size^2', name+'_R0_offset'], [name+'_R0_1d']),
make_node('Reshape', [name+'_R0_1d', name+'_WR_shape'], [name+'_R0']),
# get B
make_node('Add', [name+'_WR_offset', name+'_2*state_size'], [name+'_B0_offset']),
make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']),
make_node('Reshape', [name+'_B0_1d', name+'_B_shape'], [name+'_B0']),
# get initial states
make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0),
# get seq_len
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
# Layer 0 RNN
make_node('RNN', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len',
name+'_initial_h0'], [name+'_rnn0_out_', name+'_rnn0_h'],
hidden_size=state_size, activations=activations),
make_node('Squeeze', [name+'_rnn0_out_', name+'_1'], [name+'_rnn0_out']),

# Layer 1
# get W
make_node('Add', [name+'_R0_offset', name+'_state_size^2'], [name+'_W1_offset']),
make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']),
make_node('Reshape', [name+'_W1_1d', name+'_WR_shape'], [name+'_W1']),
# get R
make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']),
make_node('Reshape', [name+'_R1_1d', name+'_WR_shape'], [name+'_R1']),
# get B
make_node('Add', [name+'_B0_offset', name+'_2*state_size'], [name+'_B1_offset']),
make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']),
make_node('Reshape', [name+'_B1_1d', name+'_B_shape'], [name+'_B1']),
# Layer 1 RNN
make_node('RNN', [name+'_rnn0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len',
name+'_initial_h1'], [name+'_rnn1_out_', name+'_rnn1_h'],
hidden_size=state_size, activations=activations),
make_node('Squeeze', [name+'_rnn1_out_', name+'_1'], [name]),
make_node('Concat', [name+'_rnn0_h', name+'_rnn1_h'], [name+'1'], axis=0)
]

elif num_layers == 1:
create_tensor([state_size], name+'_state_size', kwargs['initializer'])
create_tensor([2*state_size], name+'_2*state_size', kwargs['initializer'])
create_tensor([state_size*state_size], name+'_state_size^2', kwargs['initializer'])
create_tensor([1, state_size, state_size], name+'_R_shape', kwargs['initializer'])
create_tensor([1, 2*state_size], name+'_B_shape', kwargs['initializer'])

nodes += [
make_node('Shape', [data], [name+'_data_shape']),
make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size',
name+'_input_size'], name='split0'),
# get W
make_node('Mul', [name+'_state_size', name+'_input_size'], [name+'_mul0']),
make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
make_node('Concat', [name+'_1', name+'_state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
make_node('Reshape', [name+'_W_1d', name+'_W_shape'], [name+'_W']),
# get R
make_node('Add', [name+'_mul0', name+'_state_size^2'], [name+'_add0']),
make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
make_node('Reshape', [name+'_R_1d', name+'_R_shape'], [name+'_R']),
# get B
make_node('Add', [name+'_add0', name+'_2*state_size'], [name+'_add1']),
make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
make_node('Reshape', [name+'_B_1d', name+'_B_shape'], [name+'_B']),
# get seq_len
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
# compute RNN
make_node('RNN', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h],
[name+'0_', name+'1'], hidden_size=state_size, activations=activations),
make_node('Squeeze', [name+'0_', name+'_1'], [name]),
]
else:
raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
else:
raise NotImplementedError(f"Currently RNN onnx export does not support {mode} mode")
return nodes
Expand Down
Loading

0 comments on commit 2127c3e

Please sign in to comment.