From 31bf99752a8bab33726c68bb9c459ed09b47423a Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 14 Feb 2025 22:29:02 +0000 Subject: [PATCH 1/3] Massage MLA's usage of flash attn for RoCM Signed-off-by: Tyler Michael Smith --- vllm/attention/backends/mla/utils.py | 36 ++++++++++++++++------------ 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index e9b4dff74f42..314c5a19eed7 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -487,20 +487,26 @@ def _forward_prefill_flash( v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) - attn_output = flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=seq_start_loc, - cu_seqlens_k=seq_start_loc, - max_seqlen_q=max_prefill_seq_len, - max_seqlen_k=max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - fa_version=self.vllm_flash_attn_version, - ) - attn_output = attn_output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) + # Here we massage the flash_attn_varlen_func interface since we use this + # codepath via vllm-flash-attn on NVIDIA and the upstream branch + # on AMD . + fa_args = { + "q": q, + "k": k, + "v": v_padded, + "cu_seqlens_q": seq_start_loc, + "cu_seqlens_k": seq_start_loc, + "max_seqlen_q": max_prefill_seq_len, + "max_seqlen_k": max_prefill_seq_len, + "softmax_scale": self.scale, + "causal": True, + } + + if self.vllm_flash_attn_version is not None: + fa_args["fa_version"] = self.vllm_flash_attn_version + + attn_output = flash_attn_varlen_func(**fa_args).view( + -1, self.num_heads, q.shape[-1])[..., :v.shape[-1]].reshape( + -1, self.num_heads * v.shape[-1]) return self.o_proj(attn_output)[0] From 460783f247ad508ddaefb5f786249f1cb7181b7b Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 14 Feb 2025 22:35:28 +0000 Subject: [PATCH 2/3] formatting Signed-off-by: Tyler Michael Smith --- vllm/attention/backends/mla/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 314c5a19eed7..b37ceabd949f 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -505,8 +505,8 @@ def _forward_prefill_flash( if self.vllm_flash_attn_version is not None: fa_args["fa_version"] = self.vllm_flash_attn_version - attn_output = flash_attn_varlen_func(**fa_args).view( - -1, self.num_heads, q.shape[-1])[..., :v.shape[-1]].reshape( - -1, self.num_heads * v.shape[-1]) + attn_output = flash_attn_varlen_func(**fa_args)\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) return self.o_proj(attn_output)[0] From 645a911841c0243f27d04bcafee51aaaaebfa2f3 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 14 Feb 2025 22:53:54 +0000 Subject: [PATCH 3/3] Improve Signed-off-by: Tyler Michael Smith --- vllm/attention/backends/mla/utils.py | 45 +++++++++++++++------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index b37ceabd949f..df3fb2aeefc4 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import functools from abc import abstractmethod from dataclasses import dataclass from typing import Any, Dict, Generic, List, Optional, Tuple @@ -183,6 +184,15 @@ def __init__( self.o_proj = o_proj self.vllm_flash_attn_version = get_flash_attn_version() + # Handle the differences between the flash_attn_varlen from flash_attn + # and the one from vllm_flash_attn. The former is used on RoCM and the + # latter has an additional parameter to control FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + if self.vllm_flash_attn_version is not None: + self.flash_attn_varlen_func = \ + functools.partial(flash_attn_varlen_func, + fa_version=self.vllm_flash_attn_version) + def _v_up_proj_and_o_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if is_fp8(self.W_UV_O): @@ -487,26 +497,19 @@ def _forward_prefill_flash( v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) - # Here we massage the flash_attn_varlen_func interface since we use this - # codepath via vllm-flash-attn on NVIDIA and the upstream branch - # on AMD . - fa_args = { - "q": q, - "k": k, - "v": v_padded, - "cu_seqlens_q": seq_start_loc, - "cu_seqlens_k": seq_start_loc, - "max_seqlen_q": max_prefill_seq_len, - "max_seqlen_k": max_prefill_seq_len, - "softmax_scale": self.scale, - "causal": True, - } - - if self.vllm_flash_attn_version is not None: - fa_args["fa_version"] = self.vllm_flash_attn_version - - attn_output = flash_attn_varlen_func(**fa_args)\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) + attn_output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=seq_start_loc, + cu_seqlens_k=seq_start_loc, + max_seqlen_q=max_prefill_seq_len, + max_seqlen_k=max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + ) + attn_output = attn_output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) return self.o_proj(attn_output)[0]