Skip to content

Conversation

vllmellm
Copy link
Contributor

@vllmellm vllmellm commented Apr 1, 2025

Description

This PR integrates the AITER ops to improve the MLA functionality from AITER flash_attn_varlen_func and AITER mla_decode_fwd into vLLM, and will allow any up-coming optimizations in AITER kernel to be directly used and evaluated within the vLLM framework.

Implementation

ROCM_AITER_MLA is introduced as an additional attention backend type for ROCm platform.
To support this backend the modules below are implemented vllm/attention/backends/rocm_aiter_mla.py

  • AiterMLABackend inherits from MLACommonBackend.
  • AiterMLAMetadata inherits from MLACommonMetadata: note that from this class the advance_step function utilizes advance_step_flashinfer function from VLLM cutom ops.
  • AiterMLAMetadataBuilder inherits from MLACommonMetadataBuilder.
  • AiterMLAState inherits from MLACommonState.
  • AiterMLAImpl class inherits from CommonMLAImpl:
    Important notes for this class:
    • flash_attn_varlen_func (FA function) used in this class is AITER FA implementation (flash_attn_varlen_func from AITER package).
    • _forward_decode function in this class uses mla_decode_fwd kernel from AITER package.

The MLACommon module has been refactored to reduce code duplication in its subclasses for advance_step function by invoking ops attention ops.advance_step_flashattn in a separate function _ops_advance_step that can be overridden by subclass.

To enable the backed the environment variable VLLM_ATTN_BACKEND can be set to ROCM_AITER_MLA.
In case that the backend is not specified the rocm.py in vllm/platforms verifies whether VLLM_ROCM_USE_AITER and VLLM_ROCM_USE_AITER_MLA are both enabled or not to utilize this backend. Otherwise the selected backend is TRITON_MLA.

Important Notes:

  • AITER MLA currently only supports block_size=1 and the variable max_model_len=32768 has to be set.
  • AITER MLA is suitable for DeepSeek models.

Testing

In order to ensure correct attention backend is selected.
MLA backend env backends has been added into the test cases in tests/kernels/test_attention_selector.py

Performance

Benchmark Serving Results Comparison

Metric Triton MLA (ROCm Flash Attention) Triton MLA (Triton Flash Attention) ROCm AITER MLA (AITER Flash Attention)
Overall Performance
Successful requests 1000 1000 1000
Benchmark duration (s) 121.13 264.31 104.67
Total input tokens 1024000 1024000 1024000
Total generated tokens 39139 39899 40681
Request throughput (req/s) 8.26 3.78 9.55
Output token throughput (tok/s) 323.13 150.96 388.66
Total Token throughput (tok/s) 8777.14 4025.23 10171.83
Time to First Token (TTFT)
Mean TTFT (ms) 55437.36 116591.00 46060.49
Median TTFT (ms) 51164.28 101109.93 42263.81
P99 TTFT (ms) 114009.92 256545.57 96858.61
Time per Output Token (TPOT) (excl. 1st)
Mean TPOT (ms) 2053.40 5737.02 2360.18
Median TPOT (ms) 713.78 1736.03 768.247
P99 TPOT (ms) 8271.61 23386.95 15863.46
Inter-token Latency (ITL)
Mean ITL (ms) 534.54 1356.84 474.84
Median ITL (ms) 106.89 109.63 106.23
P99 ITL (ms) 7176.11 17639.02 5287.23

Lm Eval Results

Tasks Version Filter n-shot Metric Value (Without AITER) Stderr (Without AITER) Value (With AITER) Stderr (With AITER)
gsm8k 3 flexible-extract 5 exact_match ↑ 0.95 ±0.05 0.95 ±0.05
gsm8k 3 strict-match 5 exact_match ↑ 0.95 ±0.05 0.95 ±0.05

Envrionment Setting

Updates in Dockerfile.rocm_base
Added AITER Package:

vllmellm and others added 7 commits March 28, 2025 08:19
Co-authored-by: qli88 <qiang.li2@amd.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: qli88 <qiang.li2@amd.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link

github-actions bot commented Apr 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link

mergify bot commented Apr 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 1, 2025
vllmellm added 3 commits April 2, 2025 04:51
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify bot removed the needs-rebase label Apr 2, 2025
vllmellm added 3 commits April 3, 2025 04:41
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm vllmellm marked this pull request as ready for review April 3, 2025 05:44
Co-authored-by: ArthurAMD yajhuang@amd.com
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
… if/else statements

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
"cpu": [],
}

DEVICE_NON_MLA_BACKENDS = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: lets just call this DEVICE_REGULAR_ATTN_BACKENDS instead of MLA

Copy link
Contributor Author

@vllmellm vllmellm Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson This has been addressed. Thanks.

self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why relocate these lines? Also can you please explain to me why we now need self.__class__.BLOCK_TABLE_EXTENDER

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.class.BLOCK_TABLE_EXTENDER this is a static class variable since common had this hardcoded as "[]" in the line below:
self.block_tables.extend([] * cuda_graph_pad_size)
cuz in AiterMLAMetadataBuilder for capturing graph we need "[[]]" instead of "[]", by eliminating the hardcoded extender into class variable allows the subclass to implement itsown extender value or just inherit from parent.

to review this file is better to open the entire file, as the github interface is not representative enough what has been changed.

overall as explained in the PR descript for the summary of the changes to accommodate AITER MLA implementation and reduce the code duplication in the subclass some refactoring has been made in certain function to allow more flexibility in subclasses.

Implementation

ROCM_AITER_MLA is introduced as an additional attention backend type for ROCm platform.
To support this backend the modules below are implemented vllm/attention/backends/rocm_aiter_mla.py

  • AiterMLABackend inherits from MLACommonBackend.
  • AiterMLAMetadata inherits from MLACommonMetadata: note that from this class the advance_step function utilizes advance_step_flashinfer function from VLLM cutom ops.
  • AiterMLAMetadataBuilder inherits from MLACommonMetadataBuilder.
  • AiterMLAState inherits from MLACommonState.
  • AiterMLAImpl class inherits from CommonMLAImpl:
    Important notes for this class:
    • flash_attn_varlen_func (FA function) used in this class is AITER FA implementation (flash_attn_varlen_func from AITER package).
    • _forward_decode function in this class uses mla_decode_fwd kernel from AITER package.

The MLACommon module has been refactored to reduce code duplication in its subclasses. This was achieved by separating the attention output computation into two dedicated functions named as _get_fwd_prefill_attn_output and _get_prefill_ctx_attn_output that are used in _compute_prefill_context and _forward_prefill function respectively.
Another refactoring is placed in advance_step function by separating out the pre assertion checks before calling an advance_step method to allow advance_step function to be overridden without code duplication in its subclasses.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson after resolving merge conflict for this file. the only changes in common.py are as below:

  • invoking ops.advance_step_flashattn in a separate function _ops_advance_step that can be overridden by subclass that is used in advance_step function.

  • use of "static" class variable as BLOCK_TABLE_EXTENDER: list[list[int]] = [] that is used to update self.block_tables in graph mode which eliminates the hardcoded "[]" self.block_tables.extend([] * cuda_graph_pad_size) to allow flexibility for the subclasses to override this update based on the class variable.

@tjtanaa tjtanaa force-pushed the aiter-mla-integration branch from 5e6ed9a to 6e48433 Compare April 17, 2025 13:17
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the contribution

@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 21, 2025
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"
"does not support block size {block_size}.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"does not support block size {block_size}.")
f"does not support block size {block_size}.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing out this. Have added the the suggestion.

else:
raise ValueError(
f" The selected backend, {selected_backend.name},"
"does not support block size {block_size}."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"does not support block size {block_size}."
f"does not support block size {block_size}."

@gshtras gshtras mentioned this pull request Apr 21, 2025
Copy link

mergify bot commented Apr 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 22, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify bot removed the needs-rebase label Apr 22, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
… handle wrong backend selection when MLA is requested.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllm-bot vllm-bot merged commit 30bc3e0 into vllm-project:main Apr 22, 2025
43 of 46 checks passed
frieda-huang pushed a commit to frieda-huang/vllm that referenced this pull request Apr 23, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: qli88 <qiang.li2@amd.com>
Signed-off-by: Frieda (Jingying) Huang <jingyingfhuang@gmail.com>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: qli88 <qiang.li2@amd.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: qli88 <qiang.li2@amd.com>
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: qli88 <qiang.li2@amd.com>
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: qli88 <qiang.li2@amd.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants