Skip to content

Conversation

@happierpig
Copy link
Collaborator

📌 Description

Follow up of #858, #967, and #1026, this PR aims to provide an efficient and unified API for processing prefill and decode requests within a single kernel launch. Key features include:

  1. Single CUDA graph capture for all batch sizes and sequence lengths. Prior to this PR, FA2 template is implemented with a non-persistent kernel way, which dispatches padded_batch_sizes CTA and uses static information (ref:
    const uint64_t max_seq_len = total_num_rows - batch_size + 1;
    ). This necessitates a specialized CUDA graph for each batch with different seqlens and batch sizes, to maximize throughput. Furthermore, prefill and decode are executed by different kernel launches, increasing the number of CUDA graphs by combination. This PR implements a persistent-style kernel, which enables a single CUDA graph to capture work for all seqlens and batch sizes.
  2. Dynamic specialization for prefill and decode. Implemented as a persistent kernel, prefill and decode requests are dynamically executed by an efficient kernel template with suitable hyperparameters. For example, decode requests with qo_len=1 are processed by CTA_TILE_Q=16 while prefill requests with qo_len>=128 are processed by CTA_TILE_Q=128.

Perf Benchmarks:

The benchmark script is at benchmarks/bench_batch_attention.py and was tested with Qwen-2.5-7B configurations and a single H200. Visualization:
image

  1. 30% bandwidth boost in hybrid scenarios
  2. slightly worse perf at pure workloads, which may be caused by the reduction overhead

Unit Tests:

Unit tests can be located at tests/bench_batch_attention.py.
image

Future works:

  1. Add profiler to analyze perf bottleneck
  2. Optimize the reduction kernel schedule

🔍 Related Issues

#1022

Advised by @yzh119. CC @AKKamath @Edenzzzz

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

@happierpig happierpig requested a review from yzh119 June 11, 2025 22:03
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @happierpig, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request implements a significant feature by introducing a unified, persistent CUDA kernel for batch attention that efficiently handles both prefill and decode operations. This change aims to improve performance, particularly in hybrid workloads, and enhance compatibility with CUDA graph capture by allowing a single graph to cover diverse batch configurations. The implementation includes dynamic kernel specialization and is accompanied by new benchmarks and correctness tests.

Highlights

  • Unified Attention Kernel: Introduced a new BatchAttention API and underlying CUDA kernel designed to handle both prefill and decode requests within a single kernel launch, replacing the previous approach that used separate kernels.
  • Persistent Kernel & CUDA Graph Compatibility: The new kernel is implemented in a 'persistent-style', enabling a single CUDA graph to capture attention operations for varying batch sizes and sequence lengths. This contrasts with the prior method which required specialized CUDA graphs for different batch configurations.
  • Dynamic Specialization: The persistent kernel dynamically selects efficient hyperparameters (like CTA tile sizes) based on the characteristics of the requests, specializing for short decode sequences (qo_len=1) and longer prefill sequences (qo_len>=128).
  • Performance Improvements: Benchmarks show a 30% bandwidth boost in hybrid prefill/decode scenarios compared to the old implementation. Pure workloads may see slightly worse performance due to reduction overhead, which is noted as a future optimization area.
  • Comprehensive Correctness Tests: Added a new test suite (tests/test_batch_attention.py) that verifies the correctness of the new BatchAttention implementation against the old BatchPrefillWithPagedKVCacheWrapper across a wide range of sequence length configurations, page sizes, head counts, head dimensions, causality settings, layouts, and data types.
  • New Python API: Exposed the new functionality through a flashinfer.BatchAttention Python class, providing plan and run methods similar to existing wrappers but for the unified kernel.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a unified batch attention mechanism using a persistent CUDA kernel, aiming to improve efficiency, particularly for CUDA graph capture, by combining prefill and decode operations. The changes include new CUDA kernels, a scheduler, Python bindings, benchmarks, and correctness tests.

Key feedback points include:

  • A JIT module caching issue in flashinfer/attention.py that could lead to unnecessary recompilations.
  • A duplicated function definition in flashinfer/jit/attention/pytorch.py.
  • Some hardcoded values (e.g., type sizes, buffer limits) in the C++ scheduler and kernel code that might affect flexibility or robustness under different configurations.
  • Minor typos in comments and code.

The overall approach of using a persistent kernel with dynamic specialization for prefill/decode and load balancing seems sound. The provided benchmarks and tests are valuable for verifying correctness and performance.

global _batch_attention_modules
modules_dict = _batch_attention_modules
if args not in modules_dict:
module = gen_batch_attention_module(*args).build_and_load()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The JIT-compiled module is not being cached. After building a new module, it should be stored in modules_dict to avoid recompilation on subsequent calls with the same arguments. This can significantly impact performance if get_holistic_attention_module is called multiple times with identical configurations.

Suggested change
module = gen_batch_attention_module(*args).build_and_load()
module = gen_batch_attention_module(*args).build_and_load()
modules_dict[args] = module # Cache the newly built module

Comment on lines +897 to +937
def gen_batch_attention_module(
dtype_q: torch.dtype,
dtype_kv: torch.dtype,
dtype_o: torch.dtype,
dtype_idx: torch.dtype,
head_dim_qk: int,
head_dim_vo: int,
pos_encoding_mode: int,
):
uri = get_batch_attention_uri(
dtype_q,
dtype_kv,
dtype_o,
dtype_idx,
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
)
additional_tensor_names = []
additional_tensor_dtypes = []
additional_scalar_names = []
additional_scalar_dtypes = []
variant_name = f"StandardAttention"
variant_decl = f"#include<flashinfer/attention/variants.cuh>"

return gen_customize_batch_attention_module(
uri,
dtype_q,
dtype_kv,
dtype_o,
dtype_idx,
head_dim_qk,
head_dim_vo,
additional_tensor_names,
additional_tensor_dtypes,
additional_scalar_names,
additional_scalar_dtypes,
variant_name,
variant_decl,
pos_encoding_mode=pos_encoding_mode,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function gen_batch_attention_module is defined twice with the exact same signature and body (lines 854-894 and 897-937). One of these definitions is redundant and should be removed to improve code maintainability and avoid confusion.

size_t num_bytes_to_copy = int_allocator.num_allocated_bytes();
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy,
cudaMemcpyHostToDevice, stream));
constexpr size_t sizeof_dtype_o = 2; // NOTE (Yilong): assume fp16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The size of DTypeO is assumed to be 2 bytes (e.g., fp16/bf16). This assumption is hardcoded. If other output types with different sizes are to be supported in the future, this might need to be made more flexible, perhaps by using sizeof(DTypeO) from the template parameters if DTypeO were available here, or by passing the size.

Comment on lines +1208 to +1209
cluster_cost_heap.insert(
{cluster_idx, accum_cost + cost_function(cluster_tile_q, actual_len)});
Copy link
Contributor

@Edenzzzz Edenzzzz Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding performance issues of reduction, I guess its cost should be included in the scheduler
see https://hazyresearch.stanford.edu/blog/2025-03-04-thundermla

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had some discussion w/ @happierpig , there are several optimizations we need to take into consideration:

  1. count the reduction cost as well (note that we can first adding reduction cost to the priority queue, and then prefill, and then decode, because reduction cost is long).
  2. consider two stage reduction: merging 64 chunks on one CTA -> (merging 8 chunks on 8 CTA -> merging 8 chunks on 1 CTA)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @Edenzzzz and @yzh119! Yeah, I think there are two levels of inefficiency of the reduction operations:

  1. Inter-stage wave quantization. Bubbles are introduced by grid.sync between the core attention operation and the reduction operation. I think some fine-grained dependency control is needed for reducing this. After adding the barrier (or somehow guaranteeing the dependency), we can count the reduction cost as well.
  2. Intra-stage balance. The reduction kernel itself is implemented in a coarse-grained manner. For e.g., some threads are idle when seq_len of the work_tile is short, while some work_tile has super long seq_len can be further shared.
    (iter + num_smem_stages) * bdy + ty < num_index_sets);

    I will take a look at it after adding profiler

BlockPersistentRunner1::Run(params_1, &smem_storage_1);
PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kRunner1);

__syncthreads();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can just remove the syncthreads here and the grid.sync() below? I tested that after removing, precision tests still pass and the bandwidth is higher

Before remove

image

After remove

image

Copy link
Collaborator Author

@happierpig happierpig Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the first __sync_threads() should be feasible, as all works between CTA in this stage are independent. However, grid_sync() here is for guarantee dependency, which assures reduction start only after all partial results are correctly calculated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use CTA wise barriers, instead of grid_sync

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use CTA wise barriers, instead of grid_sync

Agree and this fine-grained dependency control will be a necessary component for counting the reduction cost

Copy link
Collaborator Author

@happierpig happierpig Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119 i am curious about whether grid.sync has some guarantee on inter-CTA memory ordering. What if some partial results are kept in L1 while the other CTA try to do reduction and load stale data?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the source of grid.sync() in cooperative_groups library, it basically relies on a arrived counter in global memory and updated with .gpu memory scope.

typedef unsigned int barrier_t;

_CG_STATIC_QUALIFIER bool bar_has_flipped(unsigned int old_arrive, unsigned int current_arrive) {
    return (((old_arrive ^ current_arrive) & 0x80000000) != 0);
}

_CG_STATIC_QUALIFIER bool is_cta_master() {
    return (threadIdx.x + threadIdx.y + threadIdx.z == 0);
}

_CG_STATIC_QUALIFIER unsigned int sync_grids_arrive(volatile barrier_t *arrived) {
    unsigned int oldArrive = 0;

    __barrier_sync(0);

    if (is_cta_master()) {
        unsigned int expected = gridDim.x * gridDim.y * gridDim.z;
        bool gpu_master = (blockIdx.x + blockIdx.y + blockIdx.z == 0);
        unsigned int nb = 1;

        if (gpu_master) {
            nb = 0x80000000 - (expected - 1);
        }

NV_IF_ELSE_TARGET(NV_PROVIDES_SM_70,
        // Barrier update with release; polling with acquire
        asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;" : "=r"(oldArrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived), "r"(nb) : "memory");
        ,
        // Fence; barrier update; volatile polling; fence
        __threadfence();
        oldArrive = atomicAdd((unsigned int*)arrived, nb);
        );
    }

    return oldArrive;
}


_CG_STATIC_QUALIFIER void sync_grids_wait(unsigned int oldArrive, volatile barrier_t *arrived) {
    if (is_cta_master()) {
NV_IF_ELSE_TARGET(NV_PROVIDES_SM_70,
        unsigned int current_arrive;
        do {
            asm volatile("ld.acquire.gpu.u32 %0,[%1];" : "=r"(current_arrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived) : "memory");
        } while (!bar_has_flipped(oldArrive, current_arrive));
        ,
        while (!bar_has_flipped(oldArrive, *arrived));
        __threadfence();
        );
    }

    __barrier_sync(0);
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see. i believe this pair-wise acquire.gpu and release.gpu will guarantee the L1 flush.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's merge this in first and move on with following PRs to fix performance.

@yzh119 yzh119 merged commit 568ab6c into flashinfer-ai:main Jun 12, 2025
2 checks passed
Anerudhan pushed a commit to Anerudhan/flashinfer that referenced this pull request Jun 28, 2025
…ai#1137)

<!-- .github/pull_request_template.md -->

Follow up of flashinfer-ai#858,
flashinfer-ai#967, and
flashinfer-ai#1026, this PR aims to
provide an efficient and unified API for processing prefill and decode
requests within a single kernel launch. Key features include:
1. Single CUDA graph capture for all batch sizes and sequence lengths.
Prior to this PR, FA2 template is implemented with a non-persistent
kernel way, which dispatches `padded_batch_sizes` CTA and uses static
information (ref:
https://github.com/flashinfer-ai/flashinfer/blob/f484fd3c7f09a1d0afb75d779872b9762a35e445/include/flashinfer/attention/scheduler.cuh#L527).
This necessitates a specialized CUDA graph for each batch with different
seqlens and batch sizes, to maximize throughput. Furthermore, prefill
and decode are executed by different kernel launches, increasing the
number of CUDA graphs by combination. This PR implements a
persistent-style kernel, which enables a single CUDA graph to capture
work for all seqlens and batch sizes.
2. Dynamic specialization for prefill and decode. Implemented as a
persistent kernel, prefill and decode requests are dynamically executed
by an efficient kernel template with suitable hyperparameters. For
example, decode requests with `qo_len=1` are processed by
`CTA_TILE_Q=16` while prefill requests with `qo_len>=128` are processed
by `CTA_TILE_Q=128`.

The benchmark script is at `benchmarks/bench_batch_attention.py` and was
tested with Qwen-2.5-7B configurations and a single H200. Visualization:
<img width="594" alt="image"
src="https://github.com/user-attachments/assets/735aca14-387d-4013-b3f4-e199b6cff5f3"
/>
1. 30% bandwidth boost in hybrid scenarios
2. slightly worse perf at pure workloads, which may be caused by the
reduction overhead

Unit tests can be located at `tests/bench_batch_attention.py`.
<img width="1527" alt="image"
src="https://github.com/user-attachments/assets/fff06c6d-c121-497c-9f62-039653149a4d"
/>

1. Add profiler to analyze perf bottleneck
4. Optimize the reduction kernel schedule
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

flashinfer-ai#1022

Advised by @yzh119. CC @AKKamath @Edenzzzz
<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

Co-authored-by: yzh119 <expye@outlook.com>
Co-authored-by: happierpig <zhaoyilong217@sjtu.edn.cn>
yzh119 pushed a commit that referenced this pull request Jun 30, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
This is a follow-up PR on
#1137, optimizing
scheduling balance & reduction overhead, and achieving 2x speedup.
Optimization includes:
1. Conduct fewer kv-split on decode requests. This reduces the number of
reduction operations needed.
2. Write-through. Attention without kv-split is directly written back
into the final buffer without additional copy.
3. Warp-level reduction. Changing CTA-wise reduction into warp-wise
reduction increase parallelism, greatly reducing the reduction overhead.

This PR also investigates _**fine-grained gmem barrier**_ as discussed
in #1137. However, due
to the overhead of memory ordering operations, we do not see performance
gain on our usecases (static scheduling), thus leaving for future work.
Implementations are available in a separate fork
(https://github.com/happierpig/flashinfer-ai/tree/fine-grained-barrier-reduction).
Visualization:
<img width="1487" alt="image"
src="https://github.com/user-attachments/assets/27ebc173-6ef6-4b21-9277-3793e270c9f6"
/>

Perf benchmarks (`benchmarks/bench_batch_attention.py`) on
`head_dim=128, num_qo_heads=28, num_kv_heads=4, page_size=1` with H200.
<img width="609" alt="image"
src="https://github.com/user-attachments/assets/361d8c43-3fb6-431b-bf9e-2bf3a7a0df87"
/>


## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes
cc @yzh119 @yyihuang 
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

Co-authored-by: happierpig <zhaoyilong217@sjtu.edn.cn>
yzh119 pushed a commit that referenced this pull request Sep 2, 2025
… in Qwen family (#1614)

<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
This PR fixes precision issues of BatchAttention (Persistent FA2 of
#1137), when
`CTA_TILE_Q` is not a multiple of `gqa_group_size` (e.g., Qwen family
models). Prior implementation assumes that all `qo_heads` of a `kv_head`
on a specific token will all be split-kv or non-split-kv. However, when
`gqa-group-size == 7`, some `qo_heads` can be non-split while the
remaining can be split.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
cc @Edenzzzz
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants