diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 363aa08ef003..2ef66229b833 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1213,9 +1213,9 @@ def _compute_prefill_context( attn_output, attn_softmax_lse = \ self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, + q, + k, + v, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], max_seqlen_q=prefill_metadata.max_query_len, @@ -1267,9 +1267,9 @@ def _forward_prefill( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) output = self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, + q, + k, + v, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.query_start_loc, max_seqlen_q=prefill_metadata.max_prefill_seq_len, diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 4936c8201399..d6df6d8b0e8e 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -53,7 +53,7 @@ def get_state_cls() -> Type["AiterMLAState"]: @dataclass class AiterMLAMetadata(MLACommonMetadata): - # The following 4 tensors are for current version of AITER MLA + # The following 5 tensors are for current version of AITER MLA block_table_bound: Optional[torch.Tensor] = None # The indptr of the paged kv cache, shape: [batch_size + 1] paged_kv_indptr: Optional[torch.Tensor] = None @@ -63,6 +63,10 @@ class AiterMLAMetadata(MLACommonMetadata): # the paged kv cache, shape: [batch_size] paged_kv_last_page_lens: Optional[torch.Tensor] = None + # This is just to make new AITER MLA API work + # -- MTP support is not added yet. + qo_indptr: Optional[torch.Tensor] = None + @property def prefill_metadata(self): prefill_metadata = super().prefill_metadata @@ -74,6 +78,7 @@ def prefill_metadata(self): prefill_metadata\ .paged_kv_last_page_lens = self.paged_kv_last_page_lens prefill_metadata.block_table_bound = self.block_table_bound + prefill_metadata.qo_indptr = self.qo_indptr # update the cache self._cached_prefill_metadata = self.__class__( @@ -93,6 +98,7 @@ def decode_metadata(self): decode_metadata\ .paged_kv_last_page_lens = self.paged_kv_last_page_lens decode_metadata.block_table_bound = self.block_table_bound + decode_metadata.qo_indptr = self.qo_indptr # update the cache self._cached_decode_metadata = self.__class__( @@ -136,6 +142,7 @@ def prepare(self): self.paged_kv_indptr: list[int] = [0] self.paged_kv_last_page_lens: list[int] = [] self.total_blocks = 0 + self.qo_indptr: list[int] = [0] def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, prefix_cache_hit: bool): @@ -208,6 +215,7 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int): self.paged_kv_indices.extend(block_table[:block_table_bound]) self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + block_table_bound) + self.qo_indptr.append(self.qo_indptr[-1] + 1) last_page_len = seq_len % self.block_size if last_page_len == 0: @@ -226,6 +234,8 @@ def build(self, seq_lens: list[int], query_lens: list[int], self.paged_kv_indptr.extend([last_paged_kv_indptr] * cuda_graph_pad_size) self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) + last_qo_indptr = self.qo_indptr[-1] + self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size) # For current version of AITER MLA if len(self.paged_kv_indptr) > 0: @@ -245,16 +255,22 @@ def build(self, seq_lens: list[int], query_lens: list[int], 1, device=device, dtype=torch.int) + + qo_indptr = torch.tensor(self.qo_indptr, + device=device, + dtype=torch.int) else: paged_kv_indices_tensor = None paged_kv_indptr_tensor = None paged_kv_last_page_lens_tensor = None block_table_bound_tensor = None + qo_indptr = None metadata.paged_kv_indptr = paged_kv_indptr_tensor metadata.paged_kv_indices = paged_kv_indices_tensor metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor metadata.block_table_bound = block_table_bound_tensor + metadata.qo_indptr = qo_indptr return metadata @@ -263,14 +279,17 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]): @contextmanager def graph_capture(self, max_batch_size: int): - kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata( - max_batch_size=max_batch_size, - block_size=self.runner.block_size, - max_block_per_batch=self.runner.get_max_block_per_batch(), - device=self.runner.device) + kv_indices, kv_indptr, last_page_lens, qo_indptr = \ + get_aiter_mla_metadata( + max_batch_size=max_batch_size, + block_size=self.runner.block_size, + max_block_per_batch=\ + self.runner.get_max_block_per_batch(), + device=self.runner.device) self._paged_kv_indices_tensor = kv_indices self._paged_kv_indptr_tensor = kv_indptr self._paged_kv_last_page_lens_tensor = last_page_lens + self._qo_indptr_tensor = qo_indptr with super().graph_capture(max_batch_size): yield @@ -278,6 +297,7 @@ def graph_capture(self, max_batch_size: int): del self._paged_kv_indices_tensor del self._paged_kv_indptr_tensor del self._paged_kv_last_page_lens_tensor + del self._qo_indptr_tensor def graph_capture_get_metadata_for_batch( self, @@ -291,10 +311,12 @@ def graph_capture_get_metadata_for_batch( paged_kv_indices = self._paged_kv_indices_tensor paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: batch_size] + qo_indptr = self._qo_indptr_tensor[:batch_size + 1] metadata.paged_kv_indptr = paged_kv_indptr metadata.paged_kv_indices = paged_kv_indices metadata.paged_kv_last_page_lens = paged_kv_last_page_lens + metadata.qo_indptr = qo_indptr return metadata @@ -311,6 +333,7 @@ def get_graph_input_buffers(self, input_buffers[ "paged_kv_last_page_lens"] = attn_metadata.\ decode_metadata.paged_kv_last_page_lens + input_buffers['qo_indptr'] = attn_metadata.qo_indptr return input_buffers @@ -330,6 +353,8 @@ def prepare_graph_input_buffers(self, input_buffers["paged_kv_last_page_lens"].copy_( attn_metadata.decode_metadata.paged_kv_last_page_lens, non_blocking=True) + input_buffers["qo_indptr"].copy_( + attn_metadata.decode_metadata.qo_indptr, non_blocking=True) class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): @@ -370,11 +395,9 @@ def _flash_attn_varlen_diff_headdims( softmax_scale: float, return_softmax_lse: bool, **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v, - softmax_scale=softmax_scale, - return_lse=return_softmax_lse, + q, + k, + v, **kwargs, ) @@ -394,7 +417,7 @@ def _forward_decode( B = q_nope.shape[0] q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.zeros(B, + o = torch.empty(B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, @@ -402,9 +425,14 @@ def _forward_decode( kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + aiter_mla_decode_fwd(q, + kv_buffer, + o, + attn_metadata.qo_indptr, + attn_metadata.max_query_len, attn_metadata.paged_kv_indptr, attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_lens) + attn_metadata.paged_kv_last_page_lens, + sm_scale=self.scale) return self._v_up_proj(o) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 3348d18804aa..b41db744feeb 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -20,17 +20,20 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int, paged_kv_last_page_lens = torch.full((max_batch_size, ), block_size, dtype=torch.int32) - return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens + qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device) + return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr def aiter_mla_decode_fwd( q: torch.Tensor, kv_buffer: torch.Tensor, o: torch.Tensor, - sm_scale: float, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, kv_indptr: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, logit_cap: float = 0.0, ): @@ -38,6 +41,8 @@ def aiter_mla_decode_fwd( kv_buffer.view( -1, 1, 1, q.shape[-1]), o, + qo_indptr, + max_seqlen_qo, kv_indptr, kv_indices, kv_last_page_lens, @@ -49,6 +54,8 @@ def mla_decode_fwd_impl( q: torch.Tensor, kv_buffer: torch.Tensor, o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, kv_indptr: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, kv_last_page_lens: Optional[torch.Tensor] = None, @@ -60,9 +67,11 @@ def mla_decode_fwd_impl( mla_decode_fwd(q, kv_buffer.view(-1, 1, 1, q.shape[-1]), o, + qo_indptr, kv_indptr, kv_indices, kv_last_page_lens, + max_seqlen_qo, sm_scale=sm_scale, logit_cap=logit_cap) @@ -71,6 +80,8 @@ def mla_decode_fwd_fake( q: torch.Tensor, kv_buffer: torch.Tensor, o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, kv_indptr: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, kv_last_page_lens: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 7d7bce9ec6ab..b31af95248e3 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -123,10 +123,11 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl( fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids, sorted_weight_buf, sorted_expert_ids, - num_valid_ids, topk, w1_scale.view(local_E, -1), - w2_scale.view(local_E, -1), - a1_scale.t().contiguous(), *block_shape, - smooth_scale) + num_valid_ids, topk, + a1_scale.t().contiguous(), + w1_scale.view(local_E, -1), + w2_scale.view(local_E, + -1), *block_shape, smooth_scale) return out_asm diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 37b72c08d52b..81bac9a15396 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -186,9 +186,14 @@ def _forward_decode( kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + aiter_mla_decode_fwd(q, + kv_buffer, + o, + attn_metadata.qo_indptr, + attn_metadata.max_query_len, attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indices, - attn_metadata.decode.paged_kv_last_page_len) + attn_metadata.decode.paged_kv_last_page_len, + sm_scale=self.scale) return self._v_up_proj(o)