-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flax] Add FlaxMBart #12236
[Flax] Add FlaxMBart #12236
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this @stancld!
It's looking great overall, I left a few comments. Specifically
- The order of layer norm and attention and
- Could we add as much copied from statements as possible?
* Fix shift_tokens_right method according to MBart implementation * Update shift_tokens_right in tests accordingly * Fix the import issue and update docs file * make style quality
* Change the order of normalization layer and attention * Add some copu statementes
@patil-suraj Thank you a lot for your suggestions. I fixed the order of attention and normalization layers and some other minor bugs. Also added some additional copy statements. I also changed the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Just left a couple of comments that need to be taken care of before merging.
I will run all slow tests and push the checkpoints to the hub before merging.
Besides, add `lang_code_to_id` to MBartTokenizeFast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The encoder and decoder reqiores an extra layer_norm at the end
@stancld I pushed a couple of commits to add the |
[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)] | ||
).squeeze() | ||
# for loop basically does jax-compatible version of prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() | ||
for i in range(prev_output_tokens.shape[1], 0, -1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
What does this PR do?
This PR adds flax implementation of MBart.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@patrickvonplaten @patil-suraj