-
Notifications
You must be signed in to change notification settings - Fork 431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix keras bidirectional merge failures #1869
Conversation
Signed-off-by: Kotaro Yamamoto <kota.crk@gmail.com>
This pull request introduces 1 alert when merging 00562aa into bf4a22d - view on LGTM.com new alerts:
|
support below cases: - there are one or two Identity layers between input/output and RNN - Transpose-Reverse-backward (previously, only Reverse-Transpose-backward was supported) - return_sequences=False with no Reverse after the backward Signed-off-by: Kotaro Yamamoto <kota.crk@gmail.com>
00562aa
to
ce644d2
Compare
|
||
if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity": | ||
reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to do it twice here? Or update it like:
if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity":
reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])
if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity":
reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, tensorflow makes one or two Identity, it depends on Layer arguments.
I changed the code to nested if.
tf2onnx/rewriter/rnn_utils.py
Outdated
@@ -620,12 +650,27 @@ def slice_birnn_for_original_rnn_consumers(g, rnn_fw, rnn_bw, bi_rnn, rnn_output | |||
if rnn_output_index == 0: | |||
axis = 1 | |||
# remove reverse op for rnn_bw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the comments accordingly to involve tailed slice node.
- Consecutive Identity checks changed to nested if - update comment for remove Reverse or tail-slice op Signed-off-by: Kotaro Yamamoto <kota.crk@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution!
Fixed an issue where the Bidirectional layers generated by Keras were not correctly combined into a single LSTM/GRU.
currently biltm_rewriter/bigru_rewriter only supports the following structure
but, tf2 keras Bidirectional Layer generates different graph
I also added a test to check that all Bidirectional recurrent layers are combined into one. The first commit will fail the test.