-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models #9559
base: main
Are you sure you want to change the base?
Conversation
Pull from head
I think it won't be too difficult to support mllama+flashattention. @sroy745 ping me if you need more background information. I'll go through the code later today. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work. I left some comments, mainly about simplifying the logic of different AttentionType.
kv_cache[0], | ||
kv_cache[1], | ||
updated_slot_mapping.flatten() | ||
if updated_slot_mapping is not None else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we do not need this branch. When decode phase & attn_type==encoder_decoder, the key & value should be None and we will not enter the true branch of if (attn_type != AttentionType.ENCODER) and (key is not None) and (value is not None):
I think we can add some comment to explain it and remove the branches
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is not needed as you mentioned. I had this because I was getting a mypy type check error. Removed this condition and instead added an ignore annotation.
raise AttributeError(f"Invalid attention type {str(attn_type)}") | ||
|
||
|
||
def _get_num_prefill_encode_decode_tokens( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this function is the same with xformer backend, can we move it to utils.py and calls it in both flashattn & xformer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved it to utils and using it now in xformers backend. However there is a slight diff in the way I set num_encoder_tokens when attention_type = DECODER as you have noted in your other comment.
if (attn_type == AttentionType.ENCODER or \ | ||
attn_type == AttentionType.ENCODER_DECODER): | ||
key = key[:num_encoder_tokens] | ||
value = value[:num_encoder_tokens] | ||
else: | ||
key = key[:num_prefill_tokens] | ||
value = value[:num_prefill_tokens] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you set def _get_num_prefill_encode_decode_tokens() if attn_type == AttentionType.DECODER: num_encoder_tokens=attn_metadata.num_prefill_tokens
like what xformer is doing, you can avoid this branch. The num_encoder_token
is similar to q_len
.
And not sure if it is correct to remove these lines and pass the full key and value to the attention kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am slightly in favor of this because when the AttentionType is DECODER I was thinking it might be more intuitive to split the key based on prefill_tokens rather than encoder_tokens. encoder_tokens seem more relevant when the AttentionType is ENCODER or ENCODER_DECODER.
I modified the xformers code also to have the same split since I am now using the common get_num_prefill_encode_decode_tokens. Please let me know your preference and I will update this accordingly in both the backends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the goal of _get_num_prefill_encode_decode_tokens
is to unify different attention types and avoid branching in the following code path as much as possible. What about renaming the three variables to make it clearer, e.g., num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done renamed to num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens and removed the if branches in the backends.
Thanks for the review. Addressed your comments. PTAL. |
Thanks for your fix. I left some comments. |
Thanks for the review. Addressed comments. PTAL |
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for your hardwork on this. Looking forward for the follow-up PRs for test_encoder_decoder_attention
and mllama support.
Also CC @WoosukKwon. You may need to sync this PR to v1 later.
@ywang96 PTAL when you get a chance. PR has been LG'ed by @heheda12345 , is synced to head and all tests are passing. |
This pull request has merge conflicts that must be resolved before it can be |
This PR adds support for flash attention kernel for encoder decoder models. For encoder-decoder models with dtype=bfloat16 the default backend choice is now FlashAttention instead of XFormers. However for llama-3.2-11b-vision-instruct we still use the Xformers backend even with dtype=bfloat16 because the model implementation (models/mllama.py) has dependency on PagedAttention.
For adding this support, we make the following changes in this pr
#7366