Skip to content

Commit 475d004

Browse files
committed
feat(keye): support xformers backend
Signed-off-by: Kwai-Keye <Keye@kuaishou.com>
1 parent 434da49 commit 475d004

File tree

1 file changed

+42
-22
lines changed

1 file changed

+42
-22
lines changed

vllm/model_executor/models/keye.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,9 @@ def __init__(
383383
prefix=f"{prefix}.out_proj",
384384
)
385385

386+
# Detect attention implementation.
386387
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
387-
if self.attn_backend not in {_Backend.FLASH_ATTN}:
388+
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
388389
raise RuntimeError(
389390
f"Keye-VL does not support {self.attn_backend} backend now.")
390391

@@ -402,18 +403,22 @@ def forward(
402403
dim=-1,
403404
)
404405

406+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
407+
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
408+
batch_size = q.shape[0]
409+
405410
if rope_emb is None:
406-
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim).squeeze(0)
411+
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
407412
k = k.view(
408413
*k.shape[:-1],
409414
self.num_kv_heads,
410415
self.head_dim,
411-
).squeeze(0)
416+
)
412417
v = v.view(
413418
*v.shape[:-1],
414419
self.num_kv_heads,
415420
self.head_dim,
416-
).squeeze(0)
421+
)
417422
else:
418423
if cu_seqlens is None:
419424
raise ValueError(
@@ -426,31 +431,45 @@ def forward(
426431
self.head_dim,
427432
)
428433
q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin)
429-
q = q.squeeze(0)
430-
k = k.squeeze(0)
431434
v = v.view(
432435
*v.shape[:-1],
433436
self.num_kv_heads,
434437
self.head_dim,
435-
).squeeze(0)
438+
)
436439

437-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
440+
if self.attn_backend == _Backend.FLASH_ATTN:
441+
from flash_attn import flash_attn_varlen_func
442+
443+
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
444+
445+
output = flash_attn_varlen_func(
446+
q,
447+
k,
448+
v,
449+
cu_seqlens_q=cu_seqlens,
450+
cu_seqlens_k=cu_seqlens,
451+
max_seqlen_q=max_seqlen,
452+
max_seqlen_k=max_seqlen,
453+
causal=False,
454+
softmax_scale=self.scale,
455+
)
456+
context_layer = rearrange(output,
457+
"(b s) ... -> b s ...",
458+
b=batch_size)
459+
elif self.attn_backend == _Backend.XFORMERS:
460+
from xformers import ops as xops
461+
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
438462

439-
from flash_attn import flash_attn_varlen_func
440-
441-
output = flash_attn_varlen_func(
442-
q,
443-
k,
444-
v,
445-
cu_seqlens_q=cu_seqlens,
446-
cu_seqlens_k=cu_seqlens,
447-
max_seqlen_q=max_seqlen,
448-
max_seqlen_k=max_seqlen,
449-
causal=False,
450-
softmax_scale=self.scale,
451-
)
463+
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
464+
kv_seqlen=None,
465+
device=q.device)
466+
467+
context_layer = xops.memory_efficient_attention_forward(
468+
q, k, v, attn_bias=attn_bias, p=0, scale=None)
469+
470+
context_layer = rearrange(context_layer,
471+
"b s h d -> b s (h d)").contiguous()
452472

453-
context_layer = output.flatten(-2).unsqueeze(0)
454473
output, _ = self.out_proj(context_layer)
455474
return output
456475

@@ -528,6 +547,7 @@ def forward(
528547
residual = hidden_states
529548
hidden_states = self.layer_norm2(hidden_states)
530549
hidden_states = self.mlp(hidden_states)
550+
531551
hidden_states = residual + hidden_states
532552

533553
return hidden_states

0 commit comments

Comments
 (0)