-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flax] Implement FlaxElectraModel, FlaxElectraForMaskedLM, FlaxElectraForPreTraining #9172
Conversation
…n model output, ranging 0.0010 - 0.0016
Returns: | ||
Normalized inputs (the same shape as inputs). | ||
""" | ||
features = x.shape[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we replace x
and y
by hidden_states
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, it's fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR looks great to me!
I think your argument for using setup
instead of nn.compact
is a good one and we should probably replace all usage of nn.compact
to setup
.
In addition to @chris-tng arguments for using setup
instead of nn.compact
(easier to test and makes a shorter, more concise forward function), I think a third argument is also that it'll make the class' signature more similar to PyTorch and TF. Transformers users would probably have an easier time understanding what is happening when setup
is implemented vs. nn.compact
.
I'd be in support of replacing all nn.compacts
with the setup
function.
What do you think @sgugger @mfuntowicz? I don't really see an advantage in using nn.compact
over setup
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm no Flax expert (yet) but this looks good to me! For the difference between nn.compact
and setup
I really don't know enough to be able to weigh in.
Args: | ||
x: the inputs | ||
|
||
Returns: | ||
Normalized inputs (the same shape as inputs). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use 4 spaces for indentation :-)
class FlaxElectraPooler(nn.Module): | ||
kernel_init_scale: float = 0.2 | ||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | ||
|
||
@nn.compact | ||
def __call__(self, hidden_states): | ||
cls_token = hidden_states[:, 0] | ||
out = nn.Dense( | ||
hidden_states.shape[-1], | ||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), | ||
name="dense", | ||
dtype=self.dtype, | ||
)(cls_token) | ||
return nn.tanh(out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we have this in the PyTorch/TF versions? And looking at the file it doesn't seem to be used anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's removed!
Hi @chris-tng -- thanks for trying out Flax in HF Transformers! A quick comment on Indeed (Please do let us know whatever other thoughts or questions on Flax on our discussion board: https://github.com/google/flax/discussions) Happy holidays and new year! |
Hey @avital, Thanks a lot for your input here! That's very useful. Most of the main contributors to Transformers are on holiday at the moment and this is a rather big design decision to make going forward with Flax, so I think we'll have to wait here until early January until everybody is back (@sgugger, @LysandreJik, @mfuntowicz) Happy holiday to you as well :-) |
Hi @avital , Apology for my delayed response. I appreciate your great work on Flax. Regarding the use of class Dummy(nn.Module):
def setup(self):
self.submodule1 = nn.Dense(10)
self.submodule2 = MyLayerNorm()
def __call__(self):
# do something here After loading model weights from a dict, I can access/debug submodule by simply accessing the attribute: Shameless plug, I wrote a blog post about porting huggingface pytorch model to flax, here. I'm a new Flax user so please correct me if I'm missing anything. Happy holiday and happy new year to everyone 🎄 🍾 |
Hey @chris-tng, sorry to had you wait for this long. I'll solve the merge conflicts in your PR and then use your PR to change the
|
Intermediate state is saved here: #9484 will push to this PR on Monday the latest |
Hey @chris-tng, I noticed that we will probably have to wait a bit to get this merged: google/flax#683 to be able to continue the PR. Will keep you up-to-date :-) |
Hi folks, sorry for the delay with the new-year shuffle and school shutdown. google/flax#683 required a bit more conversation and updating some other codebases but now it's merged! If you have a moment, please take a look and see if it helps unblock progress. We'll release Flax 0.4.0 soon, but installing from GitHub now is the way to go. |
Hey, sorry for barging in |
Hey @CoderPat, It would be great if you could wait until #11364 is merged (should be done in the next 2 days). The PR fixes a couple of bugs :-) |
No problem @patrickvonplaten! Also regarding git logistics, is it better to ask @chris-tng for permission to push directly to his branch? |
I think it's alright to copy past the code that is still useful and open a new branch, if you'd like to add Electra :-). On the branch we should then give credit to @chris-tng , but since the PR is quite old now I think he would be fine if we close this one and open a new one (Please let me know if this is not the case @chris-tng :-)) . #11364 should be the last refactor before the "fundamental" Flax design is finished. |
Just to confirm @patrickvonplaten , the flax refactor is merged and the structure should be stable enough that I can work on implementing Electra right? |
Exactly @CoderPat - very much looking forward to your PR :-) |
Closing as this PR is super old and partly fixed by #11426 |
What does this PR do?
FlaxElectraModel
,FlaxElectraForMaskedLM
,FlaxElectraForPreTraining
. Most of the code taken from FlaxBert version with changes in parameters and forward pass.convert_to_pytorch
to load weights for ElectraFlaxElectraGeneratorPredictions
,FlaxElectraDiscriminatorPredictions
for generator and discriminator prediction head.tests/test_modeling_flax_electra.py
Forward pass works by running
Hi @patrickvonplaten , @mfuntowicz , I've seen your work on FlaxBert, so I'm tagging in case you want to review. Please note that I use
flax setup
instead of decorator@nn.compact
since the former@nn.compact
I'm happy to revert this change to make code style consistent.
Let me know if you have any questions or feedbacks.
Thanks.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
tests/test_modeling_flax_electra.py
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.