-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for small page sizes #824
base: main
Are you sure you want to change the base?
Conversation
Thanks for your great work! Small page size is important for llm inference framework. Expect this pr could be merged soon. |
Fixed issue with fused RoPE embeddings - should be ready for review. |
*apply change from pull request : Dao-AILab/flash-attention#824
*apply change from pull request : Dao-AILab/flash-attention#824
Hi, I am waiting for this PR! Is this planning to be merged soon? Also, can I ask when it is planned to be released? |
Not sure - @tridao if you have time, would greatly appreciate a review so I can make any changes necessary to get this PR merged! |
Hi @skrider, thanks for the great work! Based on my test, this kernel is 1.5-4x faster than the triton equivalent. But when I use it for end-to-end testing in vLLM, I hit import torch
from flash_attn import flash_attn_with_kvcache
def cdiv(a, b):
return (a + b - 1) // b
block_size = 16
num_blocks = 1000*16//block_size
bs = 4
seq_len = 170
num_heads = 32
head_dim = 128
key_cache = torch.rand([num_blocks, block_size, num_heads, head_dim]).half().cuda()
value_cache = torch.rand([num_blocks, block_size, num_heads, head_dim]).half().cuda()
cache_seqlens = torch.zeros(bs, dtype=torch.int32).cuda()
for _ in range(1000):
query = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
key = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
value = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
block_tables = torch.randint(0, num_blocks, size=(bs, cdiv(seq_len, block_size)), dtype=torch.int32, device="cuda")
output = flash_attn_with_kvcache(
query,
key_cache,
value_cache,
k=key,
v=value,
cache_seqlens=cache_seqlens,
block_table=block_tables,
causal=True,
) Error message:
Some observations:
|
@ymwangg Thanks for the heads up - I will look into it. Reproducing with the provided code, I believe the error is with an async copy not being properly awaited. Synchronizing after every launch and setting a manual seed does not get rid of the nondeterminism. Additionally, if I only run the kernel every two iterations, then the kernel never errors. Small num_heads also gets rid of the issue. To me this suggests that the state of the L2 cache is correlated with the error. Signing off for tonight but will revisit when I have time. |
@ymwangg is it possible that passing k=key and v=value each iteration causing seq_len=170 tokens appended to the kvcache each time, which overflows after couple of iterations? |
My understanding is that this function does not allocate new memory but rather using |
after inspection locally found that the illegal mem access is caused by table_diff calculation overflows and propagated to further iterations. since the n_block is iterated in reverse order, the calculated virtual_page_idx_next of the page_table may be larger than the table allocated in the first round, getting undetermined table_diffs and never get fixed by advancing tKgK.data() relatively. so the fix is straight forward: use tested locally the illegal access is gone and the flash_attn_kvcache tests are passed. |
Btw, do we plan to merge this soon? |
The fix mentioned by @gnap was implemented by @ymwangg in this commit: ymwangg@7354198. @gnap, would you mind checking it to verify that's what you had in mind? @ymwangg told me the illegal access is gone with that commit on top of this PR. Could someone pull that fix into this PR and fix the conflicts so this can be merged? |
@davidthomas426 @ymwangg I have checked that commit and the modification is mutually identical with my local change. currently I am conducting more tests with our internal inference engine. but if the vllm community tests okay, feel free to commit or notify @skrider to update this PR. |
Thanks for your great work! Does this PR support varlen with KV block: 2a15840 |
Thank you everyone for all the help! I will review locally and push the fix. I used the difference between page indices rather than calculating the offset directly because that's how it was done originally. Besides saving a register I am not sure if there are any advantages to doing this. @gnap curious what your process was for finding the bug?
In progress, expect it sometime next week |
These changes pass unit tests for standard and varlen APIs as well as the example provided above by @ymwangg |
by ran the |
Thanks so much for your work @skrider. Can you rebase and then I'll merge? |
@skrider Are you going to rebase this so it can get merged? |
@tridao absolutely! Sorry, just seeing this. Notification fell through the cracks. |
@skrider It looks like rebasing was pretty easy. would you mind if I just create a PR to your branch? (I just ran git merge main, and no conflict) |
// assumes that the tensor has already been positioned at the correct head. | ||
template <typename Kernel_traits> | ||
__forceinline__ __device__ | ||
int resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your great work!
I think int64 should be used here, because using int32 may overflow when page is multiplied by stride. The following code will raise an illegal memory access error in my test (A100 40G)
import torch
from flash_attn import flash_attn_varlen_func
torch.manual_seed(0)
num_pages = 4048
page_id = 4000
page_size = 256
num_heads = 32
head_size = 128
seq_len = 13
q = torch.randn(seq_len, 32, 128, device="cuda", dtype=torch.float16)
k_cache = torch.zeros(num_pages, page_size, num_heads, head_size, device="cuda", dtype=torch.float16)
v_cache = torch.zeros(num_pages, page_size, num_heads, head_size, device="cuda", dtype=torch.float16)
cu_seqlens_q = torch.tensor([0, seq_len], device="cuda", dtype=torch.int32)
seqlens_k = torch.tensor([seq_len], device="cuda", dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, seq_len], device="cuda", dtype=torch.int32)
block_table=torch.tensor([[page_id]], device='cuda', dtype=torch.int32)
k_cache[page_id, :seq_len] = torch.randn(seq_len, num_heads, head_size, device="cuda", dtype=torch.float16)
v_cache[page_id, :seq_len] = torch.randn(seq_len, num_heads, head_size, device="cuda", dtype=torch.float16)
flash_attn_varlen_func(
q=q,
k=k_cache,
v=v_cache,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
causal=True,
block_table=block_table,
)
@skrider Hi, any updates on this PR? |
Any updates on the current PR ? @skrider |
This PR has already been merged into this repository and is now part of the version of flash_attention, which vllm depends on. |
Recently, support has been added for paged attention with large page sizes of 256 tokens. However, projects which use paged attention prefer smaller page sizes of around 16. This PR adds support for smaller page sizes by reshaping the GMEM -> SMEM copy to ensure that in each iteration of the mainloop each thread fetches only from a single page. Hence physical page addresses need only be resolved at the beginning of each mainloop iteration and can be resolved per-thread rather than per-CTA.
Preliminary benchmarking with
ncu
on the unit testing suite shows no degradation in performance.