Skip to content

Commit

Permalink
Fix keras bidirectional merge failures (#1869)
Browse files Browse the repository at this point in the history
* add test to check keras bidirectional recurrent is merged

Signed-off-by: Kotaro Yamamoto <kota.crk@gmail.com>

* fix keras bidirectional merge failures

support below cases:
- there are one or two Identity layers between input/output and RNN
- Transpose-Reverse-backward (previously, only Reverse-Transpose-backward was supported)
- return_sequences=False with no Reverse after the backward

Signed-off-by: Kotaro Yamamoto <kota.crk@gmail.com>

* apply review comments for Bidirectional fix

- Consecutive Identity checks changed to nested if
- update comment for remove Reverse or tail-slice op

Signed-off-by: Kotaro Yamamoto <kota.crk@gmail.com>

Co-authored-by: Jay Zhang <36183870+fatcat-z@users.noreply.github.com>
  • Loading branch information
kota-row and fatcat-z authored Apr 2, 2022
1 parent 0e3720c commit de7e5af
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 27 deletions.
10 changes: 9 additions & 1 deletion tests/keras2onnx_unit_tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mock_keras2onnx.proto import (keras, is_tf_keras,
is_tensorflow_older_than, is_tensorflow_later_than,
is_keras_older_than, is_keras_later_than)
from test_utils import no_loops_in_tf2
from test_utils import no_loops_in_tf2, all_recurrents_should_bidirectional

K = keras.backend
Activation = keras.layers.Activation
Expand Down Expand Up @@ -2073,6 +2073,7 @@ def test_bidirectional(runner, rnn_class, return_sequences):
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
onnx_model = convert_keras(model, 'test', target_opset=op_version)
assert all_recurrents_should_bidirectional(onnx_model)
for batch in batch_list:
data = np.random.rand(batch, sequence_len, input_dim).astype(np.float32)
expected = model.predict(data)
Expand All @@ -2084,6 +2085,7 @@ def test_bidirectional(runner, rnn_class, return_sequences):
input_shape=(5, 10), merge_mode=merge_mode)(sub_input1)
keras_model = keras.Model(inputs=sub_input1, outputs=sub_mapped1)
onnx_model = convert_keras(keras_model, 'test_2', target_opset=op_version)
assert all_recurrents_should_bidirectional(onnx_model)
for batch in batch_list:
data = np.random.rand(batch, sequence_len, input_dim).astype(np.float32)
expected = keras_model.predict(data)
Expand All @@ -2102,6 +2104,7 @@ def test_bidirectional_with_bias(runner, rnn_class):
# Test with the default bias
expected = model.predict(x)
onnx_model = convert_keras(model, model.name)
assert all_recurrents_should_bidirectional(onnx_model)
assert runner(onnx_model.graph.name, onnx_model, x, expected)

# Set bias values to random floats
Expand All @@ -2114,6 +2117,7 @@ def test_bidirectional_with_bias(runner, rnn_class):
# Test with random bias
expected = model.predict(x)
onnx_model = convert_keras(model, model.name)
assert all_recurrents_should_bidirectional(onnx_model)
assert runner(onnx_model.graph.name, onnx_model, x, expected)


Expand Down Expand Up @@ -2141,6 +2145,7 @@ def test_bidirectional_time_major_true(runner, rnn_class):

expected = model.predict(x)
onnx_model = convert_keras(model, model.name)
assert all_recurrents_should_bidirectional(onnx_model)
assert runner(onnx_model.graph.name, onnx_model, x, expected)


Expand All @@ -2155,6 +2160,7 @@ def test_bidirectional_with_initial_states(runner, rnn_class):

expected = model.predict(inputs)
onnx_model = convert_keras(model, model.name)
assert all_recurrents_should_bidirectional(onnx_model)
assert runner(onnx_model.graph.name, onnx_model, inputs, expected)

input2 = Input(shape=(None, 5))
Expand All @@ -2165,6 +2171,7 @@ def test_bidirectional_with_initial_states(runner, rnn_class):

expected = model.predict(inputs)
onnx_model = convert_keras(model, model.name)
assert all_recurrents_should_bidirectional(onnx_model)
assert runner(onnx_model.graph.name, onnx_model, inputs, expected, atol=1e-5)


Expand All @@ -2180,6 +2187,7 @@ def test_bidirectional_seqlen_none(runner, rnn_class):
model.add(Dense(44))

onnx_model = convert_keras(model, model.name)
assert all_recurrents_should_bidirectional(onnx_model)
for batch in [1, 4]:
x = np.random.rand(batch, 50).astype(np.float32)
expected = model.predict(x)
Expand Down
9 changes: 9 additions & 0 deletions tests/keras2onnx_unit_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys
import onnx
from onnx import helper
import numpy as np
import mock_keras2onnx
from mock_keras2onnx.proto import keras, is_keras_older_than
Expand Down Expand Up @@ -161,6 +162,14 @@ def no_loops_in_tf2(onnx_model):
return not is_tf2 or all(n.op_type != "Loop" for n in onnx_model.graph.node)


def all_recurrents_should_bidirectional(onnx_model):
return all([
helper.get_attribute_value(attr) == b'bidirectional'
for node in onnx_model.graph.node if node.op_type in ['GRU', 'LSTM', 'RNN']
for attr in node.attribute if attr.name == 'direction'
])


def run_onnx_runtime(case_name, onnx_model, data, expected, model_files, rtol=1.e-3, atol=1.e-6,
compare_perf=False, enable_profiling=False):
if not os.path.exists(tmp_path):
Expand Down
96 changes: 70 additions & 26 deletions tf2onnx/rewriter/rnn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,26 +485,42 @@ def find_bidirectional_rnns(g, ops, rnn_type):
input_id = n.input[0]
temp = n.inputs[0]
is_bw = False
is_transposed = False
if temp.type == "Transpose":
input_id = temp.input[0]
temp = temp.inputs[0]
is_transposed = True

if utils.is_tf_reverse_op(temp):
input_id = temp.input[0]
temp = temp.inputs[0]
is_bw = True

if (not is_transposed) and temp.type == "Transpose":
input_id = temp.input[0]
temp = temp.inputs[0]

input_ids = [input_id]
if temp.type == "Identity":
input_ids.append(temp.input[0])
temp = temp.inputs[0]
if temp.type == "Identity":
input_ids.append(temp.input[0])

if is_bw:
# if output 0 is consumed and there is no reverse after the 1st output.
# it's not backward rnn.
if g.find_output_consumers(n.output[0]) and not get_reverse_nodes_after_y_output(g, n):
if g.find_output_consumers(n.output[0]) and not get_reverse_or_slice_nodes_after_y_output(g, n):
logger.warning("rnn %s following Reverse op isn't the part of bi-rnn.", n.name)
continue

logger.debug("find bw rnn %s", input_id)
bw_rnns[input_id].append(n)
logger.debug("find bw rnn %s", input_ids)
for input_id in input_ids:
bw_rnns[input_id].append(n)
else:
logger.debug("find fw rnn %s", input_id)
fw_rnns[input_id].append(n)
logger.debug("find fw rnn %s", input_ids)
for input_id in input_ids:
fw_rnns[input_id].append(n)

# fw_rnn and bw_rnn must share the same input
birnn_input = list(set(fw_rnns.keys()).intersection(bw_rnns.keys()))
Expand Down Expand Up @@ -554,27 +570,40 @@ def belong_to_birnn(g, fw_rnn, bw_rnn, rnn_type):
return True


def get_reverse_nodes_after_y_output(g, rnn_bw):
def is_tail_slice_op(node):
return (
node.type == 'StridedSlice' and
node.inputs[1].get_tensor_value() == [-1] and
node.inputs[2].get_tensor_value() == [0] and
node.inputs[3].get_tensor_value() == [1] and
node.get_attr('shrink_axis_mask').i == 1
)


def get_reverse_or_slice_nodes_after_y_output(g, rnn_bw):
bw_consumers = g.find_output_consumers(rnn_bw.output[0])

# todo: figure out a better way to remove reverse op
squeeze_nodes = [c for c in bw_consumers if c.type == "Squeeze"]
s_cnt = len(squeeze_nodes)
if s_cnt == 1:
s = squeeze_nodes[0]
trans_nodes = g.find_output_consumers(s.output[0])
if len(trans_nodes) == 1:
if trans_nodes[0].type == "Transpose":
reverse_nodes = g.find_output_consumers(trans_nodes[0].output[0])
elif utils.is_tf_reverse_op(trans_nodes[0]):
reverse_nodes = trans_nodes
else:
logger.debug("not found reverse op, unexpected")
return []

are_all_reverse = all([utils.is_tf_reverse_op(r_op) for r_op in reverse_nodes])
if are_all_reverse:
return reverse_nodes
reverse_or_slice_nodes = g.find_output_consumers(s.output[0])
if len(reverse_or_slice_nodes) == 1:
if reverse_or_slice_nodes[0].type == "Transpose":
reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])

if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity":
reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])
if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity":
reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])

are_all_reverse_or_slice = all([
utils.is_tf_reverse_op(r_op) or is_tail_slice_op(r_op)
for r_op in reverse_or_slice_nodes
])
if are_all_reverse_or_slice:
return reverse_or_slice_nodes

logger.debug("bw y output is used followed by reverse node")
return []
Expand Down Expand Up @@ -619,13 +648,28 @@ def slice_birnn_for_original_rnn_consumers(g, rnn_fw, rnn_bw, bi_rnn, rnn_output

if rnn_output_index == 0:
axis = 1
# remove reverse op for rnn_bw
reverse_nodes = get_reverse_nodes_after_y_output(g, rnn_bw)

for r_op in reverse_nodes:
logger.debug("remove reverse op %s", r_op.name)
g.replace_all_inputs(r_op.output[0], r_op.input[0], ops=all_nodes)
to_remove.append(r_op.name)
# remove reverse(return_sequence=True) or tail slice(return_sequence=False) op for rnn_bw
reverse_or_slice_nodes = get_reverse_or_slice_nodes_after_y_output(g, rnn_bw)

for r_op in reverse_or_slice_nodes:
if utils.is_tf_reverse_op(r_op):
logger.debug("remove reverse op %s", r_op.name)
g.replace_all_inputs(r_op.output[0], r_op.input[0], ops=all_nodes)
to_remove.append(r_op.name)
elif is_tail_slice_op(r_op):
# in case of return_sequence=False
# replace output[-1:] to output[0:1]
attr = {"axes": [0], "starts": [0], "ends": [1]}
inputs_map = {"data": r_op.input[0], **attr}
slice_node_bw = GraphBuilder(g).make_slice(inputs_map)
all_nodes.append(g.get_node_by_output(slice_node_bw))

inputs_map = {"data": slice_node_bw, "axes": [0]}
squeeze_node_bw = GraphBuilder(g).make_squeeze(inputs_map)
all_nodes.append(g.get_node_by_output(squeeze_node_bw))

g.replace_all_inputs(r_op.output[0], squeeze_node_bw, ops=all_nodes)
to_remove.append(r_op.name)
elif rnn_output_index in [1, 2]:
axis = 0
else:
Expand Down

0 comments on commit de7e5af

Please sign in to comment.