-
Notifications
You must be signed in to change notification settings - Fork 208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make flashinfer kernels cuda graphs friendly #187
Comments
Hi @AgrawalAmey , thanks for bringing this up, I have some ideas about the CUDA graph integration with flashinfer: The kernels to be executed can be determined before the a decode/prefill step (for all layers) by analyze the shapes, we can compile the CUDA Graph for all possible combinations (not too many) ahead of time, and dispatch to one of them according to the shapes. Regarding dynamic parallelism:
It sounds tricky to me because the required shared memory size/grid size varies for different schedules. |
Hi @yzh119! I have one implementation in sarathi-serve which tries to list different combinations, and capture them. But with increasing batch size and big variance in input sequences, the number of possibilities seemed explode. Plus, prefill + decode requests clubbed together makes it further more challenging. The memory cost of cuda graphs becomes too high as the number of combinations increases. The child kernel/dynamic parallelism proposal is aimed to solve the challenge with different grid size etc. Essentially, the launcher kernel will be triggered with a single warp. Inside the launcher kernel, we can determine all the launch params and launch the actual attention kernel. |
A sample program to explain what I mean:
|
Thanks for your explaination, that's sounds reasonable. To proceed, I'd love to write some documentations on our dispatching rules and see if we can describe them in dynamic parallelism. Before that I have to make #75 done because it will affect our dispatching strategy. I'll be glad to follow up next week and we can schedule a meeting on zoom (you can drop me an email at zhye@cs.washington.edu). |
Yes, that would be great, I will send out a when2meet link on email, thank you! |
Hi, @AgrawalAmey, will your sarathi or sarathi-serve be open-sourced? |
Hey @ZSL98, we are working with the vLLM team to get Sarathi-Serve scheduler support inside vLLM |
As requested in #187 , this PR adds initial support of `CUDAGraph` compatibility of flashinfer batch decode attention kernels. This PR is the first step towards full CUDAGraph support and we will implement CUDAGraph compatible prefill operators in later PRs. # Proposed APIs We add another wrapper `CUDAGraphBatchDecodeWithPagedKVCacheWrapper`, and user need to pre-allocation page data structure buffers to initialize this wrapper class. Once initiated, these buffers are pinned on GPUs in the life cycle of the wrapper class. The behavior of `CUDAGraphBatchDecodeWithPagedKVCacheWrapper` is a little bit different from `BatchDecodeWithPagedKVCacheWrapper`'s: we will only run a fixed set of kernels in CUDAGraph mode, no matter what the input shape is (the original implementation will dispatch to different kernels according to different input shapes). This PR also fix the address of all kernel input pointers to accomodate the constraint of CUDAGraph capturing. # Examples See `test_cuda_graph_batch_decode_with_paged_kv_cache` in unittests. `begin_forward` functions should not be captured as some of the operators are not allowed to be captured. cc @AgrawalAmey @LiuXiaoxuanPKU @comaniac
The CUDA graph compatibility was resolved in https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.0.5 The current strategy is:
|
@yzh119 thanks a lot for all the amazing work! I wanted to understand split-k behaves when the sequence length is significantly different between capture and replay time. For instance, if during capture we have seq length of 1k and during replay we have a seq of length 100k, would the parallelization parameters get applied appropriately? |
Yes, they will be properly handled. When cudagraph is enabled, we decides whether to split-k only on batch size (for decode) and query lengths (for append), not on kv-cache length, that said, so it's safe to capture when kv-cache length is small (we have test cases for capturing for small kv-length and replay with long: flashinfer/python/tests/test_batch_decode_kernels.py Lines 136 to 286 in 231b1dc
|
There is one tricky part about prefill kernels, we pass
and we will change its value in |
This is great! Thanks a lot for the in-depth description. I will go ahead and add cuda graph support in sarathi-serve based on this. |
Thanks for creating these awesome kernels! I am trying to get flashinfer kernels to work with cuda graphs. But it appears that several parallelism decisions (block size, num_q_tiles, etc.) are made on the fly based on the input data in the forward function. This makes it difficult to capture flashinfer kernels in cuda graphs in a generic manner. I think one solution to the problem would be to introduce a launcher kernel which would factor in the input metadata and launch the actual the actual cuda kernel using dynamic parallelism. Towards that, following are the items I have identified --
@yzh119 please let me know what would be the best way to proceed?
The text was updated successfully, but these errors were encountered: