2727
2828
2929class ROCmFlashAttentionBackend (AttentionBackend ):
30+ accept_output_buffer : bool = True
3031
3132 @staticmethod
3233 def get_name () -> str :
@@ -613,6 +614,8 @@ def forward(
613614 Returns:
614615 shape = [num_tokens, num_heads * head_size]
615616 """
617+ assert output is not None , "Output tensor must be provided."
618+
616619 query = query .view (- 1 , self .num_heads , self .head_size )
617620 if key is not None :
618621 assert value is not None
@@ -656,7 +659,6 @@ def forward(
656659 assert attn_metadata .num_encoder_tokens is not None
657660 num_prefill_tokens = attn_metadata .num_encoder_tokens
658661
659- output = torch .empty_like (query )
660662 # Query for decode. KV is not needed because it is already cached.
661663 decode_query = query [num_prefill_tokens :]
662664 # QKV for prefill.
@@ -708,7 +710,7 @@ def forward(
708710 query ,
709711 key ,
710712 value ,
711- None ,
713+ output [: num_prefill_tokens ] ,
712714 query_seq_start_loc ,
713715 key_seq_start_loc ,
714716 query_max_seq_len ,
@@ -737,6 +739,7 @@ def forward(
737739 query ,
738740 key ,
739741 value ,
742+ output [:num_prefill_tokens ],
740743 query_seq_start_loc ,
741744 num_prefill_tokens ,
742745 self .num_heads ,
@@ -749,6 +752,7 @@ def forward(
749752 q = query ,
750753 k = key ,
751754 v = value ,
755+ out = output [:num_prefill_tokens ],
752756 cu_seqlens_q = query_seq_start_loc ,
753757 cu_seqlens_k = key_seq_start_loc ,
754758 max_seqlen_q = prefill_meta .max_prefill_seq_len ,
@@ -762,10 +766,7 @@ def forward(
762766
763767 # common code for prefill
764768 assert output [:num_prefill_tokens ].shape == out .shape
765- if output .shape [0 ] > num_prefill_tokens :
766- output [:num_prefill_tokens ] = out
767- else :
768- output = out
769+
769770 else :
770771 # prefix-enabled attention -
771772 # not applicable for encoder-only models
@@ -818,14 +819,10 @@ def forward(
818819 device = output .device ,
819820 )
820821 max_logits = torch .empty_like (exp_sums )
821- if num_prefill_tokens > 0 :
822- out = output [num_prefill_tokens :]
823- else :
824- out = output
825822
826823 query_start_loc = None
827824 ops .paged_attention_rocm (
828- out ,
825+ output [ num_prefill_tokens :] ,
829826 exp_sums ,
830827 max_logits ,
831828 tmp_output ,
@@ -878,6 +875,7 @@ def _sdpa_attention(
878875 query : torch .Tensor ,
879876 key : torch .Tensor ,
880877 value : torch .Tensor ,
878+ output : torch .Tensor ,
881879 seq_lens : List [int ],
882880 num_tokens : int ,
883881 num_heads : int ,
@@ -886,9 +884,9 @@ def _sdpa_attention(
886884 attn_masks : Optional [List [torch .Tensor ]] = None ,
887885) -> torch .Tensor :
888886 start = 0
889- output = torch . empty (( num_tokens , num_heads , head_size ),
890- dtype = query .dtype ,
891- device = query .device )
887+ assert output . shape == ( num_tokens , num_heads , head_size )
888+ assert output . dtype == query .dtype
889+ assert output . device == query .device
892890
893891 for i , seq_len in enumerate (seq_lens ):
894892 end = start + seq_len
0 commit comments