diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 222b9c158e5e..45efcbde698b 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="5a77249" +ARG AITER_BRANCH="c1debd8" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 7ce39110ac01..31980e94a037 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -53,6 +53,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): # The number of entries in the last page of each request in # the paged kv cache, shape: [batch_size] paged_kv_last_page_len: Optional[torch.Tensor] = None + # The query indptr, shape : [num_decode + 1] + qo_indptr: Optional[torch.Tensor] = None class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): @@ -75,27 +77,33 @@ def _get_paged_kv_tensors( seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens + page_size - 1) // page_size + device = self.runner.device mask = (torch.arange(block_table.size(1), dtype=block_table.dtype, - device=block_table.device).unsqueeze(0) + device=device).unsqueeze(0) < block_table_bounds.unsqueeze(1)) paged_kv_indices = block_table[mask] paged_kv_indptr = torch.cat([ - torch.zeros(1, - dtype=block_table_bounds.dtype, - device=block_table_bounds.device), + torch.zeros(1, dtype=block_table_bounds.dtype, device=device), block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) paged_kv_last_page_len = seq_lens % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) + qo_indptr = torch.arange(0, + self._num_decodes + 1, + step=1, + dtype=torch.int32, + device=device) + return ( paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, + qo_indptr, ) def _build_decode(self, block_table_tensor: torch.Tensor, @@ -105,6 +113,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, paged_kv_indices, paged_kv_indptr, paged_last_page_len, + qo_indptr, ) = self._get_paged_kv_tensors(block_table_tensor, seq_lens) attn_metadata = AiterMLADecodeMetadata( @@ -112,7 +121,8 @@ def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens=seq_lens, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_last_page_len) + paged_kv_last_page_len=paged_last_page_len, + qo_indptr=qo_indptr) return attn_metadata @@ -137,7 +147,10 @@ def __init__( alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, **mla_args) - + assert (num_heads == 16 or num_heads == 128), ( + f"Aiter MLA only supports 16 or 128 number of heads.\n" + f"Provided {num_heads} number of heads.\n" + "Try adjusting tensor_parallel_size value.") unsupported_features = [ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap ] @@ -189,7 +202,18 @@ def _forward_decode( kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + if self.num_heads == 16: + # AITER MLA decode kernel only supports + # max_seqlen_q=1 when using 16 heads. + max_seqlen_qo = 1 + else: + # AITER MLA decode Kernel handles arbitrary + # max_seqlen_q values when using 128 heads. + assert attn_metadata.prefill is not None + max_seqlen_qo = attn_metadata.prefill.max_query_len + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.decode.qo_indptr, max_seqlen_qo, attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len)