Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models #9559

Open
wants to merge 104 commits into
base: main
Choose a base branch
from

Conversation

sroy745
Copy link
Contributor

@sroy745 sroy745 commented Oct 21, 2024

This PR adds support for flash attention kernel for encoder decoder models. For encoder-decoder models with dtype=bfloat16 the default backend choice is now FlashAttention instead of XFormers. However for llama-3.2-11b-vision-instruct we still use the Xformers backend even with dtype=bfloat16 because the model implementation (models/mllama.py) has dependency on PagedAttention.

For adding this support, we make the following changes in this pr

  1. Updated flash_attn.py to add support for encoder-decoder models. Also updated the tests in tests/kernels/test_encoder_decoder.py to test FlashAttention backend along with the existing XFormers backend.
  2. Updated test_bart.py , test_florence2.py and encoder_decoder/test_e2e_correctness.py to run with both backends.
  3. Moved some methods from xformers.py to backend/utils.py so that they can be reused in both xformers.py and flash_attn.py
  4. Updated the checks in worker/enc_dec_model_runner.py to now check that the backend is either FlashAttention or XFormers instead of only XFormers as we do currently.
  5. Updated models/bart.py to invoke attention.forward with query of shape [num_tokens, hidden_size]. Currently it was invoking the forward with a query of shape [num_tokens, num_heads, head_size] which is not default.

#7366

sroy745 added 30 commits May 28, 2024 20:39
@heheda12345
Copy link
Collaborator

I think it won't be too difficult to support mllama+flashattention. @sroy745 ping me if you need more background information. I'll go through the code later today.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work. I left some comments, mainly about simplifying the logic of different AttentionType.

vllm/attention/backends/flash_attn.py Outdated Show resolved Hide resolved
vllm/attention/backends/flash_attn.py Outdated Show resolved Hide resolved
kv_cache[0],
kv_cache[1],
updated_slot_mapping.flatten()
if updated_slot_mapping is not None else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do not need this branch. When decode phase & attn_type==encoder_decoder, the key & value should be None and we will not enter the true branch of if (attn_type != AttentionType.ENCODER) and (key is not None) and (value is not None):
I think we can add some comment to explain it and remove the branches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is not needed as you mentioned. I had this because I was getting a mypy type check error. Removed this condition and instead added an ignore annotation.

raise AttributeError(f"Invalid attention type {str(attn_type)}")


def _get_num_prefill_encode_decode_tokens(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this function is the same with xformer backend, can we move it to utils.py and calls it in both flashattn & xformer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved it to utils and using it now in xformers backend. However there is a slight diff in the way I set num_encoder_tokens when attention_type = DECODER as you have noted in your other comment.

Comment on lines 891 to 897
if (attn_type == AttentionType.ENCODER or \
attn_type == AttentionType.ENCODER_DECODER):
key = key[:num_encoder_tokens]
value = value[:num_encoder_tokens]
else:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you set def _get_num_prefill_encode_decode_tokens() if attn_type == AttentionType.DECODER: num_encoder_tokens=attn_metadata.num_prefill_tokens like what xformer is doing, you can avoid this branch. The num_encoder_token is similar to q_len.

And not sure if it is correct to remove these lines and pass the full key and value to the attention kernel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am slightly in favor of this because when the AttentionType is DECODER I was thinking it might be more intuitive to split the key based on prefill_tokens rather than encoder_tokens. encoder_tokens seem more relevant when the AttentionType is ENCODER or ENCODER_DECODER.

I modified the xformers code also to have the same split since I am now using the common get_num_prefill_encode_decode_tokens. Please let me know your preference and I will update this accordingly in both the backends.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the goal of _get_num_prefill_encode_decode_tokens is to unify different attention types and avoid branching in the following code path as much as possible. What about renaming the three variables to make it clearer, e.g., num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done renamed to num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens and removed the if branches in the backends.

vllm/attention/backends/flash_attn.py Outdated Show resolved Hide resolved
vllm/attention/backends/flash_attn.py Outdated Show resolved Hide resolved
vllm/attention/backends/flash_attn.py Outdated Show resolved Hide resolved
vllm/worker/enc_dec_model_runner.py Outdated Show resolved Hide resolved
@sroy745
Copy link
Contributor Author

sroy745 commented Oct 30, 2024

Thanks for the review. Addressed your comments. PTAL.

@heheda12345
Copy link
Collaborator

Thanks for your fix. I left some comments.

@sroy745
Copy link
Contributor Author

sroy745 commented Nov 1, 2024

Thanks for the review. Addressed comments. PTAL

Copy link

mergify bot commented Nov 1, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @sroy745 please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 1, 2024
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for your hardwork on this. Looking forward for the follow-up PRs for test_encoder_decoder_attention and mllama support.

Also CC @WoosukKwon. You may need to sync this PR to v1 later.

@mergify mergify bot removed the needs-rebase label Nov 1, 2024
@sroy745
Copy link
Contributor Author

sroy745 commented Nov 1, 2024

@ywang96 PTAL when you get a chance. PR has been LG'ed by @heheda12345 , is synced to head and all tests are passing.

Copy link

mergify bot commented Nov 1, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @sroy745 please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 1, 2024
@mergify mergify bot removed the needs-rebase label Nov 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants