-
Notifications
You must be signed in to change notification settings - Fork 558
[feat] add unified batch attention w/ correctness tests. #1137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @happierpig, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request implements a significant feature by introducing a unified, persistent CUDA kernel for batch attention that efficiently handles both prefill and decode operations. This change aims to improve performance, particularly in hybrid workloads, and enhance compatibility with CUDA graph capture by allowing a single graph to cover diverse batch configurations. The implementation includes dynamic kernel specialization and is accompanied by new benchmarks and correctness tests.
Highlights
- Unified Attention Kernel: Introduced a new
BatchAttentionAPI and underlying CUDA kernel designed to handle both prefill and decode requests within a single kernel launch, replacing the previous approach that used separate kernels. - Persistent Kernel & CUDA Graph Compatibility: The new kernel is implemented in a 'persistent-style', enabling a single CUDA graph to capture attention operations for varying batch sizes and sequence lengths. This contrasts with the prior method which required specialized CUDA graphs for different batch configurations.
- Dynamic Specialization: The persistent kernel dynamically selects efficient hyperparameters (like CTA tile sizes) based on the characteristics of the requests, specializing for short decode sequences (
qo_len=1) and longer prefill sequences (qo_len>=128). - Performance Improvements: Benchmarks show a 30% bandwidth boost in hybrid prefill/decode scenarios compared to the old implementation. Pure workloads may see slightly worse performance due to reduction overhead, which is noted as a future optimization area.
- Comprehensive Correctness Tests: Added a new test suite (
tests/test_batch_attention.py) that verifies the correctness of the newBatchAttentionimplementation against the oldBatchPrefillWithPagedKVCacheWrapperacross a wide range of sequence length configurations, page sizes, head counts, head dimensions, causality settings, layouts, and data types. - New Python API: Exposed the new functionality through a
flashinfer.BatchAttentionPython class, providingplanandrunmethods similar to existing wrappers but for the unified kernel.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a unified batch attention mechanism using a persistent CUDA kernel, aiming to improve efficiency, particularly for CUDA graph capture, by combining prefill and decode operations. The changes include new CUDA kernels, a scheduler, Python bindings, benchmarks, and correctness tests.
Key feedback points include:
- A JIT module caching issue in
flashinfer/attention.pythat could lead to unnecessary recompilations. - A duplicated function definition in
flashinfer/jit/attention/pytorch.py. - Some hardcoded values (e.g., type sizes, buffer limits) in the C++ scheduler and kernel code that might affect flexibility or robustness under different configurations.
- Minor typos in comments and code.
The overall approach of using a persistent kernel with dynamic specialization for prefill/decode and load balancing seems sound. The provided benchmarks and tests are valuable for verifying correctness and performance.
| global _batch_attention_modules | ||
| modules_dict = _batch_attention_modules | ||
| if args not in modules_dict: | ||
| module = gen_batch_attention_module(*args).build_and_load() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The JIT-compiled module is not being cached. After building a new module, it should be stored in modules_dict to avoid recompilation on subsequent calls with the same arguments. This can significantly impact performance if get_holistic_attention_module is called multiple times with identical configurations.
| module = gen_batch_attention_module(*args).build_and_load() | |
| module = gen_batch_attention_module(*args).build_and_load() | |
| modules_dict[args] = module # Cache the newly built module |
| def gen_batch_attention_module( | ||
| dtype_q: torch.dtype, | ||
| dtype_kv: torch.dtype, | ||
| dtype_o: torch.dtype, | ||
| dtype_idx: torch.dtype, | ||
| head_dim_qk: int, | ||
| head_dim_vo: int, | ||
| pos_encoding_mode: int, | ||
| ): | ||
| uri = get_batch_attention_uri( | ||
| dtype_q, | ||
| dtype_kv, | ||
| dtype_o, | ||
| dtype_idx, | ||
| head_dim_qk, | ||
| head_dim_vo, | ||
| pos_encoding_mode, | ||
| ) | ||
| additional_tensor_names = [] | ||
| additional_tensor_dtypes = [] | ||
| additional_scalar_names = [] | ||
| additional_scalar_dtypes = [] | ||
| variant_name = f"StandardAttention" | ||
| variant_decl = f"#include<flashinfer/attention/variants.cuh>" | ||
|
|
||
| return gen_customize_batch_attention_module( | ||
| uri, | ||
| dtype_q, | ||
| dtype_kv, | ||
| dtype_o, | ||
| dtype_idx, | ||
| head_dim_qk, | ||
| head_dim_vo, | ||
| additional_tensor_names, | ||
| additional_tensor_dtypes, | ||
| additional_scalar_names, | ||
| additional_scalar_dtypes, | ||
| variant_name, | ||
| variant_decl, | ||
| pos_encoding_mode=pos_encoding_mode, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); | ||
| FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, | ||
| cudaMemcpyHostToDevice, stream)); | ||
| constexpr size_t sizeof_dtype_o = 2; // NOTE (Yilong): assume fp16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The size of DTypeO is assumed to be 2 bytes (e.g., fp16/bf16). This assumption is hardcoded. If other output types with different sizes are to be supported in the future, this might need to be made more flexible, perhaps by using sizeof(DTypeO) from the template parameters if DTypeO were available here, or by passing the size.
| cluster_cost_heap.insert( | ||
| {cluster_idx, accum_cost + cost_function(cluster_tile_q, actual_len)}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding performance issues of reduction, I guess its cost should be included in the scheduler
see https://hazyresearch.stanford.edu/blog/2025-03-04-thundermla
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had some discussion w/ @happierpig , there are several optimizations we need to take into consideration:
- count the reduction cost as well (note that we can first adding reduction cost to the priority queue, and then prefill, and then decode, because reduction cost is long).
- consider two stage reduction: merging 64 chunks on one CTA -> (merging 8 chunks on 8 CTA -> merging 8 chunks on 1 CTA)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @Edenzzzz and @yzh119! Yeah, I think there are two levels of inefficiency of the reduction operations:
- Inter-stage wave quantization. Bubbles are introduced by
grid.syncbetween the core attention operation and the reduction operation. I think some fine-grained dependency control is needed for reducing this. After adding the barrier (or somehow guaranteeing the dependency), we can count the reduction cost as well. - Intra-stage balance. The reduction kernel itself is implemented in a coarse-grained manner. For e.g., some threads are idle when
seq_lenof thework_tileis short, while somework_tilehas super long seq_len can be further shared.(iter + num_smem_stages) * bdy + ty < num_index_sets);
I will take a look at it after adding profiler
| BlockPersistentRunner1::Run(params_1, &smem_storage_1); | ||
| PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kRunner1); | ||
|
|
||
| __syncthreads(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing the first __sync_threads() should be feasible, as all works between CTA in this stage are independent. However, grid_sync() here is for guarantee dependency, which assures reduction start only after all partial results are correctly calculated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use CTA wise barriers, instead of grid_sync
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use CTA wise barriers, instead of grid_sync
Agree and this fine-grained dependency control will be a necessary component for counting the reduction cost
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yzh119 i am curious about whether grid.sync has some guarantee on inter-CTA memory ordering. What if some partial results are kept in L1 while the other CTA try to do reduction and load stale data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the source of grid.sync() in cooperative_groups library, it basically relies on a arrived counter in global memory and updated with .gpu memory scope.
typedef unsigned int barrier_t;
_CG_STATIC_QUALIFIER bool bar_has_flipped(unsigned int old_arrive, unsigned int current_arrive) {
return (((old_arrive ^ current_arrive) & 0x80000000) != 0);
}
_CG_STATIC_QUALIFIER bool is_cta_master() {
return (threadIdx.x + threadIdx.y + threadIdx.z == 0);
}
_CG_STATIC_QUALIFIER unsigned int sync_grids_arrive(volatile barrier_t *arrived) {
unsigned int oldArrive = 0;
__barrier_sync(0);
if (is_cta_master()) {
unsigned int expected = gridDim.x * gridDim.y * gridDim.z;
bool gpu_master = (blockIdx.x + blockIdx.y + blockIdx.z == 0);
unsigned int nb = 1;
if (gpu_master) {
nb = 0x80000000 - (expected - 1);
}
NV_IF_ELSE_TARGET(NV_PROVIDES_SM_70,
// Barrier update with release; polling with acquire
asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;" : "=r"(oldArrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived), "r"(nb) : "memory");
,
// Fence; barrier update; volatile polling; fence
__threadfence();
oldArrive = atomicAdd((unsigned int*)arrived, nb);
);
}
return oldArrive;
}
_CG_STATIC_QUALIFIER void sync_grids_wait(unsigned int oldArrive, volatile barrier_t *arrived) {
if (is_cta_master()) {
NV_IF_ELSE_TARGET(NV_PROVIDES_SM_70,
unsigned int current_arrive;
do {
asm volatile("ld.acquire.gpu.u32 %0,[%1];" : "=r"(current_arrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived) : "memory");
} while (!bar_has_flipped(oldArrive, current_arrive));
,
while (!bar_has_flipped(oldArrive, *arrived));
__threadfence();
);
}
__barrier_sync(0);
}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i see. i believe this pair-wise acquire.gpu and release.gpu will guarantee the L1 flush.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's merge this in first and move on with following PRs to fix performance.
…ai#1137) <!-- .github/pull_request_template.md --> Follow up of flashinfer-ai#858, flashinfer-ai#967, and flashinfer-ai#1026, this PR aims to provide an efficient and unified API for processing prefill and decode requests within a single kernel launch. Key features include: 1. Single CUDA graph capture for all batch sizes and sequence lengths. Prior to this PR, FA2 template is implemented with a non-persistent kernel way, which dispatches `padded_batch_sizes` CTA and uses static information (ref: https://github.com/flashinfer-ai/flashinfer/blob/f484fd3c7f09a1d0afb75d779872b9762a35e445/include/flashinfer/attention/scheduler.cuh#L527). This necessitates a specialized CUDA graph for each batch with different seqlens and batch sizes, to maximize throughput. Furthermore, prefill and decode are executed by different kernel launches, increasing the number of CUDA graphs by combination. This PR implements a persistent-style kernel, which enables a single CUDA graph to capture work for all seqlens and batch sizes. 2. Dynamic specialization for prefill and decode. Implemented as a persistent kernel, prefill and decode requests are dynamically executed by an efficient kernel template with suitable hyperparameters. For example, decode requests with `qo_len=1` are processed by `CTA_TILE_Q=16` while prefill requests with `qo_len>=128` are processed by `CTA_TILE_Q=128`. The benchmark script is at `benchmarks/bench_batch_attention.py` and was tested with Qwen-2.5-7B configurations and a single H200. Visualization: <img width="594" alt="image" src="https://github.com/user-attachments/assets/735aca14-387d-4013-b3f4-e199b6cff5f3" /> 1. 30% bandwidth boost in hybrid scenarios 2. slightly worse perf at pure workloads, which may be caused by the reduction overhead Unit tests can be located at `tests/bench_batch_attention.py`. <img width="1527" alt="image" src="https://github.com/user-attachments/assets/fff06c6d-c121-497c-9f62-039653149a4d" /> 1. Add profiler to analyze perf bottleneck 4. Optimize the reduction kernel schedule <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> flashinfer-ai#1022 Advised by @yzh119. CC @AKKamath @Edenzzzz <!-- Link any related issues here --> Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> Co-authored-by: yzh119 <expye@outlook.com> Co-authored-by: happierpig <zhaoyilong217@sjtu.edn.cn>
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> This is a follow-up PR on #1137, optimizing scheduling balance & reduction overhead, and achieving 2x speedup. Optimization includes: 1. Conduct fewer kv-split on decode requests. This reduces the number of reduction operations needed. 2. Write-through. Attention without kv-split is directly written back into the final buffer without additional copy. 3. Warp-level reduction. Changing CTA-wise reduction into warp-wise reduction increase parallelism, greatly reducing the reduction overhead. This PR also investigates _**fine-grained gmem barrier**_ as discussed in #1137. However, due to the overhead of memory ordering operations, we do not see performance gain on our usecases (static scheduling), thus leaving for future work. Implementations are available in a separate fork (https://github.com/happierpig/flashinfer-ai/tree/fine-grained-barrier-reduction). Visualization: <img width="1487" alt="image" src="https://github.com/user-attachments/assets/27ebc173-6ef6-4b21-9277-3793e270c9f6" /> Perf benchmarks (`benchmarks/bench_batch_attention.py`) on `head_dim=128, num_qo_heads=28, num_kv_heads=4, page_size=1` with H200. <img width="609" alt="image" src="https://github.com/user-attachments/assets/361d8c43-3fb6-431b-bf9e-2bf3a7a0df87" /> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @yzh119 @yyihuang <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> Co-authored-by: happierpig <zhaoyilong217@sjtu.edn.cn>
… in Qwen family (#1614) <!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> This PR fixes precision issues of BatchAttention (Persistent FA2 of #1137), when `CTA_TILE_Q` is not a multiple of `gqa_group_size` (e.g., Qwen family models). Prior implementation assumes that all `qo_heads` of a `kv_head` on a specific token will all be split-kv or non-split-kv. However, when `gqa-group-size == 7`, some `qo_heads` can be non-split while the remaining can be split. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> cc @Edenzzzz


📌 Description
Follow up of #858, #967, and #1026, this PR aims to provide an efficient and unified API for processing prefill and decode requests within a single kernel launch. Key features include:
padded_batch_sizesCTA and uses static information (ref:flashinfer/include/flashinfer/attention/scheduler.cuh
Line 527 in f484fd3
qo_len=1are processed byCTA_TILE_Q=16while prefill requests withqo_len>=128are processed byCTA_TILE_Q=128.Perf Benchmarks:
The benchmark script is at

benchmarks/bench_batch_attention.pyand was tested with Qwen-2.5-7B configurations and a single H200. Visualization:Unit Tests:
Unit tests can be located at

tests/bench_batch_attention.py.Future works:
🔍 Related Issues
#1022
Advised by @yzh119. CC @AKKamath @Edenzzzz
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes