-
Notifications
You must be signed in to change notification settings - Fork 663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Bidirectional #2770
Add Bidirectional #2770
Conversation
dd4fdf5
to
b5d50c5
Compare
c51c9f6
to
1834bdb
Compare
c0314ad
to
da12439
Compare
Codecov Report
@@ Coverage Diff @@
## main #2770 +/- ##
==========================================
+ Coverage 81.91% 82.01% +0.09%
==========================================
Files 55 55
Lines 5918 5977 +59
==========================================
+ Hits 4848 4902 +54
- Misses 1070 1075 +5
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
kernel = jnp.concatenate(kernels, axis=-1) | ||
if use_bias: | ||
biases = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is created to make mypy
happy but also improves the correctness of the program.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!!
@@ -717,6 +744,12 @@ def scan_fn( | |||
else: | |||
carry, outputs = scan_output | |||
|
|||
if reverse and keep_order: | |||
outputs = jax.tree_map( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe factor this out into a separate function since you use it twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we DRY or WET? For this use case I think abstracting away a tree_map leads to unnecessary indirection.
b7c0a91
to
984ec14
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for coding this up!
flax/linen/recurrent.py
Outdated
@@ -651,16 +664,27 @@ def __call__( | |||
else it will return a tuple of the final carry and the output sequence. | |||
time_major: if ``time_major=False`` (default) it will expect inputs with shape | |||
``(*batch, time, *features)``, else it will expect inputs with shape ``(time, *batch, *features)``. | |||
reverse: overrides the ``reverse`` attribute, if ``reverse=False`` (default) the sequence | |||
is processed from left to right, else it will be processed from right to left, output order | |||
will be the same as input order. If ``segmentation_mask`` is passed, padding elements will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think keep_order
is still a bit difficult to understand, so good documentation is important. This might be too unclear:
reverse: overrides the ``reverse`` attribute, if ``reverse=False`` (default) the sequence
is processed from left to right, else it will be processed from right to left, output order
will be the same as input order.
what about:
reverse: overrides the reverse
attribute, if reverse=False
(default) the sequence
is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order (but with the padding values at the end).
maybe even a simple example could be good. e.g.
`[1, 2, 3, 0, 0]` vs `[3, 2, 1, 0, 0]`?
984ec14
to
74e00f2
Compare
74e00f2
to
6412c6c
Compare
What does this PR do?
Bidirectional
Module that takes twoRNN
instances and applies the bidirectional logic.reverse
andkeep_order
arguments toRNN
to allow the implementation of the bidirectional logic.Array
fromAny
tojax.Array
and fixes some type errors.