From 1d14709db871b4bf49294413930b7538ce444e62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 10 Apr 2025 20:34:57 +0000 Subject: [PATCH 1/4] Cleanup ROCm output passing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/attention/backends/rocm_flash_attn.py | 26 ++++++++++------------ 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 7376f9303788..c3e6130585c0 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -27,6 +27,7 @@ class ROCmFlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True @staticmethod def get_name() -> str: @@ -613,6 +614,8 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + assert output is not None, "Output tensor must be provided." + query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None @@ -656,7 +659,6 @@ def forward( assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens - output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. @@ -708,7 +710,7 @@ def forward( query, key, value, - None, + output[:num_prefill_tokens], query_seq_start_loc, key_seq_start_loc, query_max_seq_len, @@ -737,6 +739,7 @@ def forward( query, key, value, + output[:num_prefill_tokens], query_seq_start_loc, num_prefill_tokens, self.num_heads, @@ -749,6 +752,7 @@ def forward( q=query, k=key, v=value, + out=output[:num_prefill_tokens], cu_seqlens_q=query_seq_start_loc, cu_seqlens_k=key_seq_start_loc, max_seqlen_q=prefill_meta.max_prefill_seq_len, @@ -762,10 +766,7 @@ def forward( # common code for prefill assert output[:num_prefill_tokens].shape == out.shape - if output.shape[0] > num_prefill_tokens: - output[:num_prefill_tokens] = out - else: - output = out + else: # prefix-enabled attention - # not applicable for encoder-only models @@ -818,14 +819,10 @@ def forward( device=output.device, ) max_logits = torch.empty_like(exp_sums) - if num_prefill_tokens > 0: - out = output[num_prefill_tokens:] - else: - out = output query_start_loc = None ops.paged_attention_rocm( - out, + output[num_prefill_tokens:], exp_sums, max_logits, tmp_output, @@ -878,6 +875,7 @@ def _sdpa_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + output: torch.Tensor, seq_lens: List[int], num_tokens: int, num_heads: int, @@ -886,9 +884,9 @@ def _sdpa_attention( attn_masks: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: start = 0 - output = torch.empty((num_tokens, num_heads, head_size), - dtype=query.dtype, - device=query.device) + assert output.shape == (num_tokens, num_heads, head_size) + assert output.dtype == query.dtype + assert output.device == query.device for i, seq_len in enumerate(seq_lens): end = start + seq_len From 1db78f0c9d8f939b0a097b6db78e75fcd1c83dfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sun, 13 Apr 2025 06:19:33 +0000 Subject: [PATCH 2/4] Fix output for ROCm FA output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/attention/backends/rocm_flash_attn.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c3e6130585c0..c7d69faf5906 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -516,7 +516,7 @@ def __init__( from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) - self.attn_func = triton_attention + self.triton_attn_func = triton_attention logger.debug("Using Triton FA in ROCmBackend") if self.sliding_window != (-1, -1): logger.warning("ROCm Triton FA does not currently support " @@ -532,7 +532,7 @@ def __init__( else: try: from flash_attn import flash_attn_varlen_func # noqa: F401 - self.attn_func = flash_attn_varlen_func + self.fa_attn_func = flash_attn_varlen_func logger.debug("Using CK FA in ROCmBackend") except ModuleNotFoundError: self.use_naive_attn = True @@ -543,7 +543,7 @@ def __init__( "ROCm Naive FlashAttention does not support " "attention logits soft capping.") - self.attn_func = _sdpa_attention + self.sdpa_attn_func = _sdpa_attention logger.debug("Using naive (SDPA) attention in ROCmBackend") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -706,7 +706,7 @@ def forward( query.dtype, seq_lens, make_attn_mask=causal_mask) # type: ignore - out, _ = self.attn_func( + self.triton_attn_func( query, key, value, @@ -735,7 +735,7 @@ def forward( key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) # sdpa math backend attention - out = self.attn_func( + self.sdpa_attn_func( query, key, value, @@ -748,7 +748,8 @@ def forward( attn_masks, ) else: - out = self.attn_func( + # upstream FA does not support an output arg, copy + output[:num_prefill_tokens] = self.fa_attn_func( q=query, k=key, v=value, @@ -764,9 +765,6 @@ def forward( softcap=self.logits_soft_cap, ) - # common code for prefill - assert output[:num_prefill_tokens].shape == out.shape - else: # prefix-enabled attention - # not applicable for encoder-only models From 38f95b1a27a459c10fcefb0cab3d7a07fcb6f92f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sun, 13 Apr 2025 20:26:37 +0000 Subject: [PATCH 3/4] Fix sdpa arg type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/attention/backends/rocm_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c7d69faf5906..c01985abbcfa 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -874,7 +874,7 @@ def _sdpa_attention( key: torch.Tensor, value: torch.Tensor, output: torch.Tensor, - seq_lens: List[int], + seq_lens: torch.Tensor, num_tokens: int, num_heads: int, head_size: int, From f0b63eed85d815b2caa2bed01027b0bc48860598 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sun, 13 Apr 2025 20:30:13 +0000 Subject: [PATCH 4/4] Remove out param from FA path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/attention/backends/rocm_flash_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c01985abbcfa..90a21906b6e6 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -753,7 +753,6 @@ def forward( q=query, k=key, v=value, - out=output[:num_prefill_tokens], cu_seqlens_q=query_seq_start_loc, cu_seqlens_k=key_seq_start_loc, max_seqlen_q=prefill_meta.max_prefill_seq_len,