Skip to content

Commit 1d14709

Browse files
committed
Cleanup ROCm output passing
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent 9665313 commit 1d14709

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828

2929
class ROCmFlashAttentionBackend(AttentionBackend):
30+
accept_output_buffer: bool = True
3031

3132
@staticmethod
3233
def get_name() -> str:
@@ -613,6 +614,8 @@ def forward(
613614
Returns:
614615
shape = [num_tokens, num_heads * head_size]
615616
"""
617+
assert output is not None, "Output tensor must be provided."
618+
616619
query = query.view(-1, self.num_heads, self.head_size)
617620
if key is not None:
618621
assert value is not None
@@ -656,7 +659,6 @@ def forward(
656659
assert attn_metadata.num_encoder_tokens is not None
657660
num_prefill_tokens = attn_metadata.num_encoder_tokens
658661

659-
output = torch.empty_like(query)
660662
# Query for decode. KV is not needed because it is already cached.
661663
decode_query = query[num_prefill_tokens:]
662664
# QKV for prefill.
@@ -708,7 +710,7 @@ def forward(
708710
query,
709711
key,
710712
value,
711-
None,
713+
output[:num_prefill_tokens],
712714
query_seq_start_loc,
713715
key_seq_start_loc,
714716
query_max_seq_len,
@@ -737,6 +739,7 @@ def forward(
737739
query,
738740
key,
739741
value,
742+
output[:num_prefill_tokens],
740743
query_seq_start_loc,
741744
num_prefill_tokens,
742745
self.num_heads,
@@ -749,6 +752,7 @@ def forward(
749752
q=query,
750753
k=key,
751754
v=value,
755+
out=output[:num_prefill_tokens],
752756
cu_seqlens_q=query_seq_start_loc,
753757
cu_seqlens_k=key_seq_start_loc,
754758
max_seqlen_q=prefill_meta.max_prefill_seq_len,
@@ -762,10 +766,7 @@ def forward(
762766

763767
# common code for prefill
764768
assert output[:num_prefill_tokens].shape == out.shape
765-
if output.shape[0] > num_prefill_tokens:
766-
output[:num_prefill_tokens] = out
767-
else:
768-
output = out
769+
769770
else:
770771
# prefix-enabled attention -
771772
# not applicable for encoder-only models
@@ -818,14 +819,10 @@ def forward(
818819
device=output.device,
819820
)
820821
max_logits = torch.empty_like(exp_sums)
821-
if num_prefill_tokens > 0:
822-
out = output[num_prefill_tokens:]
823-
else:
824-
out = output
825822

826823
query_start_loc = None
827824
ops.paged_attention_rocm(
828-
out,
825+
output[num_prefill_tokens:],
829826
exp_sums,
830827
max_logits,
831828
tmp_output,
@@ -878,6 +875,7 @@ def _sdpa_attention(
878875
query: torch.Tensor,
879876
key: torch.Tensor,
880877
value: torch.Tensor,
878+
output: torch.Tensor,
881879
seq_lens: List[int],
882880
num_tokens: int,
883881
num_heads: int,
@@ -886,9 +884,9 @@ def _sdpa_attention(
886884
attn_masks: Optional[List[torch.Tensor]] = None,
887885
) -> torch.Tensor:
888886
start = 0
889-
output = torch.empty((num_tokens, num_heads, head_size),
890-
dtype=query.dtype,
891-
device=query.device)
887+
assert output.shape == (num_tokens, num_heads, head_size)
888+
assert output.dtype == query.dtype
889+
assert output.device == query.device
892890

893891
for i, seq_len in enumerate(seq_lens):
894892
end = start + seq_len

0 commit comments

Comments
 (0)