diff --git a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp index 2f8b64363..b769d626a 100644 --- a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp +++ b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp @@ -287,6 +287,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits) { + TORCH_CHECK(false, "vllm layout does not support mha_fwd_kvcache for now"); + auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); diff --git a/tests/test_flash_attn_ck.py b/tests/test_flash_attn_ck.py index b9091f0dc..e2bb1187c 100644 --- a/tests/test_flash_attn_ck.py +++ b/tests/test_flash_attn_ck.py @@ -1035,271 +1035,271 @@ def test_flash_attn_varlen_causal( # TODO - Support has_leftpad -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("num_splits", [1, 0]) -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("new_kv", [False, True]) -@pytest.mark.parametrize("alibi", [False, True]) -@pytest.mark.parametrize("local", [False, True]) -@pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) -@pytest.mark.parametrize("rotary_interleaved", [False, True]) -@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) -@pytest.mark.parametrize("paged_kv_block_size", [None, 256]) -@pytest.mark.parametrize("has_leftpad", [False]) -@pytest.mark.parametrize("has_batch_idx", [False, True]) -@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) -@pytest.mark.parametrize( - "seqlen_q,seqlen_k", - [ - (1, 128), - (1, 339), - (3, 1024), - (64, 800), - (64, 256), - (3, 799), - (64, 2048), - (16, 20000), - (1, 128 * 1024), - (16, 128 * 1024), - (128, 128), - ], -) -def test_flash_attn_kvcache( - seqlen_q, - seqlen_k, - d, - has_batch_idx, - has_leftpad, - paged_kv_block_size, - rotary_fraction, - rotary_interleaved, - seqlen_new_eq_seqlen_q, - causal, - local, - alibi, - new_kv, - mha_type, - num_splits, - dtype, -): - if seqlen_q > seqlen_k and new_kv: - pytest.skip() - if not new_kv and rotary_fraction > 0.0: - pytest.skip() - if has_batch_idx and paged_kv_block_size is not None: - pytest.skip() - if has_leftpad and paged_kv_block_size is not None: - pytest.skip() - device = "cuda" - # set seed - torch.random.manual_seed(0) - batch_size = 1 - batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 - nheads = 6 - # rotary_dim must be a multiple of 16, and must be <= d - rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 - nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) - assert nheads % nheads_k == 0 - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) - seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() - if new_kv: - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) - v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) - else: - k, v = None, None - if paged_kv_block_size is None: - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) - block_table = None - else: - ( - k_cache, - v_cache, - block_table, - k_cache_paged, - v_cache_paged, - num_blocks, - ) = _generate_block_kvcache( - seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype - ) - cache_seqlens = torch.randint( - 0 if new_kv else 1, - # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - ( - (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) - if new_kv - else (seqlen_k + 1) - ), - (batch_size,), - dtype=torch.int32, - device=device, - ) - if has_leftpad: - cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) - if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) - for i in range(batch_size)]) - else: - cache_leftpad = None - arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) - if has_leftpad: - key_padding_mask = torch.logical_and( - key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) - ) - if has_batch_idx: - cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ - :batch_size - ] - else: - cache_batch_idx = None - if alibi: - alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 - attn_bias = attn_bias_from_alibi_slopes( - alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad - ) - else: - alibi_slopes, attn_bias = None, None - # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) - if rotary_dim > 0: - angle = ( - torch.rand( - seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, - rotary_dim // 2, - device=device, - ) - * 2 - * math.pi - ) - cos = torch.cos(angle).to(dtype=dtype) - sin = torch.sin(angle).to(dtype=dtype) - if causal or local: - q_ro = apply_rotary_emb( - q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved - ) - else: - q_ro = rearrange( - apply_rotary_emb( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=cache_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=seqlen_q, - ) - # q_ro = q - k_ro = apply_rotary_emb( - k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved - ) - else: - cos, sin = None, None - q_ro, k_ro = q, k - # k_cache[:, 64:] = -1 - k_cache_ref = ( - k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] - ).clone() - v_cache_ref = ( - v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] - ).clone() - if new_kv: - update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new - ) - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - out = flash_attn_with_kvcache( - q, - k_cache if paged_kv_block_size is None else k_cache_paged, - v_cache if paged_kv_block_size is None else v_cache_paged, - k, - v, - rotary_cos=cos, - rotary_sin=sin, - cache_seqlens=cache_seqlens, - cache_batch_idx=cache_batch_idx, - cache_leftpad=cache_leftpad, - block_table=block_table, - causal=causal, - window_size=window_size, - rotary_interleaved=rotary_interleaved, - alibi_slopes=alibi_slopes, - num_splits=num_splits, - ) - # out = flash_attn_with_kvcache( - # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size - # ) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) - # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) - # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) - # probs = torch.softmax(qk, dim=-1) - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - key_padding_mask, - attn_bias, - 0.0, - None, - causal=causal, - window_size=window_size, - key_leftpad=cache_leftpad, - ) - out_pt, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - key_padding_mask, - attn_bias, - 0.0, - None, - causal=causal, - window_size=window_size, - upcast=False, - reorder_ops=True, - key_leftpad=cache_leftpad, - ) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - if new_kv: - if paged_kv_block_size is None: - k_cache_select = ( - k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] - ) - v_cache_select = ( - v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] - ) - else: - k_cache_select = rearrange( - k_cache_paged[block_table.to(dtype=torch.long).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k] - v_cache_select = rearrange( - v_cache_paged[block_table.to(dtype=torch.long).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k] - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) - assert torch.equal(v_cache_select, v_cache_ref) - # mult = 3 if f16, bf16 need 4 - mult = 4 if not alibi else 5 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 +# @pytest.mark.parametrize("dtype", [torch.float16]) +# @pytest.mark.parametrize("num_splits", [1, 0]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("new_kv", [False, True]) +# @pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) +# @pytest.mark.parametrize("rotary_interleaved", [False, True]) +# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +# @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) +# @pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize( +# "seqlen_q,seqlen_k", +# [ +# (1, 128), +# (1, 339), +# (3, 1024), +# (64, 800), +# (64, 256), +# (3, 799), +# (64, 2048), +# (16, 20000), +# (1, 128 * 1024), +# (16, 128 * 1024), +# (128, 128), +# ], +# ) +# def test_flash_attn_kvcache( +# seqlen_q, +# seqlen_k, +# d, +# has_batch_idx, +# has_leftpad, +# paged_kv_block_size, +# rotary_fraction, +# rotary_interleaved, +# seqlen_new_eq_seqlen_q, +# causal, +# local, +# alibi, +# new_kv, +# mha_type, +# num_splits, +# dtype, +# ): +# if seqlen_q > seqlen_k and new_kv: +# pytest.skip() +# if not new_kv and rotary_fraction > 0.0: +# pytest.skip() +# if has_batch_idx and paged_kv_block_size is not None: +# pytest.skip() +# if has_leftpad and paged_kv_block_size is not None: +# pytest.skip() +# device = "cuda" +# # set seed +# torch.random.manual_seed(0) +# batch_size = 1 +# batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 +# nheads = 6 +# # rotary_dim must be a multiple of 16, and must be <= d +# rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 +# nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) +# assert nheads % nheads_k == 0 +# window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) +# q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) +# seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() +# if new_kv: +# k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) +# v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) +# else: +# k, v = None, None +# if paged_kv_block_size is None: +# k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) +# v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) +# block_table = None +# else: +# ( +# k_cache, +# v_cache, +# block_table, +# k_cache_paged, +# v_cache_paged, +# num_blocks, +# ) = _generate_block_kvcache( +# seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype +# ) +# cache_seqlens = torch.randint( +# 0 if new_kv else 1, +# # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough +# ( +# (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) +# if new_kv +# else (seqlen_k + 1) +# ), +# (batch_size,), +# dtype=torch.int32, +# device=device, +# ) +# if has_leftpad: +# cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) +# if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) +# for i in range(batch_size)]) +# else: +# cache_leftpad = None +# arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") +# cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") +# key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) +# if has_leftpad: +# key_padding_mask = torch.logical_and( +# key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) +# ) +# if has_batch_idx: +# cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ +# :batch_size +# ] +# else: +# cache_batch_idx = None +# if alibi: +# alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 +# attn_bias = attn_bias_from_alibi_slopes( +# alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad +# ) +# else: +# alibi_slopes, attn_bias = None, None +# # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) +# if rotary_dim > 0: +# angle = ( +# torch.rand( +# seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, +# rotary_dim // 2, +# device=device, +# ) +# * 2 +# * math.pi +# ) +# cos = torch.cos(angle).to(dtype=dtype) +# sin = torch.sin(angle).to(dtype=dtype) +# if causal or local: +# q_ro = apply_rotary_emb( +# q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved +# ) +# else: +# q_ro = rearrange( +# apply_rotary_emb( +# rearrange(q, "b s h d -> b 1 (s h) d"), +# cos, +# sin, +# seqlen_offsets=cache_seqlens, +# interleaved=rotary_interleaved, +# ), +# "b 1 (s h) d -> b s h d", +# s=seqlen_q, +# ) +# # q_ro = q +# k_ro = apply_rotary_emb( +# k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved +# ) +# else: +# cos, sin = None, None +# q_ro, k_ro = q, k +# # k_cache[:, 64:] = -1 +# k_cache_ref = ( +# k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] +# ).clone() +# v_cache_ref = ( +# v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] +# ).clone() +# if new_kv: +# update_mask = torch.logical_and( +# cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new +# ) +# k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") +# v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") +# k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) +# v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) +# out = flash_attn_with_kvcache( +# q, +# k_cache if paged_kv_block_size is None else k_cache_paged, +# v_cache if paged_kv_block_size is None else v_cache_paged, +# k, +# v, +# rotary_cos=cos, +# rotary_sin=sin, +# cache_seqlens=cache_seqlens, +# cache_batch_idx=cache_batch_idx, +# cache_leftpad=cache_leftpad, +# block_table=block_table, +# causal=causal, +# window_size=window_size, +# rotary_interleaved=rotary_interleaved, +# alibi_slopes=alibi_slopes, +# num_splits=num_splits, +# ) +# # out = flash_attn_with_kvcache( +# # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size +# # ) +# # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) +# # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) +# # m = qk.amax(-1, keepdim=True) +# # s_tmp = torch.exp((qk - m) / math.sqrt(d)) +# # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) +# # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) +# # probs = torch.softmax(qk, dim=-1) +# out_ref, _ = attention_ref( +# q_ro, +# k_cache_rep, +# v_cache_rep, +# None, +# key_padding_mask, +# attn_bias, +# 0.0, +# None, +# causal=causal, +# window_size=window_size, +# key_leftpad=cache_leftpad, +# ) +# out_pt, _ = attention_ref( +# q_ro, +# k_cache_rep, +# v_cache_rep, +# None, +# key_padding_mask, +# attn_bias, +# 0.0, +# None, +# causal=causal, +# window_size=window_size, +# upcast=False, +# reorder_ops=True, +# key_leftpad=cache_leftpad, +# ) +# print(f"Output max diff: {(out - out_ref).abs().max().item()}") +# print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") +# print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") +# print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + +# # Check that FlashAttention's numerical error is at most twice the numerical error +# # of a Pytorch implementation. +# if new_kv: +# if paged_kv_block_size is None: +# k_cache_select = ( +# k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] +# ) +# v_cache_select = ( +# v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] +# ) +# else: +# k_cache_select = rearrange( +# k_cache_paged[block_table.to(dtype=torch.long).flatten()], +# "(b nblocks) block_size ... -> b (nblocks block_size) ...", +# b=batch_size, +# )[:, :seqlen_k] +# v_cache_select = rearrange( +# v_cache_paged[block_table.to(dtype=torch.long).flatten()], +# "(b nblocks) block_size ... -> b (nblocks block_size) ...", +# b=batch_size, +# )[:, :seqlen_k] +# assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) +# assert torch.equal(v_cache_select, v_cache_ref) +# # mult = 3 if f16, bf16 need 4 +# mult = 4 if not alibi else 5 +# assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5