From 309f0264b0b2258a05330bac0fa1eaff615139a5 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 4 May 2021 21:57:06 -0700 Subject: [PATCH 1/4] lstm and reshape --- .../_op_translations_opset12.py | 173 +++++++++++++----- .../_op_translations_opset13.py | 153 +++++++++++----- tests/python-pytest/onnx/test_operators.py | 27 ++- 3 files changed, 255 insertions(+), 98 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index 2f38faa05f2b..0c26aa2b3003 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -1848,6 +1848,23 @@ def convert_reshape(node, **kwargs): make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) ] + if targ_shape == [-3, -1] and reverse != 'True': + special_case = True + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([1], name+'_1', kwargs['initializer']) + create_tensor([2], name+'_2', kwargs['initializer']) + create_tensor([-1], name+'_-1', kwargs['initializer']) + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Slice', [name+'_shape', name+'_0', + name+'_1'], [name+'_1st_dim']), + make_node('Slice', [name+'_shape', name+'_1', + name+'_2'], [name+'_2nd_dim']), + make_node('Mul', [name+'_1st_dim', name+'_2nd_dim'], [name+'_mul']), + make_node('Concat', [name+'_mul', name+'_-1'], [name+'_shape_new'], axis=0), + make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name), + ] + if special_case: return nodes @@ -4449,17 +4466,14 @@ def convert_RNN(node, **kwargs): from onnx import TensorProto name, input_nodes, attrs = get_inputs(node, kwargs) + mode = str(attrs.get('mode')) bidirectional = str(attrs.get('bidirectional', 'False')) - if bidirectional != 'False': + if bidirectional != 'False' and mode not in ['lstm']: raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False') num_layers = int(attrs.get('num_layers', '1')) - p = float(attrs.get('p', '0')) - if p != 0: - raise NotImplementedError('Currently RNN onnx export only supports p equals to 0') - use_sequence_length = str(attrs.get('use_sequence_length', 'False')) if use_sequence_length != 'False': raise NotImplementedError('Currently RNN onnx export only supports use_sequence_length equals to False') @@ -4480,10 +4494,11 @@ def convert_RNN(node, **kwargs): nodes = [] create_tensor([0], name+'_0', kwargs['initializer']) - mode = str(attrs.get('mode')) if mode == 'lstm': initial_c = input_nodes[3] if num_layers == 2: + if bidirectional != 'False': + raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False') create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer']) create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer']) create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer']) @@ -4555,45 +4570,113 @@ def convert_RNN(node, **kwargs): make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0), ] elif num_layers == 1: - create_tensor([1], name+'_1', kwargs['initializer']) - create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer']) - create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer']) - create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer']) - create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer']) - create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer']) - - nodes += [ - make_node('Shape', [data], [name+'_data_shape']), - make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), - # get W - make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), - make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), - make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), - make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), - make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), - make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']), - # get R - make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']), - make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']), - make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']), - make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0), - make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']), - # get B - make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']), - make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']), - make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3', - name+'_B4', name+'_B5', name+'_B6', name+'_B7']), - make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2', - name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), - make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']), - # get seq_len - make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), - make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), - # compute LSTM - make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c], - [name+'0_', name+'1', name+'2'], hidden_size=state_size), - make_node('Squeeze', [name+'0_'], [name], axes=[1]), - ] + if bidirectional == 'False': + create_tensor([1], name+'_1', kwargs['initializer']) + create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer']) + create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer']) + create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer']) + create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer']) + create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer']) + + nodes += [ + make_node('Shape', [data], [name+'_data_shape']), + make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), + # get W + make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), + make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), + make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), + make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), + make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']), + # get R + make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']), + make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']), + make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']), + make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0), + make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']), + # get B + make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']), + make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']), + make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3', + name+'_B4', name+'_B5', name+'_B6', name+'_B7']), + make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2', + name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), + make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']), + # get seq_len + make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), + make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), + # compute LSTM + make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c], + [name+'0_', name+'1', name+'2'], hidden_size=state_size), + make_node('Squeeze', [name+'0_'], [name], axes=[1]), + ] + else: + create_tensor([1], name+'_1', kwargs['initializer']) + create_tensor([-1], name+'_-1', kwargs['initializer']) + create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer']) + create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer']) + create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer']) + create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer']) + create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer']) + + nodes += [ + make_node('Shape', [data], [name+'_data_shape']), + make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), + # get W_fwd + make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), + make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), + make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), + make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), + make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W_fwd']), + # get R_fwd + make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']), + make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']), + make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']), + make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0), + make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R_fwd']), + # get W_bwd + make_node('Add', [name+'_add0', name+'_mul0'], [name+'_add1']), + make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_W_1d_bwd']), + make_node('Split', [name+'_W_1d_bwd'], [name+'_W0_bwd', name+'_W1_bwd', name+'_W2_bwd', name+'_W3_bwd']), + make_node('Concat', [name+'_W0_bwd', name+'_W3_bwd', name+'_W1_bwd', name+'_W2_bwd'], [name+'_W_bwd_'], axis=0), + # make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Reshape', [name+'_W_bwd_', name+'_W_shape'], [name+'_W_bwd']), + # get R_bwd + make_node('Add', [name+'_add1', name+'_4*state_size^2'], [name+'_add2']), + make_node('Slice', [param, name+'_add1', name+'_add2'], [name+'_R_1d_bwd']), + make_node('Split', [name+'_R_1d_bwd'], [name+'_R0_bwd', name+'_R1_bwd', name+'_R2_bwd', name+'_R3_bwd']), + make_node('Concat', [name+'_R0_bwd', name+'_R3_bwd', name+'_R1_bwd', name+'_R2_bwd'], [name+'_R_bwd_'], axis=0), + make_node('Reshape', [name+'_R_bwd_', name+'_R_shape'], [name+'_R_bwd']), + # get B_fwd + make_node('Add', [name+'_add2', name+'_8*state_size'], [name+'_add3']), + make_node('Slice', [param, name+'_add2', name+'_add3'], [name+'_B_1d']), + make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3', + name+'_B4', name+'_B5', name+'_B6', name+'_B7']), + make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2', + name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), + make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B_fwd']), + # get B_bwd + make_node('Add', [name+'_add3', name+'_8*state_size'], [name+'_add4']), + make_node('Slice', [param, name+'_add3', name+'_add4'], [name+'_B_1d_bwd']), + make_node('Split', [name+'_B_1d_bwd'], [name+'_B0_bwd', name+'_B1_bwd', name+'_B2_bwd', name+'_B3_bwd', + name+'_B4_bwd', name+'_B5_bwd', name+'_B6_bwd', name+'_B7_bwd']), + make_node('Concat', [name+'_B0_bwd', name+'_B3_bwd', name+'_B1_bwd', name+'_B2_bwd', + name+'_B4_bwd', name+'_B7_bwd', name+'_B5_bwd', name+'_B6_bwd'], [name+'_B_bwd_'], axis=0), + make_node('Reshape', [name+'_B_bwd_', name+'_B_shape'], [name+'_B_bwd']), + # get seq_len + make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), + make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), + # compute LSTM + make_node('Concat', [name+'_W_fwd', name+'_W_bwd'], [name+'_W'], axis=0), + make_node('Concat', [name+'_R_fwd', name+'_R_bwd'], [name+'_R'], axis=0), + make_node('Concat', [name+'_B_fwd', name+'_B_bwd'], [name+'_B'], axis=0), + make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c], + [name+'0_', name+'1', name+'2'], hidden_size=state_size, direction='bidirectional'), + make_node('Transpose', [name+'0_'], [name+'0_t'], perm=[0, 2, 1, 3]), + make_node('Concat', [name+'_seq_length', name+'_batch_size', name+'_-1'], [name+'_shape_out'], axis=0), + make_node('Reshape', [name+'0_t', name+'_shape_out'], [name]), + ] else: raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2') diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index 4ac6dfdff21c..4095f6036d5d 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -1015,16 +1015,13 @@ def convert_RNN(node, **kwargs): name, input_nodes, attrs = get_inputs(node, kwargs) + mode = str(attrs.get('mode')) bidirectional = str(attrs.get('bidirectional', 'False')) - if bidirectional != 'False': + if bidirectional != 'False' and mode not in ['lstm']: raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False') num_layers = int(attrs.get('num_layers', '1')) - p = float(attrs.get('p', '0')) - if p != 0: - raise NotImplementedError('Currently RNN onnx export only supports p equals to 0') - use_sequence_length = str(attrs.get('use_sequence_length', 'False')) if use_sequence_length != 'False': raise NotImplementedError('Currently RNN onnx export only supports use_sequence_length equals to False') @@ -1044,13 +1041,14 @@ def convert_RNN(node, **kwargs): nodes = [] - mode = str(attrs.get('mode')) create_tensor([0], name+'_0', kwargs['initializer']) create_tensor([1], name+'_1', kwargs['initializer']) if mode == 'lstm': initial_c = input_nodes[3] if num_layers == 2: + if bidirectional != 'False': + raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False') create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer']) create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer']) create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer']) @@ -1122,44 +1120,111 @@ def convert_RNN(node, **kwargs): make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0), ] elif num_layers == 1: - create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer']) - create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer']) - create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer']) - create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer']) - create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer']) - - nodes += [ - make_node('Shape', [data], [name+'_data_shape']), - make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), - # get W - make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), - make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), - make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), - make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), - make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), - make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']), - # get R - make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']), - make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']), - make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']), - make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0), - make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']), - # get B - make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']), - make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']), - make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3', - name+'_B4', name+'_B5', name+'_B6', name+'_B7']), - make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2', - name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), - make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']), - # get seq_len - make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), - make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), - # compute LSTM - make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c], - [name+'0_', name+'1', name+'2'], hidden_size=state_size), - make_node('Squeeze', [name+'0_', name+'_1'], [name]), - ] + if bidirectional == 'False': + create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer']) + create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer']) + create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer']) + create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer']) + create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer']) + + nodes += [ + make_node('Shape', [data], [name+'_data_shape']), + make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), + # get W + make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), + make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), + make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), + make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), + make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']), + # get R + make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']), + make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']), + make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']), + make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0), + make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']), + # get B + make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']), + make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']), + make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3', + name+'_B4', name+'_B5', name+'_B6', name+'_B7']), + make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2', + name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), + make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']), + # get seq_len + make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), + make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), + # compute LSTM + make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c], + [name+'0_', name+'1', name+'2'], hidden_size=state_size), + make_node('Squeeze', [name+'0_', name+'_1'], [name]), + ] + else: + create_tensor([-1], name+'_-1', kwargs['initializer']) + create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer']) + create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer']) + create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer']) + create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer']) + create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer']) + + nodes += [ + make_node('Shape', [data], [name+'_data_shape']), + make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), + # get W_fwd + make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), + make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), + make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), + make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), + make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W_fwd']), + # get R_fwd + make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']), + make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']), + make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']), + make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0), + make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R_fwd']), + # get W_bwd + make_node('Add', [name+'_add0', name+'_mul0'], [name+'_add1']), + make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_W_1d_bwd']), + make_node('Split', [name+'_W_1d_bwd'], [name+'_W0_bwd', name+'_W1_bwd', name+'_W2_bwd', name+'_W3_bwd']), + make_node('Concat', [name+'_W0_bwd', name+'_W3_bwd', name+'_W1_bwd', name+'_W2_bwd'], [name+'_W_bwd_'], axis=0), + # make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Reshape', [name+'_W_bwd_', name+'_W_shape'], [name+'_W_bwd']), + # get R_bwd + make_node('Add', [name+'_add1', name+'_4*state_size^2'], [name+'_add2']), + make_node('Slice', [param, name+'_add1', name+'_add2'], [name+'_R_1d_bwd']), + make_node('Split', [name+'_R_1d_bwd'], [name+'_R0_bwd', name+'_R1_bwd', name+'_R2_bwd', name+'_R3_bwd']), + make_node('Concat', [name+'_R0_bwd', name+'_R3_bwd', name+'_R1_bwd', name+'_R2_bwd'], [name+'_R_bwd_'], axis=0), + make_node('Reshape', [name+'_R_bwd_', name+'_R_shape'], [name+'_R_bwd']), + # get B_fwd + make_node('Add', [name+'_add2', name+'_8*state_size'], [name+'_add3']), + make_node('Slice', [param, name+'_add2', name+'_add3'], [name+'_B_1d']), + make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3', + name+'_B4', name+'_B5', name+'_B6', name+'_B7']), + make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2', + name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), + make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B_fwd']), + # get B_bwd + make_node('Add', [name+'_add3', name+'_8*state_size'], [name+'_add4']), + make_node('Slice', [param, name+'_add3', name+'_add4'], [name+'_B_1d_bwd']), + make_node('Split', [name+'_B_1d_bwd'], [name+'_B0_bwd', name+'_B1_bwd', name+'_B2_bwd', name+'_B3_bwd', + name+'_B4_bwd', name+'_B5_bwd', name+'_B6_bwd', name+'_B7_bwd']), + make_node('Concat', [name+'_B0_bwd', name+'_B3_bwd', name+'_B1_bwd', name+'_B2_bwd', + name+'_B4_bwd', name+'_B7_bwd', name+'_B5_bwd', name+'_B6_bwd'], [name+'_B_bwd_'], axis=0), + make_node('Reshape', [name+'_B_bwd_', name+'_B_shape'], [name+'_B_bwd']), + # get seq_len + make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), + make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), + # compute LSTM + make_node('Concat', [name+'_W_fwd', name+'_W_bwd'], [name+'_W'], axis=0), + make_node('Concat', [name+'_R_fwd', name+'_R_bwd'], [name+'_R'], axis=0), + make_node('Concat', [name+'_B_fwd', name+'_B_bwd'], [name+'_B'], axis=0), + make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c], + [name+'0_', name+'1', name+'2'], hidden_size=state_size, direction='bidirectional'), + make_node('Transpose', [name+'0_'], [name+'0_t'], perm=[0, 2, 1, 3]), + make_node('Concat', [name+'_seq_length', name+'_batch_size', name+'_-1'], [name+'_shape_out'], axis=0), + make_node('Reshape', [name+'0_t', name+'_shape_out'], [name]), + ] else: raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2') diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 11d14b7596c6..8d7ce857bbf0 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -250,6 +250,8 @@ def test_onnx_export_reshape(tmp_path, dtype): op_export_test('reshape_2', M2, [x], tmp_path) M3 = def_model('reshape', shape=(5, 1, 1, 1, 1, 0 -1, 0), reverse=True) op_export_test('reshape_3', M3, [x], tmp_path) + M4 = def_model('reshape', shape=(-3, -1)) + op_export_test('reshape_4', M4, [x], tmp_path) @pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64']) @@ -1233,8 +1235,6 @@ def test_onnx_export_sequence_reverse(tmp_path, dtype, params): M1 = def_model('SequenceReverse', use_sequence_length=True) op_export_test('SequenceReverse1', M1, [x, seq_len], tmp_path) - -# onnx LSTM from opset 11 does not support float64 @pytest.mark.parametrize('mode', ['lstm', 'gru', 'rnn_tanh', 'rnn_relu']) @pytest.mark.parametrize('dtype', ['float32']) @pytest.mark.parametrize('state_size', [16, 32, 64]) @@ -1242,25 +1242,34 @@ def test_onnx_export_sequence_reverse(tmp_path, dtype, params): @pytest.mark.parametrize('num_layers', [1, 2]) @pytest.mark.parametrize('batch_size', [1, 2, 4]) @pytest.mark.parametrize('seq_length', [16]) -def test_onnx_export_RNN(tmp_path, mode, dtype, state_size, input_size, num_layers, batch_size, seq_length): +@pytest.mark.parametrize('bidirectional', [True, False]) +def test_onnx_export_RNN(tmp_path, mode, dtype, state_size, input_size, num_layers, batch_size, seq_length, bidirectional): # TODO: The current implementation fails assertion checks for large parm/state_size. # for num_layers >= 2, input_size must equal to state_size if num_layers >= 2 and input_size != state_size: return + # Currently only bidirectional supports lstm with num_layers = 1 + if bidirectional and (mode != 'lstm' or num_layers != 1): + return + + b = 1 + if bidirectional: + b = 2 + factor = 1 if mode == 'gru': factor = 3 elif mode == 'lstm': factor = 4 - M = def_model('RNN', mode=mode, state_size=state_size, state_outputs=True, num_layers=num_layers, p=0) + M = def_model('RNN', mode=mode, state_size=state_size, state_outputs=True, num_layers=num_layers, p=0, bidirectional=bidirectional) x = mx.nd.random.normal(0, 10, (seq_length, batch_size, input_size), dtype=dtype) - param = mx.nd.random.normal(0, 1, [num_layers*factor*state_size*input_size + - num_layers*factor*state_size*state_size + - num_layers*2*factor*state_size], dtype=dtype) - state = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype) + param = mx.nd.random.normal(0, 1, [b*num_layers*factor*state_size*input_size + + b*num_layers*factor*state_size*state_size + + b*num_layers*2*factor*state_size], dtype=dtype) + state = mx.nd.random.uniform(-1, 1, [b*num_layers, batch_size, state_size], dtype=dtype) if mode == 'lstm': - cell = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype) + cell = mx.nd.random.uniform(-1, 1, [b*num_layers, batch_size, state_size], dtype=dtype) op_export_test('rnn', M, [x, param, state, cell], tmp_path) elif mode == 'rnn_relu': # set large atol as relu can outputs big numbers From f2961fb45dd56f00fc215571ffe4efb91fe089e4 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 5 May 2021 00:54:14 -0700 Subject: [PATCH 2/4] fix sanity --- .../_op_translations_opset12.py | 37 +++++++++------ .../_op_translations_opset13.py | 47 ++++++++++++------- 2 files changed, 52 insertions(+), 32 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index 0c26aa2b3003..dea0444743b5 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -4621,13 +4621,16 @@ def convert_RNN(node, **kwargs): nodes += [ make_node('Shape', [data], [name+'_data_shape']), - make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), + make_node('Split', [name+'_data_shape'], + [name+'_seq_length', name+'_batch_size', name+'_input_size']), # get W_fwd make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), - make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), - make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], + [name+'_W_'], axis=0), + make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], + [name+'_W_shape'], axis=0), make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W_fwd']), # get R_fwd make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']), @@ -4638,15 +4641,18 @@ def convert_RNN(node, **kwargs): # get W_bwd make_node('Add', [name+'_add0', name+'_mul0'], [name+'_add1']), make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_W_1d_bwd']), - make_node('Split', [name+'_W_1d_bwd'], [name+'_W0_bwd', name+'_W1_bwd', name+'_W2_bwd', name+'_W3_bwd']), - make_node('Concat', [name+'_W0_bwd', name+'_W3_bwd', name+'_W1_bwd', name+'_W2_bwd'], [name+'_W_bwd_'], axis=0), - # make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Split', [name+'_W_1d_bwd'], + [name+'_W0_bwd', name+'_W1_bwd', name+'_W2_bwd', name+'_W3_bwd']), + make_node('Concat', [name+'_W0_bwd', name+'_W3_bwd', name+'_W1_bwd', name+'_W2_bwd'], + [name+'_W_bwd_'], axis=0), make_node('Reshape', [name+'_W_bwd_', name+'_W_shape'], [name+'_W_bwd']), # get R_bwd make_node('Add', [name+'_add1', name+'_4*state_size^2'], [name+'_add2']), make_node('Slice', [param, name+'_add1', name+'_add2'], [name+'_R_1d_bwd']), - make_node('Split', [name+'_R_1d_bwd'], [name+'_R0_bwd', name+'_R1_bwd', name+'_R2_bwd', name+'_R3_bwd']), - make_node('Concat', [name+'_R0_bwd', name+'_R3_bwd', name+'_R1_bwd', name+'_R2_bwd'], [name+'_R_bwd_'], axis=0), + make_node('Split', [name+'_R_1d_bwd'], + [name+'_R0_bwd', name+'_R1_bwd', name+'_R2_bwd', name+'_R3_bwd']), + make_node('Concat', [name+'_R0_bwd', name+'_R3_bwd', name+'_R1_bwd', name+'_R2_bwd'], + [name+'_R_bwd_'], axis=0), make_node('Reshape', [name+'_R_bwd_', name+'_R_shape'], [name+'_R_bwd']), # get B_fwd make_node('Add', [name+'_add2', name+'_8*state_size'], [name+'_add3']), @@ -4654,15 +4660,17 @@ def convert_RNN(node, **kwargs): make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3', name+'_B4', name+'_B5', name+'_B6', name+'_B7']), make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2', - name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), + name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B_fwd']), # get B_bwd make_node('Add', [name+'_add3', name+'_8*state_size'], [name+'_add4']), make_node('Slice', [param, name+'_add3', name+'_add4'], [name+'_B_1d_bwd']), - make_node('Split', [name+'_B_1d_bwd'], [name+'_B0_bwd', name+'_B1_bwd', name+'_B2_bwd', name+'_B3_bwd', - name+'_B4_bwd', name+'_B5_bwd', name+'_B6_bwd', name+'_B7_bwd']), + make_node('Split', [name+'_B_1d_bwd'], + [name+'_B0_bwd', name+'_B1_bwd', name+'_B2_bwd', name+'_B3_bwd', + name+'_B4_bwd', name+'_B5_bwd', name+'_B6_bwd', name+'_B7_bwd']), make_node('Concat', [name+'_B0_bwd', name+'_B3_bwd', name+'_B1_bwd', name+'_B2_bwd', - name+'_B4_bwd', name+'_B7_bwd', name+'_B5_bwd', name+'_B6_bwd'], [name+'_B_bwd_'], axis=0), + name+'_B4_bwd', name+'_B7_bwd', name+'_B5_bwd', name+'_B6_bwd'], + [name+'_B_bwd_'], axis=0), make_node('Reshape', [name+'_B_bwd_', name+'_B_shape'], [name+'_B_bwd']), # get seq_len make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), @@ -4672,9 +4680,10 @@ def convert_RNN(node, **kwargs): make_node('Concat', [name+'_R_fwd', name+'_R_bwd'], [name+'_R'], axis=0), make_node('Concat', [name+'_B_fwd', name+'_B_bwd'], [name+'_B'], axis=0), make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c], - [name+'0_', name+'1', name+'2'], hidden_size=state_size, direction='bidirectional'), + [name+'0_', name+'1', name+'2'], hidden_size=state_size, direction='bidirectional'), make_node('Transpose', [name+'0_'], [name+'0_t'], perm=[0, 2, 1, 3]), - make_node('Concat', [name+'_seq_length', name+'_batch_size', name+'_-1'], [name+'_shape_out'], axis=0), + make_node('Concat', [name+'_seq_length', name+'_batch_size', name+'_-1'], + [name+'_shape_out'], axis=0), make_node('Reshape', [name+'0_t', name+'_shape_out'], [name]), ] else: diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index 4095f6036d5d..ea856195519d 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -1129,13 +1129,15 @@ def convert_RNN(node, **kwargs): nodes += [ make_node('Shape', [data], [name+'_data_shape']), - make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), + make_node('Split', [name+'_data_shape'], + [name+'_seq_length', name+'_batch_size', name+'_input_size']), # get W make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), - make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], + [name+'_W_shape'], axis=0), make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']), # get R make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']), @@ -1149,14 +1151,14 @@ def convert_RNN(node, **kwargs): make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3', name+'_B4', name+'_B5', name+'_B6', name+'_B7']), make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2', - name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), + name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']), # get seq_len make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), # compute LSTM make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c], - [name+'0_', name+'1', name+'2'], hidden_size=state_size), + [name+'0_', name+'1', name+'2'], hidden_size=state_size), make_node('Squeeze', [name+'0_', name+'_1'], [name]), ] else: @@ -1169,13 +1171,16 @@ def convert_RNN(node, **kwargs): nodes += [ make_node('Shape', [data], [name+'_data_shape']), - make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), + make_node('Split', [name+'_data_shape'], + [name+'_seq_length', name+'_batch_size', name+'_input_size']), # get W_fwd make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), - make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), - make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], + [name+'_W_'], axis=0), + make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], + [name+'_W_shape'], axis=0), make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W_fwd']), # get R_fwd make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']), @@ -1186,15 +1191,18 @@ def convert_RNN(node, **kwargs): # get W_bwd make_node('Add', [name+'_add0', name+'_mul0'], [name+'_add1']), make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_W_1d_bwd']), - make_node('Split', [name+'_W_1d_bwd'], [name+'_W0_bwd', name+'_W1_bwd', name+'_W2_bwd', name+'_W3_bwd']), - make_node('Concat', [name+'_W0_bwd', name+'_W3_bwd', name+'_W1_bwd', name+'_W2_bwd'], [name+'_W_bwd_'], axis=0), - # make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Split', [name+'_W_1d_bwd'], + [name+'_W0_bwd', name+'_W1_bwd', name+'_W2_bwd', name+'_W3_bwd']), + make_node('Concat', [name+'_W0_bwd', name+'_W3_bwd', name+'_W1_bwd', name+'_W2_bwd'], + [name+'_W_bwd_'], axis=0), make_node('Reshape', [name+'_W_bwd_', name+'_W_shape'], [name+'_W_bwd']), # get R_bwd make_node('Add', [name+'_add1', name+'_4*state_size^2'], [name+'_add2']), make_node('Slice', [param, name+'_add1', name+'_add2'], [name+'_R_1d_bwd']), - make_node('Split', [name+'_R_1d_bwd'], [name+'_R0_bwd', name+'_R1_bwd', name+'_R2_bwd', name+'_R3_bwd']), - make_node('Concat', [name+'_R0_bwd', name+'_R3_bwd', name+'_R1_bwd', name+'_R2_bwd'], [name+'_R_bwd_'], axis=0), + make_node('Split', [name+'_R_1d_bwd'], + [name+'_R0_bwd', name+'_R1_bwd', name+'_R2_bwd', name+'_R3_bwd']), + make_node('Concat', [name+'_R0_bwd', name+'_R3_bwd', name+'_R1_bwd', name+'_R2_bwd'], + [name+'_R_bwd_'], axis=0), make_node('Reshape', [name+'_R_bwd_', name+'_R_shape'], [name+'_R_bwd']), # get B_fwd make_node('Add', [name+'_add2', name+'_8*state_size'], [name+'_add3']), @@ -1202,15 +1210,17 @@ def convert_RNN(node, **kwargs): make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3', name+'_B4', name+'_B5', name+'_B6', name+'_B7']), make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2', - name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), + name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B_fwd']), # get B_bwd make_node('Add', [name+'_add3', name+'_8*state_size'], [name+'_add4']), make_node('Slice', [param, name+'_add3', name+'_add4'], [name+'_B_1d_bwd']), - make_node('Split', [name+'_B_1d_bwd'], [name+'_B0_bwd', name+'_B1_bwd', name+'_B2_bwd', name+'_B3_bwd', - name+'_B4_bwd', name+'_B5_bwd', name+'_B6_bwd', name+'_B7_bwd']), + make_node('Split', [name+'_B_1d_bwd'], + [name+'_B0_bwd', name+'_B1_bwd', name+'_B2_bwd', name+'_B3_bwd', + name+'_B4_bwd', name+'_B5_bwd', name+'_B6_bwd', name+'_B7_bwd']), make_node('Concat', [name+'_B0_bwd', name+'_B3_bwd', name+'_B1_bwd', name+'_B2_bwd', - name+'_B4_bwd', name+'_B7_bwd', name+'_B5_bwd', name+'_B6_bwd'], [name+'_B_bwd_'], axis=0), + name+'_B4_bwd', name+'_B7_bwd', name+'_B5_bwd', name+'_B6_bwd'], + [name+'_B_bwd_'], axis=0), make_node('Reshape', [name+'_B_bwd_', name+'_B_shape'], [name+'_B_bwd']), # get seq_len make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), @@ -1220,9 +1230,10 @@ def convert_RNN(node, **kwargs): make_node('Concat', [name+'_R_fwd', name+'_R_bwd'], [name+'_R'], axis=0), make_node('Concat', [name+'_B_fwd', name+'_B_bwd'], [name+'_B'], axis=0), make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c], - [name+'0_', name+'1', name+'2'], hidden_size=state_size, direction='bidirectional'), + [name+'0_', name+'1', name+'2'], hidden_size=state_size, direction='bidirectional'), make_node('Transpose', [name+'0_'], [name+'0_t'], perm=[0, 2, 1, 3]), - make_node('Concat', [name+'_seq_length', name+'_batch_size', name+'_-1'], [name+'_shape_out'], axis=0), + make_node('Concat', [name+'_seq_length', name+'_batch_size', name+'_-1'], + [name+'_shape_out'], axis=0), make_node('Reshape', [name+'0_t', name+'_shape_out'], [name]), ] else: From 5c3e89cfb60c529b90e2d151a42371dfc6266014 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 5 May 2021 09:21:30 -0700 Subject: [PATCH 3/4] fix sanity --- .../_op_translations/_op_translations_opset12.py | 10 ++++++---- .../_op_translations/_op_translations_opset13.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index dea0444743b5..da6339d9a767 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -4580,13 +4580,15 @@ def convert_RNN(node, **kwargs): nodes += [ make_node('Shape', [data], [name+'_data_shape']), - make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), + make_node('Split', [name+'_data_shape'], + [name+'_seq_length', name+'_batch_size', name+'_input_size']), # get W make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), - make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], + [name+'_W_shape'], axis=0), make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']), # get R make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']), @@ -4600,14 +4602,14 @@ def convert_RNN(node, **kwargs): make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3', name+'_B4', name+'_B5', name+'_B6', name+'_B7']), make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2', - name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), + name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']), # get seq_len make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), # compute LSTM make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c], - [name+'0_', name+'1', name+'2'], hidden_size=state_size), + [name+'0_', name+'1', name+'2'], hidden_size=state_size), make_node('Squeeze', [name+'0_'], [name], axes=[1]), ] else: diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index ea856195519d..b32cc94bf819 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -1129,14 +1129,14 @@ def convert_RNN(node, **kwargs): nodes += [ make_node('Shape', [data], [name+'_data_shape']), - make_node('Split', [name+'_data_shape'], + make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), # get W make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), - make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], + make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']), # get R From 1d6589a1ad7358e2b1f77c1250ce9572a68ff9a9 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 5 May 2021 13:40:34 -0700 Subject: [PATCH 4/4] reduce state_size --- tests/python-pytest/onnx/test_operators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 8d7ce857bbf0..7cf5c00b5702 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1237,7 +1237,7 @@ def test_onnx_export_sequence_reverse(tmp_path, dtype, params): @pytest.mark.parametrize('mode', ['lstm', 'gru', 'rnn_tanh', 'rnn_relu']) @pytest.mark.parametrize('dtype', ['float32']) -@pytest.mark.parametrize('state_size', [16, 32, 64]) +@pytest.mark.parametrize('state_size', [16, 32]) @pytest.mark.parametrize('input_size', [16, 32, 64]) @pytest.mark.parametrize('num_layers', [1, 2]) @pytest.mark.parametrize('batch_size', [1, 2, 4])