Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,6 @@ def test_env(
16, torch.float16, None, block_size, use_mla=use_mla
)
assert f"The selected backend, {name}" in str(exc_info.value)
elif name == "ROCM_AITER_MLA" and block_size != 1:
# ROCM_AITER_MLA only supports block_size == 1
with pytest.raises(ValueError) as exc_info:
get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
assert f"The selected backend, {name}" in str(exc_info.value)
else:
# Valid backend-block_size combination
backend = get_attn_backend(
Expand Down
13 changes: 3 additions & 10 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,9 @@ def get_attn_backend_cls(
f"does not support block size {block_size}."
)
if selected_backend == _Backend.ROCM_AITER_MLA:
if block_size == 1:
logger.info("Using AITER MLA backend.")
return (
"vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
)
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
"(currently only supports block size 1)"
)
logger.info("Using AITER MLA backend.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501

raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend."
Expand Down
38 changes: 31 additions & 7 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ def __init__(
super().__init__(
kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
)
assert self.kv_cache_spec.block_size == 1, (
"AITER MLAonly supports block size 1."
)

self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(
Expand All @@ -94,6 +91,11 @@ def __init__(
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.block_table_remapping = torch.zeros(
[max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size],
dtype=torch.int32,
device=device,
)
self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
)
Expand All @@ -119,13 +121,29 @@ def _build_decode(
dcp_tot_seq_lens_device: torch.Tensor | None,
) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
device = self.device
num_reqs = seq_lens_device.size(0)
bs, _ = block_table_tensor.shape
block_table_tensor = (
block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size
)
block_table_tensor = (
block_table_tensor
+ torch.arange(
0,
page_size,
device=block_table_tensor.device,
dtype=block_table_tensor.dtype,
)[None, None, :]
)
block_table_tensor = block_table_tensor.view(bs, -1)

# after remapping, we assume the block size already equals to 1

max_blk_size_per_req = block_table_tensor.shape[-1]
mask = torch.arange(
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
).unsqueeze(0) < block_table_bounds.unsqueeze(1)
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
paged_kv_indices = block_table_tensor[mask]

paged_kv_last_page_len = seq_lens_device % page_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Recompute last page lengths after token-level remapping

After expanding each block table entry into per-token indices, the code still derives paged_kv_last_page_len from the original page_size (seq_lens % page_size, falling back to page_size). Once the remapping is done, each entry represents a single token, so the last-page length for any non-empty request should always be 1. Keeping the old computation causes the decode kernel to believe that the final page contains page_size tokens (e.g. 128) and it will read that many elements starting from the last token’s index, potentially stepping past the valid token range when block_size > 1. This defeats the goal of supporting larger block sizes and can lead to out-of-bounds accesses or garbage attention results for any request longer than one token.

Useful? React with 👍 / 👎.

Expand All @@ -135,13 +153,19 @@ def _build_decode(

paged_kv_indptr = torch.cat(
[
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32),
torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
]
)

if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0)
self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_(
block_table_tensor, non_blocking=True
)
block_table_tensor = self.block_table_remapping[
:num_reqs, :max_blk_size_per_req
]

self.paged_kv_indices[:num_actual_pages].copy_(
paged_kv_indices, non_blocking=True
Expand Down