From 4f19e55795462315ed748575a59a281a9bfbf868 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Fri, 20 Jun 2025 01:56:57 +0000 Subject: [PATCH 1/3] fix AITER Flash Attention for Llama4 Signed-off-by: tjtanaa --- vllm/attention/layer.py | 14 +++++--- vllm/v1/attention/backends/rocm_aiter_fa.py | 39 ++++++++++++--------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index f7d230c5d7d6..a9ceeae363ac 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -306,12 +306,16 @@ def __init__( block_size=16, is_attention_free=False) backend = backend_name_to_enum(attn_backend.get_name()) - if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: - backend = _Backend.XFORMERS + if current_platform.is_rocm(): + # currently, only torch_sdpa is supported on rocm + backend = _Backend.TORCH_SDPA + else: + if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: + backend = _Backend.XFORMERS - self.attn_backend = backend if backend in { - _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 - } else _Backend.TORCH_SDPA + self.attn_backend = backend if backend in { + _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 + } else _Backend.TORCH_SDPA def forward( self, diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index e011e95efd41..31eae7928b82 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -243,8 +243,8 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, self.runner.device, non_blocking=True) local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( self.runner.device, non_blocking=True) - local_max_query_len = seqlens_q_local_np.max() - local_max_seq_len = virt_k_seqlens_np.max() + local_max_query_len = int(seqlens_q_local_np.max()) + local_max_seq_len = int(virt_k_seqlens_np.max()) local_scheduler_metadata = schedule( batch_size=local_query_start_loc.shape[0] - 1, cu_query_lens=local_query_start_loc, @@ -387,6 +387,7 @@ def __init__( blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[int] = None, use_irope: bool = False, ) -> None: if blocksparse_params is not None: @@ -408,6 +409,7 @@ def __init__( # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0. self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -478,22 +480,25 @@ def forward( # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens - # Reshape the input keys and values and store them in the cache. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] and - # value[:num_actual_tokens] because the reshape_and_cache_flash op uses - # the slot_mapping's shape to determine the number of actual tokens. key_cache, value_cache = kv_cache.unbind(0) - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(torch.float8_e4m3fnuz) From f4de631190691a92a33fa88fd30664982155129f Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Fri, 20 Jun 2025 10:10:37 +0000 Subject: [PATCH 2/3] fix MHA backend selection on rocm platform Signed-off-by: tjtanaa --- vllm/attention/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a9ceeae363ac..0c79aaf13551 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -308,7 +308,7 @@ def __init__( backend = backend_name_to_enum(attn_backend.get_name()) if current_platform.is_rocm(): # currently, only torch_sdpa is supported on rocm - backend = _Backend.TORCH_SDPA + self.attn_backend = _Backend.TORCH_SDPA else: if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: backend = _Backend.XFORMERS From 2dda1d9d84aee9f708a501ffcaf2bdbbdb41c148 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Sun, 22 Jun 2025 13:27:03 +0000 Subject: [PATCH 3/3] fix local attention metadata Signed-off-by: tjtanaa --- vllm/v1/attention/backends/rocm_aiter_fa.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 31eae7928b82..dc8ff2261306 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -253,6 +253,17 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len=local_max_seq_len, causal=True) + local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1, + dtype=torch.int32, + device=self.runner.device) + local_cu_seq_lens[1:] = torch.cumsum( + torch.from_numpy(virt_k_seqlens_np).to( + device=self.runner.device, + dtype=torch.int32, + non_blocking=True), + dim=0) + + local_attn_metadata = \ AiterFlashAttentionMetadata.LocalAttentionMetadata( local_query_start_loc=local_query_start_loc, @@ -260,6 +271,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, local_block_table=virt_block_table_tensor, local_max_query_len=local_max_query_len, local_max_seq_len=local_max_seq_len, + local_cu_seq_lens=local_cu_seq_lens, local_scheduler_metadata=local_scheduler_metadata, ) @@ -368,6 +380,7 @@ class LocalAttentionMetadata: local_block_table: torch.Tensor local_max_query_len: int local_max_seq_len: int + local_cu_seq_lens: torch.Tensor local_scheduler_metadata: Optional[torch.Tensor] local_attn_metadata: Optional[LocalAttentionMetadata] = None @@ -546,7 +559,8 @@ def forward( alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, - cu_seqlens_k=cu_seq_lens, + cu_seqlens_k=(cu_seq_lens if not use_local_attn else + local_metadata.local_cu_seq_lens), ) _, num_heads, head_size = query.shape