Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] Add FlaxMBart #12236

Merged
merged 19 commits into from
Jul 7, 2021
Merged

[Flax] Add FlaxMBart #12236

merged 19 commits into from
Jul 7, 2021

Conversation

stancld
Copy link
Contributor

@stancld stancld commented Jun 17, 2021

What does this PR do?

This PR adds flax implementation of MBart.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@patrickvonplaten @patil-suraj

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this @stancld!

It's looking great overall, I left a few comments. Specifically

  • The order of layer norm and attention and
  • Could we add as much copied from statements as possible?

src/transformers/models/mbart/modeling_flax_mbart.py Outdated Show resolved Hide resolved
src/transformers/models/mbart/modeling_flax_mbart.py Outdated Show resolved Hide resolved
src/transformers/models/mbart/modeling_flax_mbart.py Outdated Show resolved Hide resolved
src/transformers/models/mbart/modeling_flax_mbart.py Outdated Show resolved Hide resolved
src/transformers/models/mbart/modeling_flax_mbart.py Outdated Show resolved Hide resolved
tests/test_modeling_flax_mbart.py Outdated Show resolved Hide resolved
tests/test_modeling_flax_mbart.py Outdated Show resolved Hide resolved
tests/test_modeling_flax_mbart.py Outdated Show resolved Hide resolved
stancld added 3 commits June 19, 2021 15:08
* Fix shift_tokens_right method according to MBart implementation

* Update shift_tokens_right in tests accordingly

* Fix the import issue and update docs file
* make style quality
* Change the order of normalization layer and attention

* Add some copu statementes
@stancld
Copy link
Contributor Author

stancld commented Jun 19, 2021

@patil-suraj Thank you a lot for your suggestions. I fixed the order of attention and normalization layers and some other minor bugs. Also added some additional copy statements.

I also changed the shift_tokens_right method as this one looks to be different for the MBart models as they don't have a single decoder_start_token_id in contrast to other Bart-like models. => This difference of having no decoder_start_token_id, however, currently leads to some issues within the generate method. (I'll try to have a look what can be done here)

@stancld stancld changed the title [WIP] [Flax] Add FlaxMBart [Flax] Add FlaxMBart Jun 20, 2021
Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Just left a couple of comments that need to be taken care of before merging.

I will run all slow tests and push the checkpoints to the hub before merging.

src/transformers/generation_flax_utils.py Outdated Show resolved Hide resolved
src/transformers/models/mbart/modeling_flax_mbart.py Outdated Show resolved Hide resolved
src/transformers/models/mbart/modeling_flax_mbart.py Outdated Show resolved Hide resolved
Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The encoder and decoder reqiores an extra layer_norm at the end

src/transformers/models/mbart/modeling_flax_mbart.py Outdated Show resolved Hide resolved
src/transformers/models/mbart/modeling_flax_mbart.py Outdated Show resolved Hide resolved
tests/test_modeling_flax_mbart.py Outdated Show resolved Hide resolved
@patil-suraj
Copy link
Contributor

@stancld I pushed a couple of commits to add the layer_norm in encoder and decoder. Now, all slow tests are passing.
@patrickvonplaten could you please take a final look?

[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)]
).squeeze()
# for loop basically does jax-compatible version of prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
for i in range(prev_output_tokens.shape[1], 0, -1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

@patil-suraj patil-suraj merged commit 61400e1 into huggingface:master Jul 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants