Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Flax BLOOM implementation + demo #17761

Closed

Conversation

haileyschoelkopf
Copy link
Contributor

What does this PR do?

This PR will add a Flax implementation of BLOOM, and also I'd be happy to help contribute a tutorial / showcase of how to fine-tune BLOOM as well as discussed in #17703 :)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case. --> linked above
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings. --> documentation in progress
  • Did you write any new necessary tests? --> will add once code is closer to completion

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@patrickvonplaten

and @sanchit-gandhi @patil-suraj I believe were interested in collaborating. happy to discuss how best to do this.

@haileyschoelkopf
Copy link
Contributor Author

haileyschoelkopf commented Jun 17, 2022

A note on the initial status of this PR:

  • This first commit contains much of the code and structure of the modeling_flax_bloom.py file, copied from the gpt-neo Flax implementation and edited in many places already to better match the PyTorch Bloom implementation.
  • There are many TODOs I've left in this file that I still need to get to. The code is still not in a runnable/finished state,

Next steps:

  • Finish implementing all methods, in particular the FlaxBloomAttention __call__ method, until code runs (see other TODOs in file for other things that need tweaking/fixing)
  • Determine how to deal with alibi tensors and how to deal with Bloom not having any hardcoded max length
  • Once code is working, start testing whether the implementation is the same as PyTorch
  • Make sure tensor parallelism is working correctly / accounted for properly (see issue Abnormal behavior of OPT except OPT-350m #17653 , this still seems to be an open issue on how best to deal with it, but bigscience/bloom-350m has TP=1 so it can be used for testing at first without worrying about TP)

Later on:

  • Add unit tests once the code is working at least reasonably well!
  • Make sure all functions are stateless / code works fine with jit - I'm relatively new to Flax/Jax so I definitely need to confirm correctness of code on this end

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a great start @haileyschoelkopf, and a good set of TODO's for the next steps! Feel free to ping me (or @patrickvonplaten or @patil-suraj) if you encounter any diffculties or want to discuss ideas, very happy to help with the integration here!

src/transformers/models/bloom/modeling_flax_bloom.py Outdated Show resolved Hide resolved
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)

# TODO: make this one dense layer that is split into 3 on forward named self.query_key_value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


self.attn_dropout = nn.Dropout(config.attention_dropout)

# Scaled Softmax TODO: change this to something implemented in jax (maybe implement in __call__ for attn module?)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! We should define FlaxBloomScaledSoftmax as a standalone Flax nn.Module (similar to how done in PT) with its own setup and __call__ methods

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class has an initial implementation now!


self.input_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon, dtype=self.dtype)

# TODO: should check if this line (n_head) can be removed. if so, can be removed in pytorch impl.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks to be the case! Feel free to open a PR to remove from the PT implementation :-)

return outputs

# TODO: does this still require position_ids?
# TODO: gradient checkpointing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For gradient checkpointing details, see: #17399
Happy to lend a hand here! Have gotten it working myself in a Flax Seq2Seq Speech project
https://github.com/sanchit-gandhi/seq2seq-speech/blob/b28d0c25c8fad0f9ffa6707f91f7aba320d44a4b/models/modeling_flax_wav2vec2.py#L504


return outputs

# TODO: does this still require position_ids?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Flax, we have a design philosophy in which we compute the position_ids and losses outside the model, and so need to pass the position_ids to the model!

@haileyschoelkopf
Copy link
Contributor Author

Thanks for the helpful comments, @sanchit-gandhi ! I'll do another revision through the code fixing these and adding some more things as soon as I have time.

I think that the main thing that would need to be discussed soon is how to handle AliBi for position information, since it means that there is no specific max length for BLOOM inputs. I'm not too sure yet how to account for this given things like line 176, where the causal mask is made at the max length of the model and then sliced to get the mask for shorter sequences. (One idea I had was selecting a reasonable "starting max length", then if the model gets a longer input sequence the causal mask is extended either permanently or just for that forward pass).

@sanchit-gandhi
Copy link
Contributor

Would there be any issues in implementing it in the first of the two ways proposed (set to max_length, slicing as required)?

The problem I envision with the latter of the two approaches is that once the function is jit'd, providing a new input length to the model would result in XLA having to recompile. Each time XLA sees a new input shape, it needs to trace out (compile) the function again. So if we provide a new input shape for each forward pass, XLA will recompile every time (very slow)! The performance benefits of jit'ing a function come when we re-use a function that has already been jit'd, meaning we should try and use fixed input shapes where possible.

@haileyschoelkopf
Copy link
Contributor Author

Yeah, the recompilation is definitely something to try to avoid!

But the issue is that the bigscience/bloom config doesn't have any seq_length attribute (but bigscience/bloom-1b3 does--4096) and we want BLOOM to be able to handle sequences as long as a user wants since AliBi allows generalization to longer sequences. We could maybe just choose a reasonable default max_length, and then if the user passes a sequence that's too long, permanently double the size of the causal mask--this would allow for fewer recompilations, hopefully.

But I think we should keep the possibility open to using the model on very long sequences without problems--I don't know if any other models in Transformers use AliBi embeddings yet so that's a unique benefit of this model.

@sanchit-gandhi
Copy link
Contributor

Let's go with that to start - we can iterate and find an optimal solution as we progress. There's also the option of asking on one of the JAX/Flax Forums to see if the framework authors have any ideas if we're stuck!

You're right, this will be the first JAX/Flax model in Transformers to use AliBi embeddings! Will be very cool having a model with no theoretical max_len!

@patrickvonplaten
Copy link
Contributor

Actually, I don't see a big problem with computing the position_ids for the embeddings on the fly if they depend only on the input length of input_ids
In general whenever the user passes a different input length of input_ids to the model will have to be recompiled it anyways so I don't see an issue with generating the position_ids and the causal_mask from the input_ids either no?

@patrickvonplaten
Copy link
Contributor

Or am I misunderstanding something here?

@haileyschoelkopf
Copy link
Contributor Author

If just generating the causal mask at every forward pass is acceptable and wouldn't incur a speed penalty, then that should work fine!

And yes, I don't think that we need to pass position_ids into the model, and we can just compute the alibi embedding within the forward pass (the pytorch implementation does this.)

sorry for the delay on this--I'll work on it in the next 2 days.

@patrickvonplaten
Copy link
Contributor

If just generating the causal mask at every forward pass is acceptable and wouldn't incur a speed penalty, then that should work fine!

And yes, I don't think that we need to pass position_ids into the model, and we can just compute the alibi embedding within the forward pass (the pytorch implementation does this.)

sorry for the delay on this--I'll work on it in the next 2 days.

Great! Yeah, I just talked to @sanchit-gandhi offline - I think what we want to do here to only recompile when the model has to be recompiled anyways which translates into doing the folowing:

Allow position_ids to be passed but default them to None . If None they will be computed on the fly depending on the shape of input_ids and the values of attention_mask (the same would hold true for the causal_mask). Let me know if this doesn't make sense @haileyschoelkopf or if you have any other questions, more than happy to help :-)

@sanchit-gandhi
Copy link
Contributor

Hey @haileyschoelkopf! This looks good with regards to the fused key-query-value matmuls in faddb8d! Just as a heads-up, for gradient checkpointing, you can follow the PR at #17843. Feel free to reach out if there's anything you wish to discuss, very happy to help with any questions!

@younesbelkada younesbelkada mentioned this pull request Jul 5, 2022
3 tasks
@haileyschoelkopf
Copy link
Contributor Author

haileyschoelkopf commented Jul 5, 2022

Added gradient checkpointing, thanks for the pointer @sanchit-gandhi !

Sorry that I haven't been able to push things forward on this PR faster, ended up being busier the past few weeks than expected... EDIT: saw the other PR. @younesbelkada , FYI, there is gradient checkpointing code on this PR now if you need it.

@huggingface huggingface deleted a comment from github-actions bot Jul 29, 2022
@sanchit-gandhi
Copy link
Contributor

Thank you @haileyschoelkopf for jumping on this so quickly and getting the structure for the model in place! This PR was completed in #18022

Let me know if there's anything else you'd like to have a go at adding in JAX/Flax! Or if you'd like to have a go at porting another model to JAX/Flax I can make some suggestions!

@haileyschoelkopf
Copy link
Contributor Author

Thanks so much for all the helpful comments @sanchit-gandhi on this PR and apologies I wasn't able to iterate quicker on it!

If I have more time to add another JAX model I'll ping you for sure :)

@patrickvonplaten
Copy link
Contributor

Very sorry that we rushed this PR so much @haileyschoelkopf! Very much looking forward to other PRs if you'd like :-)

@haileyschoelkopf
Copy link
Contributor Author

Of course, will ping you if so :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants