diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 6308f63cc4e7..ce758b3f3013 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -33,6 +33,14 @@ def aiter_mla_decode_fwd( kv_indices: torch.Tensor | None = None, kv_last_page_lens: torch.Tensor | None = None, logit_cap: float = 0.0, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, ): torch.ops.vllm.rocm_aiter_mla_decode_fwd( q, @@ -45,6 +53,14 @@ def aiter_mla_decode_fwd( kv_last_page_lens, sm_scale=sm_scale, logit_cap=logit_cap, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=q_scale, + kv_scale=kv_scale, ) @@ -59,6 +75,14 @@ def mla_decode_fwd_impl( kv_last_page_lens: torch.Tensor | None = None, sm_scale: float = 1.0, logit_cap: float = 0.0, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, ) -> None: from aiter.mla import mla_decode_fwd @@ -73,6 +97,14 @@ def mla_decode_fwd_impl( max_seqlen_qo, sm_scale=sm_scale, logit_cap=logit_cap, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=q_scale, + kv_scale=kv_scale, ) @@ -87,6 +119,14 @@ def mla_decode_fwd_fake( kv_last_page_lens: torch.Tensor | None = None, sm_scale: float = 1.0, logit_cap: float = 0.0, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, ) -> None: pass diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index d935c02243bd..49e0e077206a 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from dataclasses import dataclass -from typing import ClassVar +from typing import ClassVar, Optional, Union import torch @@ -56,6 +57,20 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): # The query indptr, shape : [num_decode + 1] qo_indptr: torch.Tensor | None = None + max_seqlen_qo: int = 1 + + work_metadata: torch.Tensor | None = None + + work_info_set: torch.Tensor | None = None + + work_indptr: torch.Tensor | None = None + + reduce_indptr: torch.Tensor | None = None + + reduce_final_map: torch.Tensor | None = None + + reduce_partial_map: torch.Tensor | None = None + class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): pass @@ -82,6 +97,10 @@ def __init__( "AITER MLAonly supports block size 1." ) + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + self.compilation_config = vllm_config.compilation_config max_num_pages_per_req = cdiv( vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size @@ -89,6 +108,36 @@ def __init__( max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req + # num_mtp = vllm_config.speculative_config.num_speculative_tokens + # num_mtp = 1 if num_mtp is None else num_mtp + max_seqlen_qo = ( + 1 + if vllm_config.speculative_config is None + else vllm_config.speculative_config.num_speculative_tokens + ) + + max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * self.num_heads / 128)) + self.work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") + self.work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + self.work_info_set = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch * cu_num, 8], + dtype=torch.int32, + device="cuda", + ).fill_(-1) + self.reduce_indptr = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch + 1], + dtype=torch.int32, + device="cuda", + ) + self.reduce_final_map = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + ) + self.reduce_partial_map = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch * cu_num], + dtype=torch.int32, + device="cuda", + ) + # Preparing persistent buffers # TODO: we can disambiguate between decode and mixed-prefill decode here # so we can only use the persistent buffer if a cudagraph is actually @@ -139,6 +188,32 @@ def _build_decode( block_table_bounds.cumsum(dim=0, dtype=torch.int32), ] ) + kv_indptr = torch.zeros( + [query_start_loc_cpu.size(0)], dtype=torch.int32, device="cuda" + ) + torch.cumsum(seq_lens_device, dim=0, out=kv_indptr[1:]) + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + max_seqlen_qo = torch.max(query_lens).item() + + import aiter + + aiter.get_mla_metadata_v1( + query_start_loc_device, + kv_indptr, + self.num_heads // self.kv_cache_spec.num_kv_heads, + self.kv_cache_spec.num_kv_heads, + True, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + kv_granularity=max(self.kv_cache_spec.block_size, 16), + max_seqlen_qo=max_seqlen_qo, + uni_seqlen_qo=max_seqlen_qo, + fast_mode=True, + ) if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): num_actual_pages = paged_kv_indices.size(0) @@ -176,6 +251,13 @@ def _build_decode( paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr, dcp_tot_seq_lens=dcp_tot_seq_lens_device, + max_seqlen_qo=max_seqlen_qo, + work_metadata=self.work_metadata, + work_info_set=self.work_info_set, + work_indptr=self.work_indptr, + reduce_indptr=self.reduce_indptr, + reduce_final_map=self.reduce_final_map, + reduce_partial_map=self.reduce_partial_map, ) return attn_metadata @@ -256,24 +338,31 @@ def _forward_decode( assert isinstance(q, torch.Tensor) B = q.shape[0] o = torch.zeros( - B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device - ) + B, self.num_heads, self.kv_lora_rank, dtype=torch.bfloat16, device=q.device + ).fill_(-1) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) # max_seqlen_qo must be 1 except for MTP # TODO: Find the best value for MTP - max_seqlen_qo = 1 aiter_mla_decode_fwd( q, kv_buffer, o, self.scale, attn_metadata.decode.qo_indptr, - max_seqlen_qo, + attn_metadata.decode.max_seqlen_qo, attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len, + work_meta_data=attn_metadata.decode.work_metadata, + work_indptr=attn_metadata.decode.work_indptr, + work_info_set=attn_metadata.decode.work_info_set, + reduce_indptr=attn_metadata.decode.reduce_indptr, + reduce_final_map=attn_metadata.decode.reduce_final_map, + reduce_partial_map=attn_metadata.decode.reduce_partial_map, + q_scale=layer._q_scale, + kv_scale=layer._k_scale, ) return o, None