Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] ONNX add support coverage for Reshape and lstm #20246

Merged
merged 4 commits into from
May 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1848,6 +1848,23 @@ def convert_reshape(node, **kwargs):
make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name)
]

if targ_shape == [-3, -1] and reverse != 'True':
special_case = True
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([2], name+'_2', kwargs['initializer'])
create_tensor([-1], name+'_-1', kwargs['initializer'])
nodes = [
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Slice', [name+'_shape', name+'_0',
name+'_1'], [name+'_1st_dim']),
make_node('Slice', [name+'_shape', name+'_1',
name+'_2'], [name+'_2nd_dim']),
make_node('Mul', [name+'_1st_dim', name+'_2nd_dim'], [name+'_mul']),
make_node('Concat', [name+'_mul', name+'_-1'], [name+'_shape_new'], axis=0),
make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name),
]

if special_case:
return nodes

Expand Down Expand Up @@ -4449,17 +4466,14 @@ def convert_RNN(node, **kwargs):
from onnx import TensorProto

name, input_nodes, attrs = get_inputs(node, kwargs)
mode = str(attrs.get('mode'))

bidirectional = str(attrs.get('bidirectional', 'False'))
if bidirectional != 'False':
if bidirectional != 'False' and mode not in ['lstm']:
raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False')

num_layers = int(attrs.get('num_layers', '1'))

p = float(attrs.get('p', '0'))
if p != 0:
raise NotImplementedError('Currently RNN onnx export only supports p equals to 0')

use_sequence_length = str(attrs.get('use_sequence_length', 'False'))
if use_sequence_length != 'False':
raise NotImplementedError('Currently RNN onnx export only supports use_sequence_length equals to False')
Expand All @@ -4480,10 +4494,11 @@ def convert_RNN(node, **kwargs):
nodes = []
create_tensor([0], name+'_0', kwargs['initializer'])

mode = str(attrs.get('mode'))
if mode == 'lstm':
initial_c = input_nodes[3]
if num_layers == 2:
if bidirectional != 'False':
raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False')
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
Expand Down Expand Up @@ -4555,45 +4570,124 @@ def convert_RNN(node, **kwargs):
make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0),
]
elif num_layers == 1:
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])

nodes += [
make_node('Shape', [data], [name+'_data_shape']),
make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
# get W
make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0),
make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
# get R
make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
# get B
make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']),
make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
# get seq_len
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
# compute LSTM
make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
[name+'0_', name+'1', name+'2'], hidden_size=state_size),
make_node('Squeeze', [name+'0_'], [name], axes=[1]),
]
if bidirectional == 'False':
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])

nodes += [
make_node('Shape', [data], [name+'_data_shape']),
make_node('Split', [name+'_data_shape'],
[name+'_seq_length', name+'_batch_size', name+'_input_size']),
# get W
make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0),
make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'],
[name+'_W_shape'], axis=0),
make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
# get R
make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
# get B
make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']),
make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
# get seq_len
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
# compute LSTM
make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
[name+'0_', name+'1', name+'2'], hidden_size=state_size),
make_node('Squeeze', [name+'0_'], [name], axes=[1]),
]
else:
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([-1], name+'_-1', kwargs['initializer'])
create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])

nodes += [
make_node('Shape', [data], [name+'_data_shape']),
make_node('Split', [name+'_data_shape'],
[name+'_seq_length', name+'_batch_size', name+'_input_size']),
# get W_fwd
make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'],
[name+'_W_'], axis=0),
make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'],
[name+'_W_shape'], axis=0),
make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W_fwd']),
# get R_fwd
make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R_fwd']),
# get W_bwd
make_node('Add', [name+'_add0', name+'_mul0'], [name+'_add1']),
make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_W_1d_bwd']),
make_node('Split', [name+'_W_1d_bwd'],
[name+'_W0_bwd', name+'_W1_bwd', name+'_W2_bwd', name+'_W3_bwd']),
make_node('Concat', [name+'_W0_bwd', name+'_W3_bwd', name+'_W1_bwd', name+'_W2_bwd'],
[name+'_W_bwd_'], axis=0),
make_node('Reshape', [name+'_W_bwd_', name+'_W_shape'], [name+'_W_bwd']),
# get R_bwd
make_node('Add', [name+'_add1', name+'_4*state_size^2'], [name+'_add2']),
make_node('Slice', [param, name+'_add1', name+'_add2'], [name+'_R_1d_bwd']),
make_node('Split', [name+'_R_1d_bwd'],
[name+'_R0_bwd', name+'_R1_bwd', name+'_R2_bwd', name+'_R3_bwd']),
make_node('Concat', [name+'_R0_bwd', name+'_R3_bwd', name+'_R1_bwd', name+'_R2_bwd'],
[name+'_R_bwd_'], axis=0),
make_node('Reshape', [name+'_R_bwd_', name+'_R_shape'], [name+'_R_bwd']),
# get B_fwd
make_node('Add', [name+'_add2', name+'_8*state_size'], [name+'_add3']),
make_node('Slice', [param, name+'_add2', name+'_add3'], [name+'_B_1d']),
make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B_fwd']),
# get B_bwd
make_node('Add', [name+'_add3', name+'_8*state_size'], [name+'_add4']),
make_node('Slice', [param, name+'_add3', name+'_add4'], [name+'_B_1d_bwd']),
make_node('Split', [name+'_B_1d_bwd'],
[name+'_B0_bwd', name+'_B1_bwd', name+'_B2_bwd', name+'_B3_bwd',
name+'_B4_bwd', name+'_B5_bwd', name+'_B6_bwd', name+'_B7_bwd']),
make_node('Concat', [name+'_B0_bwd', name+'_B3_bwd', name+'_B1_bwd', name+'_B2_bwd',
name+'_B4_bwd', name+'_B7_bwd', name+'_B5_bwd', name+'_B6_bwd'],
[name+'_B_bwd_'], axis=0),
make_node('Reshape', [name+'_B_bwd_', name+'_B_shape'], [name+'_B_bwd']),
# get seq_len
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
# compute LSTM
make_node('Concat', [name+'_W_fwd', name+'_W_bwd'], [name+'_W'], axis=0),
make_node('Concat', [name+'_R_fwd', name+'_R_bwd'], [name+'_R'], axis=0),
make_node('Concat', [name+'_B_fwd', name+'_B_bwd'], [name+'_B'], axis=0),
make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
[name+'0_', name+'1', name+'2'], hidden_size=state_size, direction='bidirectional'),
make_node('Transpose', [name+'0_'], [name+'0_t'], perm=[0, 2, 1, 3]),
make_node('Concat', [name+'_seq_length', name+'_batch_size', name+'_-1'],
[name+'_shape_out'], axis=0),
make_node('Reshape', [name+'0_t', name+'_shape_out'], [name]),
]
else:
raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')

Expand Down
Loading