Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,16 +296,8 @@ def get_attn_backend_cls(
f"does not support block size {block_size}."
)
if selected_backend == _Backend.ROCM_AITER_MLA:
if block_size == 1:
logger.info("Using AITER MLA backend on V1 engine.")
return (
"vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
)
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
"(currently only supports block size 1)"
)
logger.info("Using AITER MLA backend on V1 engine.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend."
Expand Down
38 changes: 31 additions & 7 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ def __init__(
super().__init__(
kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
)
assert self.kv_cache_spec.block_size == 1, (
"AITER MLAonly supports block size 1."
)

self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(
Expand All @@ -94,6 +91,11 @@ def __init__(
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.block_table_remapping = torch.zeros(
[max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size],
dtype=torch.int32,
device=device,
)
self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
)
Expand All @@ -119,13 +121,29 @@ def _build_decode(
dcp_tot_seq_lens_device: torch.Tensor | None,
) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
device = self.device
num_reqs = seq_lens_device.size(0)
bs, _ = block_table_tensor.shape
block_table_tensor = (
block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size
)
block_table_tensor = (
block_table_tensor
+ torch.arange(
0,
page_size,
device=block_table_tensor.device,
dtype=block_table_tensor.dtype,
)[None, None, :]
)
block_table_tensor = block_table_tensor.view(bs, -1)

# after remapping, we assume the block size already equals to 1

max_blk_size_per_req = block_table_tensor.shape[-1]
mask = torch.arange(
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
).unsqueeze(0) < block_table_bounds.unsqueeze(1)
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
paged_kv_indices = block_table_tensor[mask]

paged_kv_last_page_len = seq_lens_device % page_size
Expand All @@ -135,13 +153,19 @@ def _build_decode(

paged_kv_indptr = torch.cat(
[
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32),
torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
]
)

if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0)
self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_(
block_table_tensor, non_blocking=True
)
block_table_tensor = self.block_table_remapping[
:num_reqs, :max_blk_size_per_req
]

self.paged_kv_indices[:num_actual_pages].copy_(
paged_kv_indices, non_blocking=True
Expand Down