-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Flax BLOOM implementation + demo #17761
[WIP] Flax BLOOM implementation + demo #17761
Conversation
A note on the initial status of this PR:
Next steps:
Later on:
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a great start @haileyschoelkopf, and a good set of TODO's for the next steps! Feel free to ping me (or @patrickvonplaten or @patil-suraj) if you encounter any diffculties or want to discuss ideas, very happy to help with the integration here!
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | ||
) | ||
|
||
# TODO: make this one dense layer that is split into 3 on forward named self.query_key_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For an example of how this can be done, see: https://github.com/sanchit-gandhi/seq2seq-speech/blob/6de3d8047d568aae8c4ad307d897e4e5c614ae1e/models/modeling_flax_bart.py#L86
|
||
self.attn_dropout = nn.Dropout(config.attention_dropout) | ||
|
||
# Scaled Softmax TODO: change this to something implemented in jax (maybe implement in __call__ for attn module?) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! We should define FlaxBloomScaledSoftmax
as a standalone Flax nn.Module
(similar to how done in PT) with its own setup
and __call__
methods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class has an initial implementation now!
|
||
self.input_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon, dtype=self.dtype) | ||
|
||
# TODO: should check if this line (n_head) can be removed. if so, can be removed in pytorch impl. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks to be the case! Feel free to open a PR to remove from the PT implementation :-)
return outputs | ||
|
||
# TODO: does this still require position_ids? | ||
# TODO: gradient checkpointing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For gradient checkpointing details, see: #17399
Happy to lend a hand here! Have gotten it working myself in a Flax Seq2Seq Speech project
https://github.com/sanchit-gandhi/seq2seq-speech/blob/b28d0c25c8fad0f9ffa6707f91f7aba320d44a4b/models/modeling_flax_wav2vec2.py#L504
|
||
return outputs | ||
|
||
# TODO: does this still require position_ids? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Flax, we have a design philosophy in which we compute the position_ids
and losses outside the model, and so need to pass the position_ids
to the model!
Thanks for the helpful comments, @sanchit-gandhi ! I'll do another revision through the code fixing these and adding some more things as soon as I have time. I think that the main thing that would need to be discussed soon is how to handle AliBi for position information, since it means that there is no specific max length for BLOOM inputs. I'm not too sure yet how to account for this given things like line 176, where the causal mask is made at the max length of the model and then sliced to get the mask for shorter sequences. (One idea I had was selecting a reasonable "starting max length", then if the model gets a longer input sequence the causal mask is extended either permanently or just for that forward pass). |
Would there be any issues in implementing it in the first of the two ways proposed (set to The problem I envision with the latter of the two approaches is that once the function is jit'd, providing a new input length to the model would result in XLA having to recompile. Each time XLA sees a new input shape, it needs to trace out (compile) the function again. So if we provide a new input shape for each forward pass, XLA will recompile every time (very slow)! The performance benefits of jit'ing a function come when we re-use a function that has already been jit'd, meaning we should try and use fixed input shapes where possible. |
Yeah, the recompilation is definitely something to try to avoid! But the issue is that the bigscience/bloom config doesn't have any seq_length attribute (but bigscience/bloom-1b3 does--4096) and we want BLOOM to be able to handle sequences as long as a user wants since AliBi allows generalization to longer sequences. We could maybe just choose a reasonable default But I think we should keep the possibility open to using the model on very long sequences without problems--I don't know if any other models in Transformers use AliBi embeddings yet so that's a unique benefit of this model. |
Let's go with that to start - we can iterate and find an optimal solution as we progress. There's also the option of asking on one of the JAX/Flax Forums to see if the framework authors have any ideas if we're stuck! You're right, this will be the first JAX/Flax model in Transformers to use AliBi embeddings! Will be very cool having a model with no theoretical |
Actually, I don't see a big problem with computing the |
Or am I misunderstanding something here? |
If just generating the causal mask at every forward pass is acceptable and wouldn't incur a speed penalty, then that should work fine! And yes, I don't think that we need to pass position_ids into the model, and we can just compute the alibi embedding within the forward pass (the pytorch implementation does this.) sorry for the delay on this--I'll work on it in the next 2 days. |
Great! Yeah, I just talked to @sanchit-gandhi offline - I think what we want to do here to only recompile when the model has to be recompiled anyways which translates into doing the folowing: Allow |
Hey @haileyschoelkopf! This looks good with regards to the fused key-query-value matmuls in faddb8d! Just as a heads-up, for gradient checkpointing, you can follow the PR at #17843. Feel free to reach out if there's anything you wish to discuss, very happy to help with any questions! |
Added gradient checkpointing, thanks for the pointer @sanchit-gandhi ! Sorry that I haven't been able to push things forward on this PR faster, ended up being busier the past few weeks than expected... EDIT: saw the other PR. @younesbelkada , FYI, there is gradient checkpointing code on this PR now if you need it. |
Thank you @haileyschoelkopf for jumping on this so quickly and getting the structure for the model in place! This PR was completed in #18022 Let me know if there's anything else you'd like to have a go at adding in JAX/Flax! Or if you'd like to have a go at porting another model to JAX/Flax I can make some suggestions! |
Thanks so much for all the helpful comments @sanchit-gandhi on this PR and apologies I wasn't able to iterate quicker on it! If I have more time to add another JAX model I'll ping you for sure :) |
Very sorry that we rushed this PR so much @haileyschoelkopf! Very much looking forward to other PRs if you'd like :-) |
Of course, will ping you if so :) |
What does this PR do?
This PR will add a Flax implementation of BLOOM, and also I'd be happy to help contribute a tutorial / showcase of how to fine-tune BLOOM as well as discussed in #17703 :)
Before submitting
Pull Request section?
to it if that's the case. --> linked above
documentation guidelines, and
here are tips on formatting docstrings. --> documentation in progress
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@patrickvonplaten
and @sanchit-gandhi @patil-suraj I believe were interested in collaborating. happy to discuss how best to do this.