Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 18 additions & 23 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@


class ROCmFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True

@staticmethod
def get_name() -> str:
Expand Down Expand Up @@ -515,7 +516,7 @@ def __init__(

from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
self.triton_attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support "
Expand All @@ -531,7 +532,7 @@ def __init__(
else:
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func
self.fa_attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True
Expand All @@ -542,7 +543,7 @@ def __init__(
"ROCm Naive FlashAttention does not support "
"attention logits soft capping.")

self.attn_func = _sdpa_attention
self.sdpa_attn_func = _sdpa_attention
logger.debug("Using naive (SDPA) attention in ROCmBackend")

def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down Expand Up @@ -613,6 +614,8 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."

query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
Expand Down Expand Up @@ -656,7 +659,6 @@ def forward(
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens

output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
Expand Down Expand Up @@ -704,11 +706,11 @@ def forward(
query.dtype,
seq_lens,
make_attn_mask=causal_mask) # type: ignore
out, _ = self.attn_func(
self.triton_attn_func(
query,
key,
value,
None,
output[:num_prefill_tokens],
query_seq_start_loc,
key_seq_start_loc,
query_max_seq_len,
Expand All @@ -733,10 +735,11 @@ def forward(
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
# sdpa math backend attention
out = self.attn_func(
self.sdpa_attn_func(
query,
key,
value,
output[:num_prefill_tokens],
query_seq_start_loc,
num_prefill_tokens,
self.num_heads,
Expand All @@ -745,7 +748,8 @@ def forward(
attn_masks,
)
else:
out = self.attn_func(
# upstream FA does not support an output arg, copy
output[:num_prefill_tokens] = self.fa_attn_func(
q=query,
k=key,
v=value,
Expand All @@ -760,12 +764,6 @@ def forward(
softcap=self.logits_soft_cap,
)

# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
if output.shape[0] > num_prefill_tokens:
output[:num_prefill_tokens] = out
else:
output = out
else:
# prefix-enabled attention -
# not applicable for encoder-only models
Expand Down Expand Up @@ -818,14 +816,10 @@ def forward(
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
if num_prefill_tokens > 0:
out = output[num_prefill_tokens:]
else:
out = output

query_start_loc = None
ops.paged_attention_rocm(
out,
output[num_prefill_tokens:],
exp_sums,
max_logits,
tmp_output,
Expand Down Expand Up @@ -878,17 +872,18 @@ def _sdpa_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
seq_lens: List[int],
output: torch.Tensor,
seq_lens: torch.Tensor,
num_tokens: int,
num_heads: int,
head_size: int,
scale: float,
attn_masks: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
start = 0
output = torch.empty((num_tokens, num_heads, head_size),
dtype=query.dtype,
device=query.device)
assert output.shape == (num_tokens, num_heads, head_size)
assert output.dtype == query.dtype
assert output.device == query.device

for i, seq_len in enumerate(seq_lens):
end = start + seq_len
Expand Down