Skip to content

Commit a94b467

Browse files
vllmellmqli88
authored andcommitted
[FEAT][ROCm]: Support AITER MLA (vllm-project#15893)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: qli88 <qiang.li2@amd.com> Signed-off-by: Frieda (Jingying) Huang <jingyingfhuang@gmail.com>
1 parent f3bda64 commit a94b467

File tree

9 files changed

+667
-29
lines changed

9 files changed

+667
-29
lines changed

tests/kernels/test_attention_selector.py

Lines changed: 128 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,45 +19,152 @@ def clear_cache():
1919
_cached_get_attn_backend.cache_clear()
2020

2121

22-
@pytest.mark.parametrize(
23-
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
22+
# Define MLA and non-MLA backends separately
23+
DEVICE_MLA_BACKENDS = {
24+
"cuda": ["TRITON_MLA", "FLASHMLA"],
25+
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
26+
"cpu": [],
27+
}
28+
29+
DEVICE_REGULAR_ATTN_BACKENDS = {
30+
"cuda": ["XFORMERS", "FLASHINFER"],
31+
"hip": ["ROCM_FLASH"],
32+
"cpu": ["TORCH_SDPA"],
33+
}
34+
35+
DEVICE_MLA_BLOCK_SIZES = {
36+
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
37+
"hip": [16, 1], # HIP requires special handling for block_size=1
38+
"cpu": [16] # CPU uses fixed block size from test cases
39+
}
40+
41+
42+
def generate_params():
43+
params = []
44+
for use_mla in [True, False]:
45+
for device in ["cuda", "hip", "cpu"]:
46+
backends = DEVICE_MLA_BACKENDS[
47+
device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device]
48+
for name in backends:
49+
block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [
50+
16
51+
]
52+
for block_size in block_sizes:
53+
params.append(
54+
pytest.param(
55+
device,
56+
name,
57+
use_mla,
58+
block_size,
59+
id=
60+
f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}"
61+
))
62+
return params
63+
64+
65+
@pytest.mark.parametrize("device, name, use_mla, block_size",
66+
generate_params())
2467
@pytest.mark.parametrize("use_v1", [True, False])
25-
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
2668
def test_env(
69+
device: str,
2770
name: str,
71+
use_mla: bool,
72+
block_size: int,
2873
use_v1: bool,
29-
device: str,
3074
monkeypatch: pytest.MonkeyPatch,
3175
):
32-
"""Test that the attention selector can be set via environment variable.
33-
Note that we do not test FlashAttn because it is the default backend.
34-
"""
35-
76+
"""Test attention backend selection with valid device-backend pairs."""
3677
with monkeypatch.context() as m:
3778
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
3879
m.setenv(STR_BACKEND_ENV_VAR, name)
80+
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
3981

4082
if device == "cpu":
4183
with patch("vllm.attention.selector.current_platform",
4284
CpuPlatform()):
4385
backend = get_attn_backend(16, torch.float16, torch.float16,
44-
16, False)
86+
block_size, False)
4587
assert backend.get_name() == "TORCH_SDPA"
88+
4689
elif device == "hip":
4790
with patch("vllm.attention.selector.current_platform",
4891
RocmPlatform()):
49-
backend = get_attn_backend(16, torch.float16, torch.float16,
50-
16, False)
51-
EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
52-
assert backend.get_name() == EXPECTED
53-
else:
54-
if name in ["XFORMERS", "FLASHINFER"]:
55-
with patch("vllm.attention.selector.current_platform",
56-
CudaPlatform()):
57-
backend = get_attn_backend(16, torch.float16,
58-
torch.float16, 16, False)
59-
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
60-
assert backend.get_name() == EXPECTED
92+
if use_mla:
93+
# Validate HIP MLA backend-block_size combinations
94+
valid_combination = (
95+
(name == "TRITON_MLA" and block_size != 1)
96+
or (name == "ROCM_AITER_MLA" and block_size == 1))
97+
98+
if valid_combination:
99+
backend = get_attn_backend(16,
100+
torch.float16,
101+
torch.float16,
102+
block_size,
103+
False,
104+
use_mla=use_mla)
105+
assert backend.get_name() == name
106+
else:
107+
with pytest.raises(ValueError) as exc_info:
108+
get_attn_backend(16,
109+
torch.float16,
110+
torch.float16,
111+
block_size,
112+
False,
113+
use_mla=use_mla)
114+
assert f"The selected backend, {name}" in str(
115+
exc_info.value)
116+
else:
117+
backend = get_attn_backend(16,
118+
torch.float16,
119+
torch.float16,
120+
block_size,
121+
False,
122+
use_mla=use_mla)
123+
expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
124+
assert backend.get_name() == expected
125+
126+
elif device == "cuda":
127+
with patch("vllm.attention.selector.current_platform",
128+
CudaPlatform()):
129+
if use_mla:
130+
if name == "FLASHMLA" and block_size == 64:
131+
from vllm.attention.backends.flashmla import (
132+
is_flashmla_supported)
133+
134+
# only on cuda platforms with specific capability.
135+
is_supported, _ = is_flashmla_supported()
136+
137+
if not is_supported:
138+
# if platform is not supported then skip this case.
139+
pytest.skip()
140+
else:
141+
backend = get_attn_backend(16,
142+
torch.float16,
143+
torch.float16,
144+
block_size,
145+
False,
146+
use_mla=use_mla)
147+
expected = f"{name}_VLLM_V1" if use_v1 else name
148+
assert backend.get_name() == expected
149+
else:
150+
backend = get_attn_backend(16,
151+
torch.float16,
152+
torch.float16,
153+
block_size,
154+
False,
155+
use_mla=use_mla)
156+
expected = ("TRITON_MLA_VLLM_V1"
157+
if use_v1 else "TRITON_MLA")
158+
assert backend.get_name() == expected
159+
else:
160+
backend = get_attn_backend(16,
161+
torch.float16,
162+
torch.float16,
163+
block_size,
164+
False,
165+
use_mla=use_mla)
166+
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
167+
assert backend.get_name() == expected
61168

62169

63170
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):

tests/kernels/test_rocm_attention_selector.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,34 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
2828
assert (backend.get_name() == "ROCM_FLASH"
2929
or backend.get_name() == "TRITON_ATTN_VLLM_V1")
3030

31-
# mla test for deepseek related
31+
# MLA test for deepseek related
32+
33+
# change the attention backend to triton MLA
34+
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
35+
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
36+
False, True)
37+
assert backend.get_name() == "TRITON_MLA"
38+
39+
# If attention backend is None
40+
# If use_mla is true
41+
# The selected backend is triton MLA
42+
m.setenv(STR_BACKEND_ENV_VAR, None)
3243
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
3344
False, True)
3445
assert backend.get_name() == "TRITON_MLA"
46+
47+
# change the attention backend to AITER MLA
48+
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
49+
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
50+
False, True)
51+
assert backend.get_name() == "ROCM_AITER_MLA"
52+
53+
# If attention backend is None
54+
# If use_mla is true
55+
# If VLLM_ROCM_USE_AITER is enabled
56+
# The selected backend is ROCM_AITER_MLA
57+
m.setenv(STR_BACKEND_ENV_VAR, None)
58+
m.setenv("VLLM_ROCM_USE_AITER", "1")
59+
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
60+
False, True)
61+
assert backend.get_name() == "ROCM_AITER_MLA"

vllm/attention/backends/mla/common.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -711,12 +711,24 @@ def advance_step(self,
711711
self.seq_lens[i] += 1
712712
self.max_decode_seq_len = max(self.seq_lens)
713713

714+
self._ops_advance_step(num_seqs=num_seqs,
715+
num_queries=num_queries,
716+
block_size=block_size,
717+
input_tokens=model_input.input_tokens,
718+
sampled_token_ids=sampled_token_ids,
719+
input_positions=model_input.input_positions)
720+
721+
def _ops_advance_step(self, num_seqs: int, num_queries: int,
722+
block_size: int, input_tokens: torch.Tensor,
723+
sampled_token_ids: torch.Tensor,
724+
input_positions: torch.Tensor) -> None:
725+
# here we use advance_step_flashinfo to update the paged_kv_* tensors
714726
ops.advance_step_flashattn(num_seqs=num_seqs,
715727
num_queries=num_queries,
716728
block_size=block_size,
717-
input_tokens=model_input.input_tokens,
729+
input_tokens=input_tokens,
718730
sampled_token_ids=sampled_token_ids,
719-
input_positions=model_input.input_positions,
731+
input_positions=input_positions,
720732
seq_lens=self.seq_lens_tensor,
721733
slot_mapping=self.slot_mapping,
722734
block_tables=self.block_tables)
@@ -727,6 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
727739
NOTE: Please read the comment at the top of the file before trying to
728740
understand this class
729741
"""
742+
BLOCK_TABLE_EXTENDER: list[list[int]] = []
730743

731744
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
732745
self.input_builder = input_builder
@@ -877,8 +890,10 @@ def build(self, seq_lens: List[int], query_lens: List[int],
877890
num_seqs = len(seq_lens)
878891
if use_captured_graph:
879892
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
880-
self.block_tables.extend([] * cuda_graph_pad_size)
893+
self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER *
894+
cuda_graph_pad_size)
881895
num_decode_tokens = batch_size - self.num_prefill_tokens
896+
882897
block_tables = self._get_graph_runner_block_tables(
883898
num_seqs, self.block_tables)
884899
else:

0 commit comments

Comments
 (0)