Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bidirectional #2770

Merged
merged 1 commit into from
Mar 22, 2023
Merged

Add Bidirectional #2770

merged 1 commit into from
Mar 22, 2023

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Jan 4, 2023

What does this PR do?

  • Adds a Bidirectional Module that takes two RNN instances and applies the bidirectional logic.
  • Adds the reverse and keep_order arguments to RNN to allow the implementation of the bidirectional logic.
  • Changes the type of Array from Any to jax.Array and fixes some type errors.

@marcvanzee marcvanzee marked this pull request as draft February 6, 2023 14:24
@cgarciae cgarciae marked this pull request as ready for review February 23, 2023 14:58
@cgarciae cgarciae changed the title [WIP] Add Bidirectional Add Bidirectional Feb 23, 2023
@cgarciae cgarciae force-pushed the add-bidirectional branch from c0314ad to da12439 Compare March 1, 2023 16:10
@codecov-commenter
Copy link

codecov-commenter commented Mar 1, 2023

Codecov Report

Merging #2770 (6412c6c) into main (23340bb) will increase coverage by 0.09%.
The diff coverage is 92.64%.

@@            Coverage Diff             @@
##             main    #2770      +/-   ##
==========================================
+ Coverage   81.91%   82.01%   +0.09%     
==========================================
  Files          55       55              
  Lines        5918     5977      +59     
==========================================
+ Hits         4848     4902      +54     
- Misses       1070     1075       +5     
Impacted Files Coverage Δ
flax/linen/__init__.py 100.00% <ø> (ø)
flax/linen/recurrent.py 97.81% <92.64%> (-1.73%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@cgarciae cgarciae requested review from marcvanzee and bastings March 1, 2023 18:02
kernel = jnp.concatenate(kernels, axis=-1)
if use_bias:
biases = []
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is created to make mypy happy but also improves the correctness of the program.

Copy link
Collaborator

@marcvanzee marcvanzee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!!

@@ -717,6 +744,12 @@ def scan_fn(
else:
carry, outputs = scan_output

if reverse and keep_order:
outputs = jax.tree_map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe factor this out into a separate function since you use it twice?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we DRY or WET? For this use case I think abstracting away a tree_map leads to unnecessary indirection.

flax/linen/recurrent.py Outdated Show resolved Hide resolved
flax/linen/recurrent.py Show resolved Hide resolved
flax/linen/recurrent.py Show resolved Hide resolved
flax/linen/recurrent.py Outdated Show resolved Hide resolved
@cgarciae cgarciae requested a review from marcvanzee March 7, 2023 15:05
@cgarciae cgarciae force-pushed the add-bidirectional branch 2 times, most recently from b7c0a91 to 984ec14 Compare March 8, 2023 15:55
Copy link
Contributor

@bastings bastings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for coding this up!

@@ -651,16 +664,27 @@ def __call__(
else it will return a tuple of the final carry and the output sequence.
time_major: if ``time_major=False`` (default) it will expect inputs with shape
``(*batch, time, *features)``, else it will expect inputs with shape ``(time, *batch, *features)``.
reverse: overrides the ``reverse`` attribute, if ``reverse=False`` (default) the sequence
is processed from left to right, else it will be processed from right to left, output order
will be the same as input order. If ``segmentation_mask`` is passed, padding elements will
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think keep_order is still a bit difficult to understand, so good documentation is important. This might be too unclear:

reverse: overrides the ``reverse`` attribute, if ``reverse=False`` (default) the sequence
        is processed from left to right, else it will be processed from right to left, output order
        will be the same as input order.

what about:

reverse: overrides the reverse attribute, if reverse=False (default) the sequence
is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order (but with the padding values at the end).


maybe even a simple example could be good. e.g. 
`[1, 2, 3, 0, 0]` vs `[3, 2, 1, 0, 0]`?

@cgarciae cgarciae force-pushed the add-bidirectional branch from 984ec14 to 74e00f2 Compare March 21, 2023 22:12
@cgarciae cgarciae force-pushed the add-bidirectional branch from 74e00f2 to 6412c6c Compare March 22, 2023 00:13
@chiamp chiamp self-requested a review March 22, 2023 04:09
@copybara-service copybara-service bot merged commit a4a89ae into google:main Mar 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants