Skip to content

Commit 568ab6c

Browse files
happierpigyzh119happierpig
authored
[feat] add unified batch attention w/ correctness tests. (#1137)
<!-- .github/pull_request_template.md --> ## 📌 Description Follow up of #858, #967, and #1026, this PR aims to provide an efficient and unified API for processing prefill and decode requests within a single kernel launch. Key features include: 1. Single CUDA graph capture for all batch sizes and sequence lengths. Prior to this PR, FA2 template is implemented with a non-persistent kernel way, which dispatches `padded_batch_sizes` CTA and uses static information (ref: https://github.com/flashinfer-ai/flashinfer/blob/f484fd3c7f09a1d0afb75d779872b9762a35e445/include/flashinfer/attention/scheduler.cuh#L527). This necessitates a specialized CUDA graph for each batch with different seqlens and batch sizes, to maximize throughput. Furthermore, prefill and decode are executed by different kernel launches, increasing the number of CUDA graphs by combination. This PR implements a persistent-style kernel, which enables a single CUDA graph to capture work for all seqlens and batch sizes. 2. Dynamic specialization for prefill and decode. Implemented as a persistent kernel, prefill and decode requests are dynamically executed by an efficient kernel template with suitable hyperparameters. For example, decode requests with `qo_len=1` are processed by `CTA_TILE_Q=16` while prefill requests with `qo_len>=128` are processed by `CTA_TILE_Q=128`. ## Perf Benchmarks: The benchmark script is at `benchmarks/bench_batch_attention.py` and was tested with Qwen-2.5-7B configurations and a single H200. Visualization: <img width="594" alt="image" src="https://github.com/user-attachments/assets/735aca14-387d-4013-b3f4-e199b6cff5f3" /> 1. 30% bandwidth boost in hybrid scenarios 2. slightly worse perf at pure workloads, which may be caused by the reduction overhead ## Unit Tests: Unit tests can be located at `tests/bench_batch_attention.py`. <img width="1527" alt="image" src="https://github.com/user-attachments/assets/fff06c6d-c121-497c-9f62-039653149a4d" /> ## Future works: 1. Add profiler to analyze perf bottleneck 4. Optimize the reduction kernel schedule <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues #1022 Advised by @yzh119. CC @AKKamath @Edenzzzz <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> Co-authored-by: yzh119 <expye@outlook.com> Co-authored-by: happierpig <zhaoyilong217@sjtu.edn.cn>
1 parent 35aaabb commit 568ab6c

18 files changed

+2119
-20
lines changed
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
from __future__ import annotations
2+
3+
import itertools
4+
from typing import List, Sequence, Tuple
5+
6+
import numpy as np
7+
import pandas as pd
8+
import torch
9+
from triton.testing import do_bench
10+
11+
import flashinfer
12+
13+
14+
def run_bench(
15+
kv_lens: Sequence[int],
16+
qo_lens: Sequence[int],
17+
*,
18+
page_block_size: int,
19+
num_kv_heads: int,
20+
num_qo_heads: int,
21+
head_dim: int,
22+
device: int = 0,
23+
causal: bool = True,
24+
) -> Tuple[float, float, float, float, float]:
25+
seq_lens = torch.tensor(kv_lens, dtype=torch.int32)
26+
q_lens = torch.tensor(qo_lens, dtype=torch.int32)
27+
seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int()
28+
29+
q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int()
30+
kv_indptr = torch.cat(
31+
[torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0
32+
).int()
33+
num_blocks = kv_indptr[-1].item()
34+
35+
q = torch.rand(
36+
q_indptr[-1].item(), num_qo_heads, head_dim, dtype=torch.bfloat16, device=device
37+
)
38+
kv_data = torch.randn(
39+
num_blocks,
40+
2,
41+
page_block_size,
42+
num_kv_heads,
43+
head_dim,
44+
dtype=torch.bfloat16,
45+
device=device,
46+
)
47+
48+
# old
49+
wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
50+
torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device),
51+
kv_layout="NHD",
52+
backend="fa2",
53+
)
54+
last_page_len = (seq_lens - 1) % page_block_size + 1
55+
wrapper_old.plan(
56+
q_indptr.to(device),
57+
kv_indptr.to(device),
58+
torch.arange(num_blocks, dtype=torch.int32, device=device),
59+
last_page_len,
60+
num_qo_heads,
61+
num_kv_heads,
62+
head_dim,
63+
page_block_size,
64+
causal=causal,
65+
q_data_type=torch.bfloat16,
66+
kv_data_type=torch.bfloat16,
67+
)
68+
ms_old = do_bench(lambda: wrapper_old.run(q, kv_data))
69+
70+
# new
71+
wrapper = flashinfer.BatchAttention(kv_layout="NHD")
72+
wrapper.plan(
73+
q_indptr.to(device),
74+
kv_indptr.to(device),
75+
torch.arange(num_blocks, dtype=torch.int32, device=device),
76+
seq_lens.to(device),
77+
num_qo_heads,
78+
num_kv_heads,
79+
head_dim,
80+
head_dim,
81+
page_block_size,
82+
causal=causal,
83+
q_data_type=torch.bfloat16,
84+
kv_data_type=torch.bfloat16,
85+
)
86+
ms_new = do_bench(lambda: wrapper.run(q, kv_data))
87+
88+
total_bytes = (
89+
q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
90+
)
91+
mem_MB = total_bytes / 1024**2
92+
bw_old = total_bytes / (ms_old * 1e-3) / 1024**3
93+
bw_new = total_bytes / (ms_new * 1e-3) / 1024**3
94+
95+
return ms_old, ms_new, mem_MB, bw_old, bw_new
96+
97+
98+
def synthesize_seq_len_configs() -> List[List[Tuple[int, int]]]:
99+
cfgs: List[List[Tuple[int, int]]] = [
100+
[(8192, 1)] * 128, # decode-only
101+
[(4096, 128)] * 4, # prefill-only
102+
[(600, 1)] * 122 + [(10_000, 17)] * 8, # hybird
103+
[(8192, 1)] * 127 * 2 + [(2048, 512)] * 1, # hybrid (chunked-prefill)
104+
]
105+
106+
def _rand_case(bsz: int, lo: int, hi: int) -> List[Tuple[int, int]]:
107+
stride, sparsity = 16, 0.05
108+
full = np.random.randint(lo, hi, size=bsz)
109+
out = []
110+
for i, kv_len in enumerate(full):
111+
if i % stride == 0:
112+
out.append((kv_len, stride + 1))
113+
else:
114+
out.append((int(kv_len * sparsity), 1))
115+
return out
116+
117+
cfgs.append(_rand_case(256, 1000, 8192))
118+
cfgs.append(_rand_case(128, 2000, 16_000))
119+
return cfgs
120+
121+
122+
def main() -> None:
123+
np.random.seed(42)
124+
torch.random.manual_seed(42)
125+
126+
seq_len_cfgs = synthesize_seq_len_configs()
127+
128+
sweep = {
129+
"page_block_size": (1, 8, 16),
130+
"head_dim": (64, 128),
131+
"num_kv_heads": (4,),
132+
"num_qo_heads": (28,),
133+
}
134+
135+
records = []
136+
137+
for cfg_id, pairs in enumerate(seq_len_cfgs, start=1):
138+
kv_lens = [p[0] for p in pairs]
139+
qo_lens = [p[1] for p in pairs]
140+
for pbs, hd, n_kv, n_qo in itertools.product(
141+
sweep["page_block_size"],
142+
sweep["head_dim"],
143+
sweep["num_kv_heads"],
144+
sweep["num_qo_heads"],
145+
):
146+
147+
ms_old, ms_new, mem_MB, bw_old, bw_new = run_bench(
148+
kv_lens,
149+
qo_lens,
150+
page_block_size=pbs,
151+
num_kv_heads=n_kv,
152+
num_qo_heads=n_qo,
153+
head_dim=hd,
154+
device=0,
155+
causal=True,
156+
)
157+
records.extend(
158+
[
159+
{
160+
"scheduler": "BatchPrefillWithPagedKVCacheWrapper",
161+
"seq_cfg_id": cfg_id,
162+
"page_size": pbs,
163+
"head_dim": hd,
164+
"num_kv_heads": n_kv,
165+
"num_qo_heads": n_qo,
166+
"time_ms": ms_old,
167+
"memory_MB": mem_MB,
168+
"bandwidth_GB_s": bw_old,
169+
},
170+
{
171+
"scheduler": "BatchAttentionWrapper",
172+
"seq_cfg_id": cfg_id,
173+
"page_size": pbs,
174+
"head_dim": hd,
175+
"num_kv_heads": n_kv,
176+
"num_qo_heads": n_qo,
177+
"time_ms": ms_new,
178+
"memory_MB": mem_MB,
179+
"bandwidth_GB_s": bw_new,
180+
},
181+
]
182+
)
183+
184+
df = pd.DataFrame(
185+
records,
186+
columns=[
187+
"scheduler",
188+
"seq_cfg_id",
189+
"page_size",
190+
"head_dim",
191+
"num_kv_heads",
192+
"num_qo_heads",
193+
"time_ms",
194+
"memory_MB",
195+
"bandwidth_GB_s",
196+
],
197+
)
198+
print(df.to_markdown(index=False, floatfmt=".2f"))
199+
200+
201+
if __name__ == "__main__":
202+
main()

csrc/batch_attention.cu

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/attention/mask.cuh>
17+
#include <flashinfer/attention/scheduler.cuh>
18+
#include <flashinfer/layout.cuh>
19+
#include <flashinfer/pos_enc.cuh>
20+
#include <optional>
21+
22+
#include "batch_attention_config.inc"
23+
#include "pytorch_conversion_utils.h"
24+
#include "pytorch_extension_utils.h"
25+
26+
namespace flashinfer {
27+
28+
template <uint32_t CTA_TILE_Q_1, uint32_t CTA_TILE_Q_2, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO,
29+
MaskMode MASK_MODE, typename AttentionVariant, typename Params>
30+
cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params params_2,
31+
const uint32_t num_blks_x, const uint32_t num_blks_y,
32+
const cudaStream_t stream);
33+
} // namespace flashinfer
34+
35+
using namespace flashinfer;
36+
37+
at::Tensor BatchPagedAttentionPlan(at::Tensor float_workspace_buffer,
38+
at::Tensor int_workspace_buffer,
39+
at::Tensor page_locked_int_workspace_buffer,
40+
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len,
41+
int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads,
42+
int64_t head_dim_o, bool causal) {
43+
size_t float_workspace_size_in_bytes =
44+
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
45+
size_t int_workspace_size_in_bytes =
46+
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
47+
48+
HolisticPlanInfo<2> plan_info;
49+
50+
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
51+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
52+
53+
cudaError_t status = TwoStageHolisticPlan<IdType>(
54+
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
55+
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
56+
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
57+
kv_indptr.data_ptr<IdType>(), kv_len.data_ptr<IdType>(), batch_size, num_qo_heads,
58+
num_kv_heads, head_dim_o, causal, stream);
59+
60+
TORCH_CHECK(status == cudaSuccess,
61+
"Failed to plan persistent paged attention, error: ", cudaGetErrorString(status));
62+
63+
return vec_to_tensor(plan_info.ToVector());
64+
}
65+
66+
void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
67+
at::Tensor plan_info_vec, at::Tensor q, at::Tensor k_cache,
68+
at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
69+
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
70+
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
71+
int64_t page_size, double sm_scale ADDITIONAL_FUNC_PARAMS) {
72+
HolisticPlanInfo<2> plan_info;
73+
plan_info.FromVector(tensor_to_vec(plan_info_vec));
74+
75+
auto device = q.device();
76+
77+
void* float_buffer_ptr = float_workspace_buffer.data_ptr();
78+
void* int_buffer_ptr = int_workspace_buffer.data_ptr();
79+
80+
const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
81+
82+
auto q_scalar_type = q.scalar_type();
83+
auto kv_scalar_type = k_cache.scalar_type();
84+
85+
// NOTE (Yilong): assume both q and o are NHD
86+
unsigned int q_stride_n = q.stride(0);
87+
unsigned int q_stride_h = q.stride(1);
88+
89+
// layout only constraint paged KV
90+
const QKVLayout kv_layout = static_cast<QKVLayout>(layout_code);
91+
unsigned int k_stride_page = k_cache.stride(0);
92+
unsigned int v_stride_page = v_cache.stride(0);
93+
unsigned int k_stride_n, k_stride_h, v_stride_n, v_stride_h;
94+
if (kv_layout == QKVLayout::kNHD) {
95+
k_stride_h = k_cache.stride(2);
96+
k_stride_n = k_cache.stride(1);
97+
v_stride_h = v_cache.stride(2);
98+
v_stride_n = v_cache.stride(1);
99+
} else {
100+
k_stride_h = k_cache.stride(1);
101+
k_stride_n = k_cache.stride(2);
102+
v_stride_h = v_cache.stride(1);
103+
v_stride_n = v_cache.stride(2);
104+
}
105+
106+
const c10::cuda::OptionalCUDAGuard device_guard(device);
107+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
108+
109+
DISPATCH_context(
110+
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
111+
AttentionVariant, PersistentParams, [&] {
112+
PersistentParams params[2];
113+
114+
for (int i = 0; i < 2; i++) {
115+
params[i].q = static_cast<DTypeQ*>(q.data_ptr());
116+
params[i].k = static_cast<DTypeKV*>(k_cache.data_ptr());
117+
params[i].v = static_cast<DTypeKV*>(v_cache.data_ptr());
118+
119+
params[i].q_indptr =
120+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].q_indptr_offset);
121+
params[i].kv_indptr =
122+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_indptr_offset);
123+
params[i].partial_indptr = GetPtrFromBaseOffset<IdType>(
124+
int_buffer_ptr, plan_info.tasks[i].partial_indptr_offset);
125+
params[i].kv_indices = static_cast<int*>(kv_indices.data_ptr());
126+
params[i].q_len =
127+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].q_len_offset);
128+
params[i].kv_len =
129+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_len_offset);
130+
params[i].q_start =
131+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].q_start_offset);
132+
params[i].kv_start =
133+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_start_offset);
134+
params[i].kv_end =
135+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_end_offset);
136+
params[i].kv_head_idx_arr =
137+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_head_idx_offset);
138+
params[i].work_indptr =
139+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].work_indptr_offset);
140+
params[i].len_kv_chunk =
141+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].len_kv_chunk_offset);
142+
143+
params[i].final_o = static_cast<DTypeO*>(o.data_ptr());
144+
params[i].final_lse =
145+
maybe_lse.has_value() ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
146+
params[i].partial_o =
147+
GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr, plan_info.partial_o_offset);
148+
params[i].partial_lse =
149+
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_lse_offset);
150+
151+
// for state reduction
152+
params[i].merge_indptr =
153+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_indptr_offset);
154+
params[i].merge_o_indices =
155+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_o_indices_offset);
156+
params[i].num_packed_qo_len =
157+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.num_qo_len_offset);
158+
159+
params[i].num_kv_heads = num_kv_heads;
160+
params[i].gqa_group_size = uint_fastdiv(num_qo_heads / num_kv_heads);
161+
params[i].page_size = uint_fastdiv(page_size);
162+
163+
params[i].q_stride_n = q_stride_n;
164+
params[i].q_stride_h = q_stride_h;
165+
params[i].k_stride_page = k_stride_page;
166+
params[i].k_stride_h = k_stride_h;
167+
params[i].k_stride_n = k_stride_n;
168+
params[i].v_stride_page = v_stride_page;
169+
params[i].v_stride_h = v_stride_h;
170+
params[i].v_stride_n = v_stride_n;
171+
172+
params[i].sm_scale = sm_scale;
173+
174+
ADDITIONAL_PARAMS_SETTER
175+
}
176+
177+
cudaError_t status = BatchPagedAttentionPersistent<128, 16, HEAD_DIM_QK, HEAD_DIM_VO,
178+
MASK_MODE, AttentionVariant>(
179+
params[0], params[1], plan_info.num_blks_x, plan_info.num_blks_y, stream);
180+
TORCH_CHECK(status == cudaSuccess, "Failed to run persistent paged attention, error: ",
181+
cudaGetErrorString(status));
182+
return true;
183+
});
184+
}

0 commit comments

Comments
 (0)