Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for small page sizes #824

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

skrider
Copy link

@skrider skrider commented Feb 13, 2024

Recently, support has been added for paged attention with large page sizes of 256 tokens. However, projects which use paged attention prefer smaller page sizes of around 16. This PR adds support for smaller page sizes by reshaping the GMEM -> SMEM copy to ensure that in each iteration of the mainloop each thread fetches only from a single page. Hence physical page addresses need only be resolved at the beginning of each mainloop iteration and can be resolved per-thread rather than per-CTA.

Preliminary benchmarking with ncu on the unit testing suite shows no degradation in performance.

@skrider skrider marked this pull request as draft February 13, 2024 08:16
@zhaoyang-star
Copy link

Thanks for your great work! Small page size is important for llm inference framework. Expect this pr could be merged soon.

@skrider
Copy link
Author

skrider commented Feb 26, 2024

Fixed issue with fused RoPE embeddings - should be ready for review.

guocuimi added a commit to vectorch-ai/ScaleLLM that referenced this pull request Feb 27, 2024
guocuimi added a commit to vectorch-ai/ScaleLLM that referenced this pull request Feb 27, 2024
@rkooo567
Copy link

rkooo567 commented Feb 28, 2024

Hi, I am waiting for this PR! Is this planning to be merged soon? Also, can I ask when it is planned to be released?

@skrider
Copy link
Author

skrider commented Feb 29, 2024

Not sure - @tridao if you have time, would greatly appreciate a review so I can make any changes necessary to get this PR merged!

@ymwangg
Copy link

ymwangg commented Mar 2, 2024

Hi @skrider, thanks for the great work! Based on my test, this kernel is 1.5-4x faster than the triton equivalent. But when I use it for end-to-end testing in vLLM, I hit RuntimeError: CUDA error: an illegal memory access was encountered.
Below is the minimum code to reproduce:

import torch
from flash_attn import flash_attn_with_kvcache

def cdiv(a, b):
    return (a + b - 1) // b

block_size = 16
num_blocks = 1000*16//block_size
bs = 4
seq_len = 170
num_heads = 32
head_dim = 128

key_cache = torch.rand([num_blocks, block_size, num_heads, head_dim]).half().cuda()
value_cache = torch.rand([num_blocks, block_size, num_heads, head_dim]).half().cuda()
cache_seqlens = torch.zeros(bs, dtype=torch.int32).cuda()

for _ in range(1000):
    query = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
    key = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
    value = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
    block_tables = torch.randint(0, num_blocks, size=(bs, cdiv(seq_len, block_size)), dtype=torch.int32, device="cuda")
    output = flash_attn_with_kvcache(
        query,
        key_cache,
        value_cache,
        k=key,
        v=value,
        cache_seqlens=cache_seqlens,
        block_table=block_tables,
        causal=True,
    )

Error message:

Traceback (most recent call last):
  File "/home/ubuntu/src/vllm-test/debug.py", line 21, in <module>
    value = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Some observations:

  • This only occurs in the prefill stage and it happens sporadically. Using it for decoding (single query or multi-query) seems fine.
  • The error is gone after increasing the block_size to 256.
  • The error still exists after removing (k=key, v=value). So the illegal memory access may happen when reading from page blocks.

@skrider
Copy link
Author

skrider commented Mar 2, 2024

@ymwangg Thanks for the heads up - I will look into it.

Reproducing with the provided code, I believe the error is with an async copy not being properly awaited. Synchronizing after every launch and setting a manual seed does not get rid of the nondeterminism. Additionally, if I only run the kernel every two iterations, then the kernel never errors. Small num_heads also gets rid of the issue. To me this suggests that the state of the L2 cache is correlated with the error. Signing off for tonight but will revisit when I have time.

@gnap
Copy link

gnap commented Mar 11, 2024

@ymwangg is it possible that passing k=key and v=value each iteration causing seq_len=170 tokens appended to the kvcache each time, which overflows after couple of iterations?

@ymwangg
Copy link

ymwangg commented Mar 11, 2024

@ymwangg is it possible that passing k=key and v=value each iteration causing seq_len=170 tokens appended to the kvcache each time, which overflows after couple of iterations?

My understanding is that this function does not allocate new memory but rather using block_table to identify the memory address to read/write. So as long as the block_id in block_table is valid, it should not cause overflow issue.

@gnap
Copy link

gnap commented Mar 14, 2024

after inspection locally found that the illegal mem access is caused by table_diff calculation overflows and propagated to further iterations.

since the n_block is iterated in reverse order, the calculated virtual_page_idx_next of the page_table may be larger than the table allocated in the first round, getting undetermined table_diffs and never get fixed by advancing tKgK.data() relatively.

so the fix is straight forward: use init_thread_kv_page_slice_offset(..., n_block, ...) to calculate the absolute offset and add to gK.data()/gV.data() directly. the copy_w_min_idx() would guarantee that only rows in range are copied.

tested locally the illegal access is gone and the flash_attn_kvcache tests are passed.

@rkooo567
Copy link

Btw, do we plan to merge this soon?

@davidthomas426
Copy link

The fix mentioned by @gnap was implemented by @ymwangg in this commit: ymwangg@7354198.

@gnap, would you mind checking it to verify that's what you had in mind? @ymwangg told me the illegal access is gone with that commit on top of this PR.

Could someone pull that fix into this PR and fix the conflicts so this can be merged?

@gnap
Copy link

gnap commented Mar 21, 2024

@davidthomas426 @ymwangg I have checked that commit and the modification is mutually identical with my local change. currently I am conducting more tests with our internal inference engine. but if the vllm community tests okay, feel free to commit or notify @skrider to update this PR.

@mjp9527
Copy link

mjp9527 commented Mar 21, 2024

Thanks for your great work! Does this PR support varlen with KV block: 2a15840

@skrider
Copy link
Author

skrider commented Mar 22, 2024

Thank you everyone for all the help! I will review locally and push the fix. I used the difference between page indices rather than calculating the offset directly because that's how it was done originally. Besides saving a register I am not sure if there are any advantages to doing this.

@gnap curious what your process was for finding the bug?

Thanks for your great work! Does this PR support varlen with KV block: 2a15840

In progress, expect it sometime next week

@skrider
Copy link
Author

skrider commented Mar 26, 2024

These changes pass unit tests for standard and varlen APIs as well as the example provided above by @ymwangg

@gnap
Copy link

gnap commented Mar 28, 2024

@gnap curious what your process was for finding the bug?

by ran the compute-santinizer --tool memcheck against the reproduction code @ymwangg provided, which showed that some threads did access memory addresses way smaller than gK, gV's gmem_ptr, then with some printings did find that table_diffs could be larger than the partitioned copy tile's strides.

@tridao
Copy link
Contributor

tridao commented Apr 10, 2024

Thanks so much for your work @skrider. Can you rebase and then I'll merge?

@davidthomas426
Copy link

@skrider Are you going to rebase this so it can get merged?

@skrider
Copy link
Author

skrider commented May 2, 2024

@tridao absolutely! Sorry, just seeing this. Notification fell through the cracks.

@rkooo567
Copy link

rkooo567 commented May 3, 2024

@skrider It looks like rebasing was pretty easy. would you mind if I just create a PR to your branch? (I just ran git merge main, and no conflict)

// assumes that the tensor has already been positioned at the correct head.
template <typename Kernel_traits>
__forceinline__ __device__
int resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your great work!
I think int64 should be used here, because using int32 may overflow when page is multiplied by stride. The following code will raise an illegal memory access error in my test (A100 40G)

import torch
from flash_attn import flash_attn_varlen_func

torch.manual_seed(0)

num_pages = 4048
page_id = 4000
page_size = 256
num_heads = 32
head_size = 128
seq_len = 13

q = torch.randn(seq_len, 32, 128, device="cuda", dtype=torch.float16)
k_cache = torch.zeros(num_pages, page_size, num_heads, head_size, device="cuda", dtype=torch.float16)
v_cache = torch.zeros(num_pages, page_size, num_heads, head_size, device="cuda", dtype=torch.float16)
cu_seqlens_q = torch.tensor([0, seq_len], device="cuda", dtype=torch.int32)
seqlens_k = torch.tensor([seq_len], device="cuda", dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, seq_len], device="cuda", dtype=torch.int32)
block_table=torch.tensor([[page_id]], device='cuda', dtype=torch.int32)

k_cache[page_id, :seq_len] = torch.randn(seq_len, num_heads, head_size, device="cuda", dtype=torch.float16)
v_cache[page_id, :seq_len] = torch.randn(seq_len, num_heads, head_size, device="cuda", dtype=torch.float16)

flash_attn_varlen_func(
    q=q,
    k=k_cache,
    v=v_cache,
    cu_seqlens_q=cu_seqlens_q,
    cu_seqlens_k=cu_seqlens_k,
    max_seqlen_q=seq_len,
    max_seqlen_k=seq_len,
    causal=True,
    block_table=block_table,
)

@yangelaboy
Copy link

@skrider Hi, any updates on this PR?

@jorgeantonio21
Copy link
Contributor

Any updates on the current PR ? @skrider

@itsliupeng
Copy link

This PR has already been merged into this repository and is now part of the version of flash_attention, which vllm depends on.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.