Skip to content

Commit 6cae1e5

Browse files
[ROCm][MLA] Support block-size > 1 for AITER MLA backend (#27224)
Signed-off-by: ganyi <ygan@amd.com> Co-authored-by: wuhuikx <hattie.wu@amd.com>
1 parent 80c9275 commit 6cae1e5

File tree

3 files changed

+34
-24
lines changed

3 files changed

+34
-24
lines changed

tests/kernels/attention/test_attention_selector.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,6 @@ def test_env(
104104
16, torch.float16, None, block_size, use_mla=use_mla
105105
)
106106
assert f"The selected backend, {name}" in str(exc_info.value)
107-
elif name == "ROCM_AITER_MLA" and block_size != 1:
108-
# ROCM_AITER_MLA only supports block_size == 1
109-
with pytest.raises(ValueError) as exc_info:
110-
get_attn_backend(
111-
16, torch.float16, None, block_size, use_mla=use_mla
112-
)
113-
assert f"The selected backend, {name}" in str(exc_info.value)
114107
else:
115108
# Valid backend-block_size combination
116109
backend = get_attn_backend(

vllm/platforms/rocm.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -252,16 +252,9 @@ def get_attn_backend_cls(
252252
f"does not support block size {block_size}."
253253
)
254254
if selected_backend == _Backend.ROCM_AITER_MLA:
255-
if block_size == 1:
256-
logger.info("Using AITER MLA backend.")
257-
return (
258-
"vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
259-
)
260-
raise ValueError(
261-
f" The selected backend, {selected_backend.name},"
262-
f"does not support block size {block_size}."
263-
"(currently only supports block size 1)"
264-
)
255+
logger.info("Using AITER MLA backend.")
256+
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
257+
265258
raise ValueError(
266259
f" The selected backend, {selected_backend.name},"
267260
f"is not MLA type while requested for MLA backend."

vllm/v1/attention/backends/mla/rocm_aiter_mla.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,6 @@ def __init__(
7878
super().__init__(
7979
kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
8080
)
81-
assert self.kv_cache_spec.block_size == 1, (
82-
"AITER MLAonly supports block size 1."
83-
)
8481

8582
self.compilation_config = vllm_config.compilation_config
8683
max_num_pages_per_req = cdiv(
@@ -94,6 +91,11 @@ def __init__(
9491
# so we can only use the persistent buffer if a cudagraph is actually
9592
# being used.
9693
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
94+
self.block_table_remapping = torch.zeros(
95+
[max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size],
96+
dtype=torch.int32,
97+
device=device,
98+
)
9799
self.paged_kv_indptr = torch.zeros(
98100
max_num_reqs + 1, dtype=torch.int32, device=device
99101
)
@@ -119,13 +121,29 @@ def _build_decode(
119121
dcp_tot_seq_lens_device: torch.Tensor | None,
120122
) -> AiterMLADecodeMetadata:
121123
page_size = self.kv_cache_spec.block_size
122-
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
123124
device = self.device
124125
num_reqs = seq_lens_device.size(0)
126+
bs, _ = block_table_tensor.shape
127+
block_table_tensor = (
128+
block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size
129+
)
130+
block_table_tensor = (
131+
block_table_tensor
132+
+ torch.arange(
133+
0,
134+
page_size,
135+
device=block_table_tensor.device,
136+
dtype=block_table_tensor.dtype,
137+
)[None, None, :]
138+
)
139+
block_table_tensor = block_table_tensor.view(bs, -1)
125140

141+
# after remapping, we assume the block size already equals to 1
142+
143+
max_blk_size_per_req = block_table_tensor.shape[-1]
126144
mask = torch.arange(
127145
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
128-
).unsqueeze(0) < block_table_bounds.unsqueeze(1)
146+
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
129147
paged_kv_indices = block_table_tensor[mask]
130148

131149
paged_kv_last_page_len = seq_lens_device % page_size
@@ -135,13 +153,19 @@ def _build_decode(
135153

136154
paged_kv_indptr = torch.cat(
137155
[
138-
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
139-
block_table_bounds.cumsum(dim=0, dtype=torch.int32),
156+
torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
157+
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
140158
]
141159
)
142160

143161
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
144162
num_actual_pages = paged_kv_indices.size(0)
163+
self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_(
164+
block_table_tensor, non_blocking=True
165+
)
166+
block_table_tensor = self.block_table_remapping[
167+
:num_reqs, :max_blk_size_per_req
168+
]
145169

146170
self.paged_kv_indices[:num_actual_pages].copy_(
147171
paged_kv_indices, non_blocking=True

0 commit comments

Comments
 (0)