diff --git a/tests/keras2onnx_unit_tests/test_layers.py b/tests/keras2onnx_unit_tests/test_layers.py index 09bcc7899..1e32ea5dd 100644 --- a/tests/keras2onnx_unit_tests/test_layers.py +++ b/tests/keras2onnx_unit_tests/test_layers.py @@ -7,7 +7,7 @@ from mock_keras2onnx.proto import (keras, is_tf_keras, is_tensorflow_older_than, is_tensorflow_later_than, is_keras_older_than, is_keras_later_than) -from test_utils import no_loops_in_tf2 +from test_utils import no_loops_in_tf2, all_recurrents_should_bidirectional K = keras.backend Activation = keras.layers.Activation @@ -2073,6 +2073,7 @@ def test_bidirectional(runner, rnn_class, return_sequences): model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy', optimizer='rmsprop') onnx_model = convert_keras(model, 'test', target_opset=op_version) + assert all_recurrents_should_bidirectional(onnx_model) for batch in batch_list: data = np.random.rand(batch, sequence_len, input_dim).astype(np.float32) expected = model.predict(data) @@ -2084,6 +2085,7 @@ def test_bidirectional(runner, rnn_class, return_sequences): input_shape=(5, 10), merge_mode=merge_mode)(sub_input1) keras_model = keras.Model(inputs=sub_input1, outputs=sub_mapped1) onnx_model = convert_keras(keras_model, 'test_2', target_opset=op_version) + assert all_recurrents_should_bidirectional(onnx_model) for batch in batch_list: data = np.random.rand(batch, sequence_len, input_dim).astype(np.float32) expected = keras_model.predict(data) @@ -2102,6 +2104,7 @@ def test_bidirectional_with_bias(runner, rnn_class): # Test with the default bias expected = model.predict(x) onnx_model = convert_keras(model, model.name) + assert all_recurrents_should_bidirectional(onnx_model) assert runner(onnx_model.graph.name, onnx_model, x, expected) # Set bias values to random floats @@ -2114,6 +2117,7 @@ def test_bidirectional_with_bias(runner, rnn_class): # Test with random bias expected = model.predict(x) onnx_model = convert_keras(model, model.name) + assert all_recurrents_should_bidirectional(onnx_model) assert runner(onnx_model.graph.name, onnx_model, x, expected) @@ -2141,6 +2145,7 @@ def test_bidirectional_time_major_true(runner, rnn_class): expected = model.predict(x) onnx_model = convert_keras(model, model.name) + assert all_recurrents_should_bidirectional(onnx_model) assert runner(onnx_model.graph.name, onnx_model, x, expected) @@ -2155,6 +2160,7 @@ def test_bidirectional_with_initial_states(runner, rnn_class): expected = model.predict(inputs) onnx_model = convert_keras(model, model.name) + assert all_recurrents_should_bidirectional(onnx_model) assert runner(onnx_model.graph.name, onnx_model, inputs, expected) input2 = Input(shape=(None, 5)) @@ -2165,6 +2171,7 @@ def test_bidirectional_with_initial_states(runner, rnn_class): expected = model.predict(inputs) onnx_model = convert_keras(model, model.name) + assert all_recurrents_should_bidirectional(onnx_model) assert runner(onnx_model.graph.name, onnx_model, inputs, expected, atol=1e-5) @@ -2180,6 +2187,7 @@ def test_bidirectional_seqlen_none(runner, rnn_class): model.add(Dense(44)) onnx_model = convert_keras(model, model.name) + assert all_recurrents_should_bidirectional(onnx_model) for batch in [1, 4]: x = np.random.rand(batch, 50).astype(np.float32) expected = model.predict(x) diff --git a/tests/keras2onnx_unit_tests/test_utils.py b/tests/keras2onnx_unit_tests/test_utils.py index 59ae3b229..3706d5eea 100644 --- a/tests/keras2onnx_unit_tests/test_utils.py +++ b/tests/keras2onnx_unit_tests/test_utils.py @@ -3,6 +3,7 @@ import os import sys import onnx +from onnx import helper import numpy as np import mock_keras2onnx from mock_keras2onnx.proto import keras, is_keras_older_than @@ -161,6 +162,14 @@ def no_loops_in_tf2(onnx_model): return not is_tf2 or all(n.op_type != "Loop" for n in onnx_model.graph.node) +def all_recurrents_should_bidirectional(onnx_model): + return all([ + helper.get_attribute_value(attr) == b'bidirectional' + for node in onnx_model.graph.node if node.op_type in ['GRU', 'LSTM', 'RNN'] + for attr in node.attribute if attr.name == 'direction' + ]) + + def run_onnx_runtime(case_name, onnx_model, data, expected, model_files, rtol=1.e-3, atol=1.e-6, compare_perf=False, enable_profiling=False): if not os.path.exists(tmp_path): diff --git a/tf2onnx/rewriter/rnn_utils.py b/tf2onnx/rewriter/rnn_utils.py index 762328665..4e3912004 100644 --- a/tf2onnx/rewriter/rnn_utils.py +++ b/tf2onnx/rewriter/rnn_utils.py @@ -485,26 +485,42 @@ def find_bidirectional_rnns(g, ops, rnn_type): input_id = n.input[0] temp = n.inputs[0] is_bw = False + is_transposed = False if temp.type == "Transpose": input_id = temp.input[0] temp = temp.inputs[0] + is_transposed = True if utils.is_tf_reverse_op(temp): input_id = temp.input[0] + temp = temp.inputs[0] is_bw = True + if (not is_transposed) and temp.type == "Transpose": + input_id = temp.input[0] + temp = temp.inputs[0] + + input_ids = [input_id] + if temp.type == "Identity": + input_ids.append(temp.input[0]) + temp = temp.inputs[0] + if temp.type == "Identity": + input_ids.append(temp.input[0]) + if is_bw: # if output 0 is consumed and there is no reverse after the 1st output. # it's not backward rnn. - if g.find_output_consumers(n.output[0]) and not get_reverse_nodes_after_y_output(g, n): + if g.find_output_consumers(n.output[0]) and not get_reverse_or_slice_nodes_after_y_output(g, n): logger.warning("rnn %s following Reverse op isn't the part of bi-rnn.", n.name) continue - logger.debug("find bw rnn %s", input_id) - bw_rnns[input_id].append(n) + logger.debug("find bw rnn %s", input_ids) + for input_id in input_ids: + bw_rnns[input_id].append(n) else: - logger.debug("find fw rnn %s", input_id) - fw_rnns[input_id].append(n) + logger.debug("find fw rnn %s", input_ids) + for input_id in input_ids: + fw_rnns[input_id].append(n) # fw_rnn and bw_rnn must share the same input birnn_input = list(set(fw_rnns.keys()).intersection(bw_rnns.keys())) @@ -554,7 +570,17 @@ def belong_to_birnn(g, fw_rnn, bw_rnn, rnn_type): return True -def get_reverse_nodes_after_y_output(g, rnn_bw): +def is_tail_slice_op(node): + return ( + node.type == 'StridedSlice' and + node.inputs[1].get_tensor_value() == [-1] and + node.inputs[2].get_tensor_value() == [0] and + node.inputs[3].get_tensor_value() == [1] and + node.get_attr('shrink_axis_mask').i == 1 + ) + + +def get_reverse_or_slice_nodes_after_y_output(g, rnn_bw): bw_consumers = g.find_output_consumers(rnn_bw.output[0]) # todo: figure out a better way to remove reverse op @@ -562,19 +588,22 @@ def get_reverse_nodes_after_y_output(g, rnn_bw): s_cnt = len(squeeze_nodes) if s_cnt == 1: s = squeeze_nodes[0] - trans_nodes = g.find_output_consumers(s.output[0]) - if len(trans_nodes) == 1: - if trans_nodes[0].type == "Transpose": - reverse_nodes = g.find_output_consumers(trans_nodes[0].output[0]) - elif utils.is_tf_reverse_op(trans_nodes[0]): - reverse_nodes = trans_nodes - else: - logger.debug("not found reverse op, unexpected") - return [] - - are_all_reverse = all([utils.is_tf_reverse_op(r_op) for r_op in reverse_nodes]) - if are_all_reverse: - return reverse_nodes + reverse_or_slice_nodes = g.find_output_consumers(s.output[0]) + if len(reverse_or_slice_nodes) == 1: + if reverse_or_slice_nodes[0].type == "Transpose": + reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0]) + + if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity": + reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0]) + if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity": + reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0]) + + are_all_reverse_or_slice = all([ + utils.is_tf_reverse_op(r_op) or is_tail_slice_op(r_op) + for r_op in reverse_or_slice_nodes + ]) + if are_all_reverse_or_slice: + return reverse_or_slice_nodes logger.debug("bw y output is used followed by reverse node") return [] @@ -619,13 +648,28 @@ def slice_birnn_for_original_rnn_consumers(g, rnn_fw, rnn_bw, bi_rnn, rnn_output if rnn_output_index == 0: axis = 1 - # remove reverse op for rnn_bw - reverse_nodes = get_reverse_nodes_after_y_output(g, rnn_bw) - - for r_op in reverse_nodes: - logger.debug("remove reverse op %s", r_op.name) - g.replace_all_inputs(r_op.output[0], r_op.input[0], ops=all_nodes) - to_remove.append(r_op.name) + # remove reverse(return_sequence=True) or tail slice(return_sequence=False) op for rnn_bw + reverse_or_slice_nodes = get_reverse_or_slice_nodes_after_y_output(g, rnn_bw) + + for r_op in reverse_or_slice_nodes: + if utils.is_tf_reverse_op(r_op): + logger.debug("remove reverse op %s", r_op.name) + g.replace_all_inputs(r_op.output[0], r_op.input[0], ops=all_nodes) + to_remove.append(r_op.name) + elif is_tail_slice_op(r_op): + # in case of return_sequence=False + # replace output[-1:] to output[0:1] + attr = {"axes": [0], "starts": [0], "ends": [1]} + inputs_map = {"data": r_op.input[0], **attr} + slice_node_bw = GraphBuilder(g).make_slice(inputs_map) + all_nodes.append(g.get_node_by_output(slice_node_bw)) + + inputs_map = {"data": slice_node_bw, "axes": [0]} + squeeze_node_bw = GraphBuilder(g).make_squeeze(inputs_map) + all_nodes.append(g.get_node_by_output(squeeze_node_bw)) + + g.replace_all_inputs(r_op.output[0], squeeze_node_bw, ops=all_nodes) + to_remove.append(r_op.name) elif rnn_output_index in [1, 2]: axis = 0 else: