-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[FEAT][ROCm]: Support AITER MLA #15893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT][ROCm]: Support AITER MLA #15893
Conversation
Co-authored-by: qli88 <qiang.li2@amd.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: qli88 <qiang.li2@amd.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: ArthurAMD yajhuang@amd.com Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
… if/else statements Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
"cpu": [], | ||
} | ||
|
||
DEVICE_NON_MLA_BACKENDS = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: lets just call this DEVICE_REGULAR_ATTN_BACKENDS
instead of MLA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson This has been addressed. Thanks.
self.block_tables.extend([] * cuda_graph_pad_size) | ||
num_decode_tokens = batch_size - self.num_prefill_tokens | ||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) | ||
self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why relocate these lines? Also can you please explain to me why we now need self.__class__.BLOCK_TABLE_EXTENDER
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.class.BLOCK_TABLE_EXTENDER this is a static class variable since common had this hardcoded as "[]" in the line below:
self.block_tables.extend([] * cuda_graph_pad_size)
cuz in AiterMLAMetadataBuilder for capturing graph we need "[[]]" instead of "[]", by eliminating the hardcoded extender into class variable allows the subclass to implement itsown extender value or just inherit from parent.
to review this file is better to open the entire file, as the github interface is not representative enough what has been changed.
overall as explained in the PR descript for the summary of the changes to accommodate AITER MLA implementation and reduce the code duplication in the subclass some refactoring has been made in certain function to allow more flexibility in subclasses.
Implementation
ROCM_AITER_MLA
is introduced as an additional attention backend type for ROCm platform.
To support this backend the modules below are implemented vllm/attention/backends/rocm_aiter_mla.py
AiterMLABackend
inherits fromMLACommonBackend
.AiterMLAMetadata
inherits fromMLACommonMetadata
: note that from this class theadvance_step
function utilizesadvance_step_flashinfer
function from VLLM cutom ops.AiterMLAMetadataBuilder
inherits fromMLACommonMetadataBuilder
.AiterMLAState
inherits fromMLACommonState
.AiterMLAImpl
class inherits fromCommonMLAImpl
:
Important notes for this class:flash_attn_varlen_func
(FA function) used in this class is AITER FA implementation (flash_attn_varlen_func
from AITER package)._forward_decode
function in this class usesmla_decode_fwd
kernel from AITER package.
The MLACommon module has been refactored to reduce code duplication in its subclasses. This was achieved by separating the attention output computation into two dedicated functions named as _get_fwd_prefill_attn_output
and _get_prefill_ctx_attn_output
that are used in _compute_prefill_context
and _forward_prefill
function respectively.
Another refactoring is placed in advance_step
function by separating out the pre assertion checks before calling an advance_step method to allow advance_step
function to be overridden without code duplication in its subclasses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson after resolving merge conflict for this file. the only changes in common.py
are as below:
-
invoking
ops.advance_step_flashattn
in a separate function_ops_advance_step
that can be overridden by subclass that is used inadvance_step
function. -
use of "static" class variable as
BLOCK_TABLE_EXTENDER: list[list[int]] = []
that is used to updateself.block_tables
in graph mode which eliminates the hardcoded "[]" self.block_tables.extend([] * cuda_graph_pad_size) to allow flexibility for the subclasses to override this update based on the class variable.
5e6ed9a
to
6e48433
Compare
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the contribution
vllm/platforms/rocm.py
Outdated
else: | ||
raise ValueError( | ||
f" The selected backend, {selected_backend.name}," | ||
"does not support block size {block_size}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"does not support block size {block_size}.") | |
f"does not support block size {block_size}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for pointing out this. Have added the the suggestion.
vllm/platforms/rocm.py
Outdated
else: | ||
raise ValueError( | ||
f" The selected backend, {selected_backend.name}," | ||
"does not support block size {block_size}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"does not support block size {block_size}." | |
f"does not support block size {block_size}." |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
… handle wrong backend selection when MLA is requested. Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: qli88 <qiang.li2@amd.com> Signed-off-by: Frieda (Jingying) Huang <jingyingfhuang@gmail.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: qli88 <qiang.li2@amd.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: qli88 <qiang.li2@amd.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: qli88 <qiang.li2@amd.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: qli88 <qiang.li2@amd.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Description
This PR integrates the AITER ops to improve the MLA functionality from AITER flash_attn_varlen_func and AITER mla_decode_fwd into vLLM, and will allow any up-coming optimizations in AITER kernel to be directly used and evaluated within the vLLM framework.
Implementation
ROCM_AITER_MLA
is introduced as an additional attention backend type for ROCm platform.To support this backend the modules below are implemented
vllm/attention/backends/rocm_aiter_mla.py
AiterMLABackend
inherits fromMLACommonBackend
.AiterMLAMetadata
inherits fromMLACommonMetadata
: note that from this class theadvance_step
function utilizesadvance_step_flashinfer
function from VLLM cutom ops.AiterMLAMetadataBuilder
inherits fromMLACommonMetadataBuilder
.AiterMLAState
inherits fromMLACommonState
.AiterMLAImpl
class inherits fromCommonMLAImpl
:Important notes for this class:
flash_attn_varlen_func
(FA function) used in this class is AITER FA implementation (flash_attn_varlen_func
from AITER package)._forward_decode
function in this class usesmla_decode_fwd
kernel from AITER package.The MLACommon module has been refactored to reduce code duplication in its subclasses for
advance_step
function by invoking ops attentionops.advance_step_flashattn
in a separate function_ops_advance_step
that can be overridden by subclass.To enable the backed the environment variable
VLLM_ATTN_BACKEND
can be set toROCM_AITER_MLA
.In case that the backend is not specified the
rocm.py
invllm/platforms
verifies whetherVLLM_ROCM_USE_AITER
andVLLM_ROCM_USE_AITER_MLA
are both enabled or not to utilize this backend. Otherwise the selected backend isTRITON_MLA
.Important Notes:
block_size=1
and the variablemax_model_len=32768
has to be set.Testing
In order to ensure correct attention backend is selected.
MLA backend env backends has been added into the test cases in
tests/kernels/test_attention_selector.py
Performance
Benchmark Serving Results Comparison
Lm Eval Results
Envrionment Setting
Updates in Dockerfile.rocm_base
Added AITER Package:
Additional Notes installing AITER
When setting up AITER, it is crucial to use the command git clone --recursive. This is because the package depends on a third-party package (Composable Kernel).
For building and installing the AITER Python package, you must use the PREBUILD_KERNELS=1 flag along with the command python3 setup.py develop. This ensures that all kernels in the AITER package are built successfully.
The following branches were used as references for this integration:
https://github.com/ROCm/vllm/tree/dsv3_dev
https://github.com/ROCm/vllm/tree/aiter_integration_final
https://github.com/ROCm/vllm/tree/deepseek_v3_dev