Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix keras bidirectional merge failures #1869

Merged
merged 4 commits into from
Apr 2, 2022

Conversation

kota-row
Copy link
Contributor

Fixed an issue where the Bidirectional layers generated by Keras were not correctly combined into a single LSTM/GRU.

currently biltm_rewriter/bigru_rewriter only supports the following structure

        /-------------(optional: Transpose) -- forward RNN -- Squeeze -- (optional: Transpose)
shared input
        \-- Reverse---(optional: Transpose) -- backward RNN -- Squeeze -- (optional: Transpose) -- Reverse

but, tf2 keras Bidirectional Layer generates different graph

  • There are several Identity layers in between.
  • shared_input -> Transpose -> Reverse -> backward (Unlike the above figure, Transpose and Reverse have been swapped)
  • return_sequence=False, which has not Reverse after Squeeze, but has StridedSlice

I also added a test to check that all Bidirectional recurrent layers are combined into one. The first commit will fail the test.

Signed-off-by: Kotaro Yamamoto <kota.crk@gmail.com>
@lgtm-com
Copy link

lgtm-com bot commented Feb 26, 2022

This pull request introduces 1 alert when merging 00562aa into bf4a22d - view on LGTM.com

new alerts:

  • 1 for Variable defined multiple times

support below cases:
- there are one or two Identity layers between input/output and RNN
- Transpose-Reverse-backward (previously, only Reverse-Transpose-backward was supported)
- return_sequences=False with no Reverse after the backward

Signed-off-by: Kotaro Yamamoto <kota.crk@gmail.com>
@kota-row kota-row marked this pull request as ready for review February 26, 2022 13:35
@hwangdeyu hwangdeyu requested review from hwangdeyu and fatcat-z March 8, 2022 09:01

if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity":
reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do it twice here? Or update it like:

        if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity":
            reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])
            if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity":
                reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, tensorflow makes one or two Identity, it depends on Layer arguments.
I changed the code to nested if.

@@ -620,12 +650,27 @@ def slice_birnn_for_original_rnn_consumers(g, rnn_fw, rnn_bw, bi_rnn, rnn_output
if rnn_output_index == 0:
axis = 1
# remove reverse op for rnn_bw
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the comments accordingly to involve tailed slice node.

- Consecutive Identity checks changed to nested if
- update comment for remove Reverse or tail-slice op

Signed-off-by: Kotaro Yamamoto <kota.crk@gmail.com>
Copy link
Collaborator

@fatcat-z fatcat-z left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution!

@fatcat-z fatcat-z merged commit de7e5af into onnx:main Apr 2, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants