Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiQueryAttention & GroupedQueryAttention #18402

Closed
awsaf49 opened this issue Sep 18, 2023 · 5 comments
Closed

Add MultiQueryAttention & GroupedQueryAttention #18402

awsaf49 opened this issue Sep 18, 2023 · 5 comments
Assignees

Comments

@awsaf49
Copy link
Contributor

awsaf49 commented Sep 18, 2023

MultiQueryAttention (MQA) [Used in Falcon LLM] and GroupedQueryAttention (GQA) [Used in Llama 2 LLM] are alternatives to MultiHeadAttention (MHA) but they are a lot faster. Here's the speed comparison in my naive implementation,

===================================
          TensorFlow - GPU
===================================
Attention                 : 0.004 sec
Multi Head Attention      : 0.035 sec
Multi Query Attention     : 0.018 sec ( 50.17% faster than MHA )
Grouped Query Attention   : 0.030 sec ( 15.02% faster than MHA )

I think it would be nice to have these layers in keras-core.

Reference Papers:

@mattdangerw
Copy link
Member

Probably easiest to just write GroupedQueryAttention, and consider MultiQueryAttention a special case of it. We can expose MultiQueryAttention, as subclass of GroupedQueryAttention that sets a single init value num_key_value_heads=1 on the base class. Somewhat similar to our AdamW class with weight_decay.

image

This is also some discussion in #18423 for more context. I definitely think adding support here makes sense. And probably clearer to have this standalone from MultiHeadAttention instead of just throwing more parameters at that (already quite complex) class.

Thanks for filing!

@awsaf49
Copy link
Contributor Author

awsaf49 commented Sep 19, 2023

We can expose MultiQueryAttention, as subclass of GroupedQueryAttention that sets a single init value num_key_value_heads=1 on the base class.

I was thinking the same thing.

@awsaf49
Copy link
Contributor Author

awsaf49 commented Sep 19, 2023

Should I open a PR for this??

@mattdangerw
Copy link
Member

Sounds good! Thank you!

@fchollet fchollet transferred this issue from keras-team/keras-core Sep 22, 2023
@awsaf49 awsaf49 closed this as completed Oct 22, 2023
@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants