-
Notifications
You must be signed in to change notification settings - Fork 558
[Refactor] Uniform PoDAttention API with Horizontal Fusion SMs Schedule #967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Some of the unittests failed, for example (test_block_sparse_attention[False-256-16-16-128-64-16-4]) |
|
Hi, can I ask when this is planned to be merged? I made a PR to support POD Attn in SGLang using the old API and plan to get that working with CUDA graph first. |
|
I really like the uniform batch API that this PR presents. I ran this on an A100 and compared it with the existing FlashInfer POD-Attention implementation. On average this performed around 10 - 15% worse, but still better than serial execution. Performance was worse for larger prefill context lengths, while for smaller context lengths the performance was more comparable. |
|
Yeah this is more convenient, one issue i had during my PR is that I have to fill 2D attention mask for prefill every time, instead using page table & indices |
|
Will the old API be preserved? Thanks. |
|
@AKKamath Btw, I wonder what was the reason for using a mask instead of page table for prefill qkv? |
|
@yzh119 Can correct me here, but I believe the mask prefill kernel (single_prefill) had a better performance than the page table prefill because the page table prefill had a higher register usage causing register spills. |
|
But don't we waste lots of space storing the 2D mask? For example, the default shape is 2D cumulative seq lens (qo_lens, kv_lens), but when converting from page table qo_indptr, kv_indptr to the mask it will be very sparse, with each qo related to only a few kv entries of the request in the whole cumulative sequence. It can also be expensive to fill the mask |
|
Actually I realized POD Attention is not designed to mix many prefill requests with decode requests, it just mixes one prefill at a time, so that we can use causal without any custom mask |
|
Follow up in #1026 . |
| std::accumulate(qo_len_ptr_h_p.begin(), qo_len_ptr_h_p.end(), 0) + | ||
| 2 * page_size * std::accumulate(kv_len_ptr_h_p.begin(), kv_len_ptr_h_p.end(), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I'm interested in implementing a persistent POD Attn and have some questions. Here why don't we do qo_len_ptr_h_p[i] * kv_len_ptr_h_p[i] * 2 to model the quadratic compute load? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using current calculation mainly for modeling memory load instead of compute load. For different workloads, this calculation can have different best heuristics. It will be helpful if you do benchmarking and decide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, #1026 will be the upstream version and this PR has been deprecated. It would be helpful if you could refer directly to the new PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Do you have any plans to adapt POD Attn to the persistent template? I also plan to work on that
<!-- .github/pull_request_template.md --> ## 📌 Description Follow up of #858, #967, and #1026, this PR aims to provide an efficient and unified API for processing prefill and decode requests within a single kernel launch. Key features include: 1. Single CUDA graph capture for all batch sizes and sequence lengths. Prior to this PR, FA2 template is implemented with a non-persistent kernel way, which dispatches `padded_batch_sizes` CTA and uses static information (ref: https://github.com/flashinfer-ai/flashinfer/blob/f484fd3c7f09a1d0afb75d779872b9762a35e445/include/flashinfer/attention/scheduler.cuh#L527). This necessitates a specialized CUDA graph for each batch with different seqlens and batch sizes, to maximize throughput. Furthermore, prefill and decode are executed by different kernel launches, increasing the number of CUDA graphs by combination. This PR implements a persistent-style kernel, which enables a single CUDA graph to capture work for all seqlens and batch sizes. 2. Dynamic specialization for prefill and decode. Implemented as a persistent kernel, prefill and decode requests are dynamically executed by an efficient kernel template with suitable hyperparameters. For example, decode requests with `qo_len=1` are processed by `CTA_TILE_Q=16` while prefill requests with `qo_len>=128` are processed by `CTA_TILE_Q=128`. ## Perf Benchmarks: The benchmark script is at `benchmarks/bench_batch_attention.py` and was tested with Qwen-2.5-7B configurations and a single H200. Visualization: <img width="594" alt="image" src="https://github.com/user-attachments/assets/735aca14-387d-4013-b3f4-e199b6cff5f3" /> 1. 30% bandwidth boost in hybrid scenarios 2. slightly worse perf at pure workloads, which may be caused by the reduction overhead ## Unit Tests: Unit tests can be located at `tests/bench_batch_attention.py`. <img width="1527" alt="image" src="https://github.com/user-attachments/assets/fff06c6d-c121-497c-9f62-039653149a4d" /> ## Future works: 1. Add profiler to analyze perf bottleneck 4. Optimize the reduction kernel schedule <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues #1022 Advised by @yzh119. CC @AKKamath @Edenzzzz <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> Co-authored-by: yzh119 <expye@outlook.com> Co-authored-by: happierpig <zhaoyilong217@sjtu.edn.cn>
…ai#1137) <!-- .github/pull_request_template.md --> Follow up of flashinfer-ai#858, flashinfer-ai#967, and flashinfer-ai#1026, this PR aims to provide an efficient and unified API for processing prefill and decode requests within a single kernel launch. Key features include: 1. Single CUDA graph capture for all batch sizes and sequence lengths. Prior to this PR, FA2 template is implemented with a non-persistent kernel way, which dispatches `padded_batch_sizes` CTA and uses static information (ref: https://github.com/flashinfer-ai/flashinfer/blob/f484fd3c7f09a1d0afb75d779872b9762a35e445/include/flashinfer/attention/scheduler.cuh#L527). This necessitates a specialized CUDA graph for each batch with different seqlens and batch sizes, to maximize throughput. Furthermore, prefill and decode are executed by different kernel launches, increasing the number of CUDA graphs by combination. This PR implements a persistent-style kernel, which enables a single CUDA graph to capture work for all seqlens and batch sizes. 2. Dynamic specialization for prefill and decode. Implemented as a persistent kernel, prefill and decode requests are dynamically executed by an efficient kernel template with suitable hyperparameters. For example, decode requests with `qo_len=1` are processed by `CTA_TILE_Q=16` while prefill requests with `qo_len>=128` are processed by `CTA_TILE_Q=128`. The benchmark script is at `benchmarks/bench_batch_attention.py` and was tested with Qwen-2.5-7B configurations and a single H200. Visualization: <img width="594" alt="image" src="https://github.com/user-attachments/assets/735aca14-387d-4013-b3f4-e199b6cff5f3" /> 1. 30% bandwidth boost in hybrid scenarios 2. slightly worse perf at pure workloads, which may be caused by the reduction overhead Unit tests can be located at `tests/bench_batch_attention.py`. <img width="1527" alt="image" src="https://github.com/user-attachments/assets/fff06c6d-c121-497c-9f62-039653149a4d" /> 1. Add profiler to analyze perf bottleneck 4. Optimize the reduction kernel schedule <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> flashinfer-ai#1022 Advised by @yzh119. CC @AKKamath @Edenzzzz <!-- Link any related issues here --> Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> Co-authored-by: yzh119 <expye@outlook.com> Co-authored-by: happierpig <zhaoyilong217@sjtu.edn.cn>
Description
This PR is a follow-up to #858, which integrates the PoDAttention (arXiv link) API in a user-transparent manner. Users can now invoke PoDAttention via the same API as
BatchPrefillWithPagedKVCache, without explicitly specifying whether requests are prefill or decode (example code).Key Changes
Support for Non-Continuous Q/O and KV Tensor Layout
Previously, tensor offsets were computed using
indptr, assuming continuous layouts. PoDAttention requires supporting mixed prefill/decode subsets within requests, necessitating a non-continuous layout.q_lenptrandkv_lenptrto accommodate this functionality (code link).Horizontal Fusion-Style Implementation
For improved efficiency, subsets of requests are aware of each other, enabling optimal selection of kernel hyperparameters and persistent kernel execution.
Limitations and Future Work
qo_len > threshold) is preliminary and requires improvement (classifier implementation).cc @AKKamath @yzh119