2727
2828
2929class ROCmFlashAttentionBackend (AttentionBackend ):
30+ accept_output_buffer : bool = True
3031
3132 @staticmethod
3233 def get_name () -> str :
@@ -515,7 +516,7 @@ def __init__(
515516
516517 from vllm .attention .ops .triton_flash_attention import ( # noqa: F401
517518 triton_attention )
518- self .attn_func = triton_attention
519+ self .triton_attn_func = triton_attention
519520 logger .debug ("Using Triton FA in ROCmBackend" )
520521 if self .sliding_window != (- 1 , - 1 ):
521522 logger .warning ("ROCm Triton FA does not currently support "
@@ -531,7 +532,7 @@ def __init__(
531532 else :
532533 try :
533534 from flash_attn import flash_attn_varlen_func # noqa: F401
534- self .attn_func = flash_attn_varlen_func
535+ self .fa_attn_func = flash_attn_varlen_func
535536 logger .debug ("Using CK FA in ROCmBackend" )
536537 except ModuleNotFoundError :
537538 self .use_naive_attn = True
@@ -542,7 +543,7 @@ def __init__(
542543 "ROCm Naive FlashAttention does not support "
543544 "attention logits soft capping." )
544545
545- self .attn_func = _sdpa_attention
546+ self .sdpa_attn_func = _sdpa_attention
546547 logger .debug ("Using naive (SDPA) attention in ROCmBackend" )
547548
548549 def repeat_kv (self , x : torch .Tensor , n_rep : int ) -> torch .Tensor :
@@ -613,6 +614,8 @@ def forward(
613614 Returns:
614615 shape = [num_tokens, num_heads * head_size]
615616 """
617+ assert output is not None , "Output tensor must be provided."
618+
616619 query = query .view (- 1 , self .num_heads , self .head_size )
617620 if key is not None :
618621 assert value is not None
@@ -656,7 +659,6 @@ def forward(
656659 assert attn_metadata .num_encoder_tokens is not None
657660 num_prefill_tokens = attn_metadata .num_encoder_tokens
658661
659- output = torch .empty_like (query )
660662 # Query for decode. KV is not needed because it is already cached.
661663 decode_query = query [num_prefill_tokens :]
662664 # QKV for prefill.
@@ -704,11 +706,11 @@ def forward(
704706 query .dtype ,
705707 seq_lens ,
706708 make_attn_mask = causal_mask ) # type: ignore
707- out , _ = self .attn_func (
709+ self .triton_attn_func (
708710 query ,
709711 key ,
710712 value ,
711- None ,
713+ output [: num_prefill_tokens ] ,
712714 query_seq_start_loc ,
713715 key_seq_start_loc ,
714716 query_max_seq_len ,
@@ -733,10 +735,11 @@ def forward(
733735 key = key .movedim (0 , key .dim () - 2 )
734736 value = value .movedim (0 , value .dim () - 2 )
735737 # sdpa math backend attention
736- out = self .attn_func (
738+ self .sdpa_attn_func (
737739 query ,
738740 key ,
739741 value ,
742+ output [:num_prefill_tokens ],
740743 query_seq_start_loc ,
741744 num_prefill_tokens ,
742745 self .num_heads ,
@@ -745,7 +748,8 @@ def forward(
745748 attn_masks ,
746749 )
747750 else :
748- out = self .attn_func (
751+ # upstream FA does not support an output arg, copy
752+ output [:num_prefill_tokens ] = self .fa_attn_func (
749753 q = query ,
750754 k = key ,
751755 v = value ,
@@ -760,12 +764,6 @@ def forward(
760764 softcap = self .logits_soft_cap ,
761765 )
762766
763- # common code for prefill
764- assert output [:num_prefill_tokens ].shape == out .shape
765- if output .shape [0 ] > num_prefill_tokens :
766- output [:num_prefill_tokens ] = out
767- else :
768- output = out
769767 else :
770768 # prefix-enabled attention -
771769 # not applicable for encoder-only models
@@ -818,14 +816,10 @@ def forward(
818816 device = output .device ,
819817 )
820818 max_logits = torch .empty_like (exp_sums )
821- if num_prefill_tokens > 0 :
822- out = output [num_prefill_tokens :]
823- else :
824- out = output
825819
826820 query_start_loc = None
827821 ops .paged_attention_rocm (
828- out ,
822+ output [ num_prefill_tokens :] ,
829823 exp_sums ,
830824 max_logits ,
831825 tmp_output ,
@@ -878,17 +872,18 @@ def _sdpa_attention(
878872 query : torch .Tensor ,
879873 key : torch .Tensor ,
880874 value : torch .Tensor ,
881- seq_lens : List [int ],
875+ output : torch .Tensor ,
876+ seq_lens : torch .Tensor ,
882877 num_tokens : int ,
883878 num_heads : int ,
884879 head_size : int ,
885880 scale : float ,
886881 attn_masks : Optional [List [torch .Tensor ]] = None ,
887882) -> torch .Tensor :
888883 start = 0
889- output = torch . empty (( num_tokens , num_heads , head_size ),
890- dtype = query .dtype ,
891- device = query .device )
884+ assert output . shape == ( num_tokens , num_heads , head_size )
885+ assert output . dtype == query .dtype
886+ assert output . device == query .device
892887
893888 for i , seq_len in enumerate (seq_lens ):
894889 end = start + seq_len
0 commit comments