Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flex attention support with arbitrary 4d mask for LlamaModel #33898

Open
alex-hh opened this issue Oct 2, 2024 · 6 comments
Open

Flex attention support with arbitrary 4d mask for LlamaModel #33898

alex-hh opened this issue Oct 2, 2024 · 6 comments
Labels
Feature request Request for a new feature Flash Attention

Comments

@alex-hh
Copy link

alex-hh commented Oct 2, 2024

Feature request

It would be nice to combine the benefits of flex attention and 4d masking.

Perhaps the llama model could be a first case, allowing arbitrary 4d masks to be handled via an efficient flex attention path.

Motivation

Custom attention masking/biasing patterns lead to considerable improvements in flexibility, and are central to state-of-the-art models like AlphaFold and recent multimodal models.

4d attention masking in Transformers already provides the user with the flexibility to define custom biases, however performance is limited by the fact that 4d masking is incompatible with flash attention.

A 4d mask supporting flex attention attention path would retain full flexibility while maintaining performance. As far as I understand, nothing comparable exists in Transformers currently.

Your contribution

Not very familiar with what this would take.

@alex-hh alex-hh added the Feature request Request for a new feature label Oct 2, 2024
@LysandreJik
Copy link
Member

Thanks for your request! cc @ArthurZucker

@ArthurZucker
Copy link
Collaborator

Sounds super good! If a contributor wants to have a go happy to review, otherwise planned for sure! See #32877

@alex-hh
Copy link
Author

alex-hh commented Oct 4, 2024

@ArthurZucker I might be interested in having a go - I think it might involve creating a separate attention_bias argument (corresponding to current 4d attention mask and flex attention's score_mod); and allowing a binary/boolean attention_mask to determine mask_mod. Would you support this kind of change?

This way it seems like it should be possible to generate appropriate score_mod and mask_mod functions automatically I think, which might be nicer API wise than allowing the user to pass these explicitly, but I'm not sure and would be interested to know if you have thoughts on this before attempting anything? Don't know if it's a bit suboptimal to explicitly instantiate attention masks then convert back to flex attention functions for example.

@ArthurZucker
Copy link
Collaborator

In general we most probably won't allow passing the score_mod and mask_mod, we might allow defining it for Gemma2Attention for example: this way it's easy to overwrite.

Automatic creation sounds a bit tricky + we'd rather have reference implementation for each model IMO! 🤗

@alex-hh
Copy link
Author

alex-hh commented Oct 4, 2024

Ok makes sense I will maybe have a think about a reference implementation for llama.

By automatic creation I was specifically thinking of (actually an example in the flex attention post):

def create_bias_mod(bias)
    def bias_mod(score, b, h, q_idx, kv_idx):
        return score + bias[b, h, q_idx, kv_idx]
    return bias_mod

I don't know enough to understand how useful this would be performance-wise

@ArthurZucker
Copy link
Collaborator

Will link a PR to this issue if I start working on this before then! 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature Flash Attention
Projects
None yet
Development

No branches or pull requests

3 participants