Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix keras bidirectional merge failures #1869

Merged
merged 4 commits into from
Apr 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion tests/keras2onnx_unit_tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mock_keras2onnx.proto import (keras, is_tf_keras,
is_tensorflow_older_than, is_tensorflow_later_than,
is_keras_older_than, is_keras_later_than)
from test_utils import no_loops_in_tf2
from test_utils import no_loops_in_tf2, all_recurrents_should_bidirectional

K = keras.backend
Activation = keras.layers.Activation
Expand Down Expand Up @@ -2073,6 +2073,7 @@ def test_bidirectional(runner, rnn_class, return_sequences):
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
onnx_model = convert_keras(model, 'test', target_opset=op_version)
assert all_recurrents_should_bidirectional(onnx_model)
for batch in batch_list:
data = np.random.rand(batch, sequence_len, input_dim).astype(np.float32)
expected = model.predict(data)
Expand All @@ -2084,6 +2085,7 @@ def test_bidirectional(runner, rnn_class, return_sequences):
input_shape=(5, 10), merge_mode=merge_mode)(sub_input1)
keras_model = keras.Model(inputs=sub_input1, outputs=sub_mapped1)
onnx_model = convert_keras(keras_model, 'test_2', target_opset=op_version)
assert all_recurrents_should_bidirectional(onnx_model)
for batch in batch_list:
data = np.random.rand(batch, sequence_len, input_dim).astype(np.float32)
expected = keras_model.predict(data)
Expand All @@ -2102,6 +2104,7 @@ def test_bidirectional_with_bias(runner, rnn_class):
# Test with the default bias
expected = model.predict(x)
onnx_model = convert_keras(model, model.name)
assert all_recurrents_should_bidirectional(onnx_model)
assert runner(onnx_model.graph.name, onnx_model, x, expected)

# Set bias values to random floats
Expand All @@ -2114,6 +2117,7 @@ def test_bidirectional_with_bias(runner, rnn_class):
# Test with random bias
expected = model.predict(x)
onnx_model = convert_keras(model, model.name)
assert all_recurrents_should_bidirectional(onnx_model)
assert runner(onnx_model.graph.name, onnx_model, x, expected)


Expand Down Expand Up @@ -2141,6 +2145,7 @@ def test_bidirectional_time_major_true(runner, rnn_class):

expected = model.predict(x)
onnx_model = convert_keras(model, model.name)
assert all_recurrents_should_bidirectional(onnx_model)
assert runner(onnx_model.graph.name, onnx_model, x, expected)


Expand All @@ -2155,6 +2160,7 @@ def test_bidirectional_with_initial_states(runner, rnn_class):

expected = model.predict(inputs)
onnx_model = convert_keras(model, model.name)
assert all_recurrents_should_bidirectional(onnx_model)
assert runner(onnx_model.graph.name, onnx_model, inputs, expected)

input2 = Input(shape=(None, 5))
Expand All @@ -2165,6 +2171,7 @@ def test_bidirectional_with_initial_states(runner, rnn_class):

expected = model.predict(inputs)
onnx_model = convert_keras(model, model.name)
assert all_recurrents_should_bidirectional(onnx_model)
assert runner(onnx_model.graph.name, onnx_model, inputs, expected, atol=1e-5)


Expand All @@ -2180,6 +2187,7 @@ def test_bidirectional_seqlen_none(runner, rnn_class):
model.add(Dense(44))

onnx_model = convert_keras(model, model.name)
assert all_recurrents_should_bidirectional(onnx_model)
for batch in [1, 4]:
x = np.random.rand(batch, 50).astype(np.float32)
expected = model.predict(x)
Expand Down
9 changes: 9 additions & 0 deletions tests/keras2onnx_unit_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys
import onnx
from onnx import helper
import numpy as np
import mock_keras2onnx
from mock_keras2onnx.proto import keras, is_keras_older_than
Expand Down Expand Up @@ -161,6 +162,14 @@ def no_loops_in_tf2(onnx_model):
return not is_tf2 or all(n.op_type != "Loop" for n in onnx_model.graph.node)


def all_recurrents_should_bidirectional(onnx_model):
return all([
helper.get_attribute_value(attr) == b'bidirectional'
for node in onnx_model.graph.node if node.op_type in ['GRU', 'LSTM', 'RNN']
for attr in node.attribute if attr.name == 'direction'
])


def run_onnx_runtime(case_name, onnx_model, data, expected, model_files, rtol=1.e-3, atol=1.e-6,
compare_perf=False, enable_profiling=False):
if not os.path.exists(tmp_path):
Expand Down
96 changes: 70 additions & 26 deletions tf2onnx/rewriter/rnn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,26 +485,42 @@ def find_bidirectional_rnns(g, ops, rnn_type):
input_id = n.input[0]
temp = n.inputs[0]
is_bw = False
is_transposed = False
if temp.type == "Transpose":
input_id = temp.input[0]
temp = temp.inputs[0]
is_transposed = True

if utils.is_tf_reverse_op(temp):
input_id = temp.input[0]
temp = temp.inputs[0]
is_bw = True

if (not is_transposed) and temp.type == "Transpose":
input_id = temp.input[0]
temp = temp.inputs[0]

input_ids = [input_id]
if temp.type == "Identity":
input_ids.append(temp.input[0])
temp = temp.inputs[0]
if temp.type == "Identity":
input_ids.append(temp.input[0])

if is_bw:
# if output 0 is consumed and there is no reverse after the 1st output.
# it's not backward rnn.
if g.find_output_consumers(n.output[0]) and not get_reverse_nodes_after_y_output(g, n):
if g.find_output_consumers(n.output[0]) and not get_reverse_or_slice_nodes_after_y_output(g, n):
logger.warning("rnn %s following Reverse op isn't the part of bi-rnn.", n.name)
continue

logger.debug("find bw rnn %s", input_id)
bw_rnns[input_id].append(n)
logger.debug("find bw rnn %s", input_ids)
for input_id in input_ids:
bw_rnns[input_id].append(n)
else:
logger.debug("find fw rnn %s", input_id)
fw_rnns[input_id].append(n)
logger.debug("find fw rnn %s", input_ids)
for input_id in input_ids:
fw_rnns[input_id].append(n)

# fw_rnn and bw_rnn must share the same input
birnn_input = list(set(fw_rnns.keys()).intersection(bw_rnns.keys()))
Expand Down Expand Up @@ -554,27 +570,40 @@ def belong_to_birnn(g, fw_rnn, bw_rnn, rnn_type):
return True


def get_reverse_nodes_after_y_output(g, rnn_bw):
def is_tail_slice_op(node):
return (
node.type == 'StridedSlice' and
node.inputs[1].get_tensor_value() == [-1] and
node.inputs[2].get_tensor_value() == [0] and
node.inputs[3].get_tensor_value() == [1] and
node.get_attr('shrink_axis_mask').i == 1
)


def get_reverse_or_slice_nodes_after_y_output(g, rnn_bw):
bw_consumers = g.find_output_consumers(rnn_bw.output[0])

# todo: figure out a better way to remove reverse op
squeeze_nodes = [c for c in bw_consumers if c.type == "Squeeze"]
s_cnt = len(squeeze_nodes)
if s_cnt == 1:
s = squeeze_nodes[0]
trans_nodes = g.find_output_consumers(s.output[0])
if len(trans_nodes) == 1:
if trans_nodes[0].type == "Transpose":
reverse_nodes = g.find_output_consumers(trans_nodes[0].output[0])
elif utils.is_tf_reverse_op(trans_nodes[0]):
reverse_nodes = trans_nodes
else:
logger.debug("not found reverse op, unexpected")
return []

are_all_reverse = all([utils.is_tf_reverse_op(r_op) for r_op in reverse_nodes])
if are_all_reverse:
return reverse_nodes
reverse_or_slice_nodes = g.find_output_consumers(s.output[0])
if len(reverse_or_slice_nodes) == 1:
if reverse_or_slice_nodes[0].type == "Transpose":
reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])

if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity":
reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])
if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity":
reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do it twice here? Or update it like:

        if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity":
            reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])
            if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity":
                reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, tensorflow makes one or two Identity, it depends on Layer arguments.
I changed the code to nested if.

are_all_reverse_or_slice = all([
utils.is_tf_reverse_op(r_op) or is_tail_slice_op(r_op)
for r_op in reverse_or_slice_nodes
])
if are_all_reverse_or_slice:
return reverse_or_slice_nodes

logger.debug("bw y output is used followed by reverse node")
return []
Expand Down Expand Up @@ -619,13 +648,28 @@ def slice_birnn_for_original_rnn_consumers(g, rnn_fw, rnn_bw, bi_rnn, rnn_output

if rnn_output_index == 0:
axis = 1
# remove reverse op for rnn_bw
reverse_nodes = get_reverse_nodes_after_y_output(g, rnn_bw)

for r_op in reverse_nodes:
logger.debug("remove reverse op %s", r_op.name)
g.replace_all_inputs(r_op.output[0], r_op.input[0], ops=all_nodes)
to_remove.append(r_op.name)
# remove reverse(return_sequence=True) or tail slice(return_sequence=False) op for rnn_bw
reverse_or_slice_nodes = get_reverse_or_slice_nodes_after_y_output(g, rnn_bw)

for r_op in reverse_or_slice_nodes:
if utils.is_tf_reverse_op(r_op):
logger.debug("remove reverse op %s", r_op.name)
g.replace_all_inputs(r_op.output[0], r_op.input[0], ops=all_nodes)
to_remove.append(r_op.name)
elif is_tail_slice_op(r_op):
# in case of return_sequence=False
# replace output[-1:] to output[0:1]
attr = {"axes": [0], "starts": [0], "ends": [1]}
inputs_map = {"data": r_op.input[0], **attr}
slice_node_bw = GraphBuilder(g).make_slice(inputs_map)
all_nodes.append(g.get_node_by_output(slice_node_bw))

inputs_map = {"data": slice_node_bw, "axes": [0]}
squeeze_node_bw = GraphBuilder(g).make_squeeze(inputs_map)
all_nodes.append(g.get_node_by_output(squeeze_node_bw))

g.replace_all_inputs(r_op.output[0], squeeze_node_bw, ops=all_nodes)
to_remove.append(r_op.name)
elif rnn_output_index in [1, 2]:
axis = 0
else:
Expand Down