@@ -383,8 +383,9 @@ def __init__(
383383 prefix = f"{ prefix } .out_proj" ,
384384 )
385385
386+ # Detect attention implementation.
386387 self .attn_backend : _Backend = get_vit_attn_backend (support_fa = True )
387- if self .attn_backend not in {_Backend .FLASH_ATTN }:
388+ if self .attn_backend not in {_Backend .FLASH_ATTN , _Backend . XFORMERS }:
388389 raise RuntimeError (
389390 f"Keye-VL does not support { self .attn_backend } backend now." )
390391
@@ -402,18 +403,22 @@ def forward(
402403 dim = - 1 ,
403404 )
404405
406+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
407+ seqlens = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).tolist ()
408+ batch_size = q .shape [0 ]
409+
405410 if rope_emb is None :
406- q = q .view (* q .shape [:- 1 ], self .num_heads , self .head_dim ). squeeze ( 0 )
411+ q = q .view (* q .shape [:- 1 ], self .num_heads , self .head_dim )
407412 k = k .view (
408413 * k .shape [:- 1 ],
409414 self .num_kv_heads ,
410415 self .head_dim ,
411- ). squeeze ( 0 )
416+ )
412417 v = v .view (
413418 * v .shape [:- 1 ],
414419 self .num_kv_heads ,
415420 self .head_dim ,
416- ). squeeze ( 0 )
421+ )
417422 else :
418423 if cu_seqlens is None :
419424 raise ValueError (
@@ -426,31 +431,45 @@ def forward(
426431 self .head_dim ,
427432 )
428433 q , k = apply_rotary_pos_emb_flashatt (q , k , cos , sin )
429- q = q .squeeze (0 )
430- k = k .squeeze (0 )
431434 v = v .view (
432435 * v .shape [:- 1 ],
433436 self .num_kv_heads ,
434437 self .head_dim ,
435- ). squeeze ( 0 )
438+ )
436439
437- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
440+ if self .attn_backend == _Backend .FLASH_ATTN :
441+ from flash_attn import flash_attn_varlen_func
442+
443+ q , k , v = (rearrange (x , "b s ... -> (b s) ..." ) for x in [q , k , v ])
444+
445+ output = flash_attn_varlen_func (
446+ q ,
447+ k ,
448+ v ,
449+ cu_seqlens_q = cu_seqlens ,
450+ cu_seqlens_k = cu_seqlens ,
451+ max_seqlen_q = max_seqlen ,
452+ max_seqlen_k = max_seqlen ,
453+ causal = False ,
454+ softmax_scale = self .scale ,
455+ )
456+ context_layer = rearrange (output ,
457+ "(b s) ... -> b s ..." ,
458+ b = batch_size )
459+ elif self .attn_backend == _Backend .XFORMERS :
460+ from xformers import ops as xops
461+ from xformers .ops .fmha .attn_bias import BlockDiagonalMask
438462
439- from flash_attn import flash_attn_varlen_func
440-
441- output = flash_attn_varlen_func (
442- q ,
443- k ,
444- v ,
445- cu_seqlens_q = cu_seqlens ,
446- cu_seqlens_k = cu_seqlens ,
447- max_seqlen_q = max_seqlen ,
448- max_seqlen_k = max_seqlen ,
449- causal = False ,
450- softmax_scale = self .scale ,
451- )
463+ attn_bias = BlockDiagonalMask .from_seqlens (q_seqlen = seqlens ,
464+ kv_seqlen = None ,
465+ device = q .device )
466+
467+ context_layer = xops .memory_efficient_attention_forward (
468+ q , k , v , attn_bias = attn_bias , p = 0 , scale = None )
469+
470+ context_layer = rearrange (context_layer ,
471+ "b s h d -> b s (h d)" ).contiguous ()
452472
453- context_layer = output .flatten (- 2 ).unsqueeze (0 )
454473 output , _ = self .out_proj (context_layer )
455474 return output
456475
@@ -528,6 +547,7 @@ def forward(
528547 residual = hidden_states
529548 hidden_states = self .layer_norm2 (hidden_states )
530549 hidden_states = self .mlp (hidden_states )
550+
531551 hidden_states = residual + hidden_states
532552
533553 return hidden_states
0 commit comments