Skip to content

Commit

Permalink
[Misc][Bugfix] FA3 support to ViT MHA layer (vllm-project#12435)
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
  • Loading branch information
ywang96 and Isotr0py committed Feb 2, 2025
1 parent cd3e0e0 commit 92159d7
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,28 @@ def forward(
_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1,
}:
from vllm.vllm_flash_attn import flash_attn_func

out = flash_attn_func(query, key, value, softmax_scale=self.scale)
from vllm.vllm_flash_attn import flash_attn_varlen_func

cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=query.device)
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
step=kv_len,
dtype=torch.int32,
device=key.device)

out = flash_attn_varlen_func(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
out = out.reshape(bsz, q_len, -1)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

Expand Down

0 comments on commit 92159d7

Please sign in to comment.