Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 53 additions & 19 deletions tests/kernels/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import random
from typing import Optional
from typing import List, Optional

from flash_attn.flash_attention import FlashAttention
import torch
Expand Down Expand Up @@ -64,6 +64,39 @@ def ref_single_query_cached_kv_attention(
output[i].copy_(out, non_blocking=True)


def ref_multi_query_kv_attention(
cu_seq_lens: List[int],
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
head_size = query.shape[-1]
scale = 1.0 / (head_size ** 0.5)

num_seqs = len(cu_seq_lens) - 1
ref_outputs = []
for i in range(num_seqs):
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
seq_len = end_idx - start_idx

# Create attention mask
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda')

ref_output = ref_masked_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
scale,
attn_mask=attn_mask,
)
ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)
return ref_output


def test_single_query_cached_kv_attention(
num_tokens: int,
num_heads: int,
Expand Down Expand Up @@ -156,30 +189,29 @@ def test_multi_query_kv_attention(
causal=True,
)[0]

ref_outputs = []
for i, seq_len in enumerate(seq_lens):
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
ref_output = ref_masked_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
scale,
attn_mask=attn_mask,
)
ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)

cu_seq_lens = cu_seq_lens.cpu().tolist()
ref_output = ref_multi_query_kv_attention(
cu_seq_lens,
query,
key,
value,
dtype,
)
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


@torch.inference_mode()
def test_attention() -> None:
def test_attention(seed: int) -> None:
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
# the test fails due to the precision issue. Re-run the test if it fails.
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
for dtype in [torch.half, torch.float]:
for block_size in [8, 16]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
test_single_query_cached_kv_attention(
num_tokens=37,
num_heads=3,
Expand All @@ -193,6 +225,8 @@ def test_attention() -> None:
for dtype in [torch.half]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
for head_size in [64, 80, 96, 128]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}')
test_multi_query_kv_attention(
num_seqs=11,
num_heads=3,
Expand All @@ -202,4 +236,4 @@ def test_attention() -> None:


if __name__ == '__main__':
test_attention()
test_attention(seed=0)