Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] Implement FlaxElectraModel, FlaxElectraForMaskedLM, FlaxElectraForPreTraining #9172

Closed
wants to merge 14 commits into from

Conversation

chris-tng
Copy link
Contributor

What does this PR do?

  1. Implement Flax version of Electra model : FlaxElectraModel, FlaxElectraForMaskedLM, FlaxElectraForPreTraining. Most of the code taken from FlaxBert version with changes in parameters and forward pass.
  2. Adjust convert_to_pytorch to load weights for Electra
  3. Implement FlaxElectraGeneratorPredictions, FlaxElectraDiscriminatorPredictions for generator and discriminator prediction head.
  4. Implement test in tests/test_modeling_flax_electra.py

Forward pass works by running

pytest tests/test_modeling_flax_electra.py

Hi @patrickvonplaten , @mfuntowicz , I've seen your work on FlaxBert, so I'm tagging in case you want to review. Please note that I use flax setup instead of decorator @nn.compact since the former

  • allows to test and inspect submodule
  • separate submodule declaration from the forward pass. Forward pass method can be very long if using @nn.compact
    I'm happy to revert this change to make code style consistent.

Let me know if you have any questions or feedbacks.
Thanks.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests? Yes, test added tests/test_modeling_flax_electra.py

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.

Returns:
Normalized inputs (the same shape as inputs).
"""
features = x.shape[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we replace x and y by hidden_states?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it's fixed

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks great to me!

I think your argument for using setup instead of nn.compact is a good one and we should probably replace all usage of nn.compact to setup.

In addition to @chris-tng arguments for using setup instead of nn.compact (easier to test and makes a shorter, more concise forward function), I think a third argument is also that it'll make the class' signature more similar to PyTorch and TF. Transformers users would probably have an easier time understanding what is happening when setup is implemented vs. nn.compact.

I'd be in support of replacing all nn.compacts with the setup function.

What do you think @sgugger @mfuntowicz? I don't really see an advantage in using nn.compact over setup.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm no Flax expert (yet) but this looks good to me! For the difference between nn.compact and setup I really don't know enough to be able to weigh in.

Comment on lines 102 to 106
Args:
x: the inputs

Returns:
Normalized inputs (the same shape as inputs).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use 4 spaces for indentation :-)

Comment on lines 354 to 367
class FlaxElectraPooler(nn.Module):
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation

@nn.compact
def __call__(self, hidden_states):
cls_token = hidden_states[:, 0]
out = nn.Dense(
hidden_states.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(cls_token)
return nn.tanh(out)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have this in the PyTorch/TF versions? And looking at the file it doesn't seem to be used anywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's removed!

@avital
Copy link
Contributor

avital commented Dec 24, 2020

Hi @chris-tng -- thanks for trying out Flax in HF Transformers!

A quick comment on nn.compact and setup (I work on Flax) -- indeed if you want access to submodules such as for transfer learning then using setup is the way to go. I skimmed the PR and see that you use setup in some places and nn.compact in others? I'm curious whether you found nn.compact more useful in particular settings.

Indeed setup is more similar to the PyTorch style (though you still get shape inference if you use nn.compact in modules that don't have submodules). nn.compact is nice if you want to use loops or conditionals to define submodules based on hyperparameters, and some people also prefer how it "co-locates" their submodule definitions and usage. But ultimate it's somewhat a matter of preference.

(Please do let us know whatever other thoughts or questions on Flax on our discussion board: https://github.com/google/flax/discussions)

Happy holidays and new year!

@patrickvonplaten
Copy link
Contributor

shape

Hey @avital,

Thanks a lot for your input here! That's very useful. Most of the main contributors to Transformers are on holiday at the moment and this is a rather big design decision to make going forward with Flax, so I think we'll have to wait here until early January until everybody is back (@sgugger, @LysandreJik, @mfuntowicz)

Happy holiday to you as well :-)

@chris-tng
Copy link
Contributor Author

Hi @avital ,

Apology for my delayed response. I appreciate your great work on Flax. Regarding the use of setup() and nn.compact, personally I find setup works better for testing submodules. This is useful for converting and debugging the module (and submodules). For instance, I can create a model/module with many submodules:

class Dummy(nn.Module):

  def setup(self):
    self.submodule1 = nn.Dense(10)
    self.submodule2 = MyLayerNorm()

  def __call__(self):
    # do something here

After loading model weights from a dict, I can access/debug submodule by simply accessing the attribute: dummy.submodule1, dummy.submodule2. From this, I can debug forward pass, check model weights of invididual submodule.

Shameless plug, I wrote a blog post about porting huggingface pytorch model to flax, here. I'm a new Flax user so please correct me if I'm missing anything.

Happy holiday and happy new year to everyone 🎄 🍾

@patrickvonplaten
Copy link
Contributor

Hey @chris-tng,

sorry to had you wait for this long. I'll solve the merge conflicts in your PR and then use your PR to change the @nn.compact to setup in all other flax models as well so that we have a common standard now. Since most of our users are used to the "PyTorch" style and I only see advantages for our library philosophy:

  • We base most design decisions on the PyTorch style
  • We prefer slightly less compact readable code over the slightly more "magic" functionalities that might reduce code
  • To me the are no real downsides to using setup

@patrickvonplaten
Copy link
Contributor

Intermediate state is saved here: #9484 will push to this PR on Monday the latest

@patrickvonplaten
Copy link
Contributor

Hey @chris-tng,

I noticed that we will probably have to wait a bit to get this merged: google/flax#683 to be able to continue the PR. Will keep you up-to-date :-)

@avital
Copy link
Contributor

avital commented Feb 11, 2021

Hi folks, sorry for the delay with the new-year shuffle and school shutdown.

google/flax#683 required a bit more conversation and updating some other codebases but now it's merged! If you have a moment, please take a look and see if it helps unblock progress. We'll release Flax 0.4.0 soon, but installing from GitHub now is the way to go.

@github-actions github-actions bot closed this Mar 6, 2021
@LysandreJik LysandreJik reopened this Mar 6, 2021
@huggingface huggingface deleted a comment from github-actions bot Mar 6, 2021
@LysandreJik LysandreJik added WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress and removed wontfix labels Mar 6, 2021
@CoderPat
Copy link
Contributor

Hey, sorry for barging in
I was needing a small BERT-like model in Jax and so I've recently updated this in a local branch to work with the Flax refactoring that makes checkpoints directly compatible with PyTorch (plus fixing some other issues that had gone through the cracks)
Should I push directly to this branch or make a new PR from my fork? Also should I wait #11364 and update my code accordingly?

@patrickvonplaten
Copy link
Contributor

Hey, sorry for barging in
I was needing a small BERT-like model in Jax and so I've recently updated this in a local branch to work with the Flax refactoring that makes checkpoints directly compatible with PyTorch (plus fixing some other issues that had gone through the cracks)
Should I push directly to this branch or make a new PR from my fork? Also should I wait #11364 and update my code accordingly?

Hey @CoderPat,

It would be great if you could wait until #11364 is merged (should be done in the next 2 days). The PR fixes a couple of bugs :-)

@CoderPat
Copy link
Contributor

No problem @patrickvonplaten! Also regarding git logistics, is it better to ask @chris-tng for permission to push directly to his branch?

@patrickvonplaten
Copy link
Contributor

No problem @patrickvonplaten! Also regarding git logistics, is it better to ask @chris-tng for permission to push directly to his branch?

I think it's alright to copy past the code that is still useful and open a new branch, if you'd like to add Electra :-). On the branch we should then give credit to @chris-tng , but since the PR is quite old now I think he would be fine if we close this one and open a new one (Please let me know if this is not the case @chris-tng :-)) . #11364 should be the last refactor before the "fundamental" Flax design is finished.

@CoderPat
Copy link
Contributor

Just to confirm @patrickvonplaten , the flax refactor is merged and the structure should be stable enough that I can work on implementing Electra right?

@patrickvonplaten
Copy link
Contributor

Exactly @CoderPat - very much looking forward to your PR :-)

@ArthurZucker
Copy link
Collaborator

Closing as this PR is super old and partly fixed by #11426

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants