Skip to content

Conversation

@happierpig
Copy link
Collaborator

Description

This PR is a follow-up to #858, which integrates the PoDAttention (arXiv link) API in a user-transparent manner. Users can now invoke PoDAttention via the same API as BatchPrefillWithPagedKVCache, without explicitly specifying whether requests are prefill or decode (example code).

Key Changes

  1. Support for Non-Continuous Q/O and KV Tensor Layout
    Previously, tensor offsets were computed using indptr, assuming continuous layouts. PoDAttention requires supporting mixed prefill/decode subsets within requests, necessitating a non-continuous layout.

    • Added q_lenptr and kv_lenptr to accommodate this functionality (code link).
  2. Horizontal Fusion-Style Implementation
    For improved efficiency, subsets of requests are aware of each other, enabling optimal selection of kernel hyperparameters and persistent kernel execution.

    • Current resource partitioning strategy solely depends on total KV-cache load size (scheduler code).
    • Note: This strategy is customizable based on specific workloads.

Limitations and Future Work

  • CUDA Graph is currently not supported. Only FA2 is supported at this stage.
  • The workload classifier (qo_len > threshold) is preliminary and requires improvement (classifier implementation).
  • Performance tuning is ongoing, and correctness has only been validated on a limited set of unit tests (unit tests).
    cc @AKKamath @yzh119
image

@happierpig happierpig requested a review from yzh119 March 21, 2025 21:28
@yzh119
Copy link
Collaborator

yzh119 commented Mar 21, 2025

Some of the unittests failed, for example (test_block_sparse_attention[False-256-16-16-128-64-16-4])

RuntimeError: Error in function 'PrefillSplitQOKVIndptr' at /workspace/flashinfer/data/include/flashinfer/attention/scheduler.cuh:515: kv_len_ptr_h[0]: 0 should be positive

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Apr 9, 2025

Hi, can I ask when this is planned to be merged? I made a PR to support POD Attn in SGLang using the old API and plan to get that working with CUDA graph first.
sgl-project/sglang#5169

@AKKamath
Copy link
Contributor

AKKamath commented Apr 9, 2025

I really like the uniform batch API that this PR presents.

I ran this on an A100 and compared it with the existing FlashInfer POD-Attention implementation. On average this performed around 10 - 15% worse, but still better than serial execution. Performance was worse for larger prefill context lengths, while for smaller context lengths the performance was more comparable.

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Apr 9, 2025

Yeah this is more convenient, one issue i had during my PR is that I have to fill 2D attention mask for prefill every time, instead using page table & indices

@yzh119
Copy link
Collaborator

yzh119 commented Apr 10, 2025

Hi @Edenzzzz @AKKamath , I'm working on another branch following this idea, it will be merged these days.

@Edenzzzz
Copy link
Contributor

Will the old API be preserved? Thanks.

@Edenzzzz
Copy link
Contributor

@AKKamath Btw, I wonder what was the reason for using a mask instead of page table for prefill qkv?

@AKKamath
Copy link
Contributor

@yzh119 Can correct me here, but I believe the mask prefill kernel (single_prefill) had a better performance than the page table prefill because the page table prefill had a higher register usage causing register spills.

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Apr 12, 2025

But don't we waste lots of space storing the 2D mask? For example, the default shape is 2D cumulative seq lens (qo_lens, kv_lens), but when converting from page table qo_indptr, kv_indptr to the mask it will be very sparse, with each qo related to only a few kv entries of the request in the whole cumulative sequence. It can also be expensive to fill the mask

@Edenzzzz
Copy link
Contributor

Actually I realized POD Attention is not designed to mix many prefill requests with decode requests, it just mixes one prefill at a time, so that we can use causal without any custom mask

@yzh119
Copy link
Collaborator

yzh119 commented Apr 20, 2025

Follow up in #1026 .

@yzh119 yzh119 closed this Apr 20, 2025
Comment on lines +1451 to +1452
std::accumulate(qo_len_ptr_h_p.begin(), qo_len_ptr_h_p.end(), 0) +
2 * page_size * std::accumulate(kv_len_ptr_h_p.begin(), kv_len_ptr_h_p.end(), 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I'm interested in implementing a persistent POD Attn and have some questions. Here why don't we do qo_len_ptr_h_p[i] * kv_len_ptr_h_p[i] * 2 to model the quadratic compute load? Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using current calculation mainly for modeling memory load instead of compute load. For different workloads, this calculation can have different best heuristics. It will be helpful if you do benchmarking and decide.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, #1026 will be the upstream version and this PR has been deprecated. It would be helpful if you could refer directly to the new PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Do you have any plans to adapt POD Attn to the persistent template? I also plan to work on that

yzh119 added a commit that referenced this pull request Jun 12, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description
Follow up of #858,
#967, and
#1026, this PR aims to
provide an efficient and unified API for processing prefill and decode
requests within a single kernel launch. Key features include:
1. Single CUDA graph capture for all batch sizes and sequence lengths.
Prior to this PR, FA2 template is implemented with a non-persistent
kernel way, which dispatches `padded_batch_sizes` CTA and uses static
information (ref:
https://github.com/flashinfer-ai/flashinfer/blob/f484fd3c7f09a1d0afb75d779872b9762a35e445/include/flashinfer/attention/scheduler.cuh#L527).
This necessitates a specialized CUDA graph for each batch with different
seqlens and batch sizes, to maximize throughput. Furthermore, prefill
and decode are executed by different kernel launches, increasing the
number of CUDA graphs by combination. This PR implements a
persistent-style kernel, which enables a single CUDA graph to capture
work for all seqlens and batch sizes.
2. Dynamic specialization for prefill and decode. Implemented as a
persistent kernel, prefill and decode requests are dynamically executed
by an efficient kernel template with suitable hyperparameters. For
example, decode requests with `qo_len=1` are processed by
`CTA_TILE_Q=16` while prefill requests with `qo_len>=128` are processed
by `CTA_TILE_Q=128`.

## Perf Benchmarks:
The benchmark script is at `benchmarks/bench_batch_attention.py` and was
tested with Qwen-2.5-7B configurations and a single H200. Visualization:
<img width="594" alt="image"
src="https://github.com/user-attachments/assets/735aca14-387d-4013-b3f4-e199b6cff5f3"
/>
1. 30% bandwidth boost in hybrid scenarios
2. slightly worse perf at pure workloads, which may be caused by the
reduction overhead

## Unit Tests:
Unit tests can be located at `tests/bench_batch_attention.py`.
<img width="1527" alt="image"
src="https://github.com/user-attachments/assets/fff06c6d-c121-497c-9f62-039653149a4d"
/>

## Future works:
1. Add profiler to analyze perf bottleneck
4. Optimize the reduction kernel schedule
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues
#1022

Advised by @yzh119. CC @AKKamath @Edenzzzz 
<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

Co-authored-by: yzh119 <expye@outlook.com>
Co-authored-by: happierpig <zhaoyilong217@sjtu.edn.cn>
Anerudhan pushed a commit to Anerudhan/flashinfer that referenced this pull request Jun 28, 2025
…ai#1137)

<!-- .github/pull_request_template.md -->

Follow up of flashinfer-ai#858,
flashinfer-ai#967, and
flashinfer-ai#1026, this PR aims to
provide an efficient and unified API for processing prefill and decode
requests within a single kernel launch. Key features include:
1. Single CUDA graph capture for all batch sizes and sequence lengths.
Prior to this PR, FA2 template is implemented with a non-persistent
kernel way, which dispatches `padded_batch_sizes` CTA and uses static
information (ref:
https://github.com/flashinfer-ai/flashinfer/blob/f484fd3c7f09a1d0afb75d779872b9762a35e445/include/flashinfer/attention/scheduler.cuh#L527).
This necessitates a specialized CUDA graph for each batch with different
seqlens and batch sizes, to maximize throughput. Furthermore, prefill
and decode are executed by different kernel launches, increasing the
number of CUDA graphs by combination. This PR implements a
persistent-style kernel, which enables a single CUDA graph to capture
work for all seqlens and batch sizes.
2. Dynamic specialization for prefill and decode. Implemented as a
persistent kernel, prefill and decode requests are dynamically executed
by an efficient kernel template with suitable hyperparameters. For
example, decode requests with `qo_len=1` are processed by
`CTA_TILE_Q=16` while prefill requests with `qo_len>=128` are processed
by `CTA_TILE_Q=128`.

The benchmark script is at `benchmarks/bench_batch_attention.py` and was
tested with Qwen-2.5-7B configurations and a single H200. Visualization:
<img width="594" alt="image"
src="https://github.com/user-attachments/assets/735aca14-387d-4013-b3f4-e199b6cff5f3"
/>
1. 30% bandwidth boost in hybrid scenarios
2. slightly worse perf at pure workloads, which may be caused by the
reduction overhead

Unit tests can be located at `tests/bench_batch_attention.py`.
<img width="1527" alt="image"
src="https://github.com/user-attachments/assets/fff06c6d-c121-497c-9f62-039653149a4d"
/>

1. Add profiler to analyze perf bottleneck
4. Optimize the reduction kernel schedule
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

flashinfer-ai#1022

Advised by @yzh119. CC @AKKamath @Edenzzzz
<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

Co-authored-by: yzh119 <expye@outlook.com>
Co-authored-by: happierpig <zhaoyilong217@sjtu.edn.cn>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants