Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions benchmarks/bench_batch_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from __future__ import annotations

import itertools
from typing import List, Sequence, Tuple

import numpy as np
import pandas as pd
import torch
from triton.testing import do_bench

import flashinfer


def run_bench(
kv_lens: Sequence[int],
qo_lens: Sequence[int],
*,
page_block_size: int,
num_kv_heads: int,
num_qo_heads: int,
head_dim: int,
device: int = 0,
causal: bool = True,
) -> Tuple[float, float, float, float, float]:
seq_lens = torch.tensor(kv_lens, dtype=torch.int32)
q_lens = torch.tensor(qo_lens, dtype=torch.int32)
seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int()

q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int()
kv_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0
).int()
num_blocks = kv_indptr[-1].item()

q = torch.rand(
q_indptr[-1].item(), num_qo_heads, head_dim, dtype=torch.bfloat16, device=device
)
kv_data = torch.randn(
num_blocks,
2,
page_block_size,
num_kv_heads,
head_dim,
dtype=torch.bfloat16,
device=device,
)

# old
wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device),
kv_layout="NHD",
backend="fa2",
)
last_page_len = (seq_lens - 1) % page_block_size + 1
wrapper_old.plan(
q_indptr.to(device),
kv_indptr.to(device),
torch.arange(num_blocks, dtype=torch.int32, device=device),
last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_block_size,
causal=causal,
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
)
ms_old = do_bench(lambda: wrapper_old.run(q, kv_data))

# new
wrapper = flashinfer.BatchAttention(kv_layout="NHD")
wrapper.plan(
q_indptr.to(device),
kv_indptr.to(device),
torch.arange(num_blocks, dtype=torch.int32, device=device),
seq_lens.to(device),
num_qo_heads,
num_kv_heads,
head_dim,
head_dim,
page_block_size,
causal=causal,
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
)
ms_new = do_bench(lambda: wrapper.run(q, kv_data))

total_bytes = (
q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
)
mem_MB = total_bytes / 1024**2
bw_old = total_bytes / (ms_old * 1e-3) / 1024**3
bw_new = total_bytes / (ms_new * 1e-3) / 1024**3

return ms_old, ms_new, mem_MB, bw_old, bw_new


def synthesize_seq_len_configs() -> List[List[Tuple[int, int]]]:
cfgs: List[List[Tuple[int, int]]] = [
[(8192, 1)] * 128, # decode-only
[(4096, 128)] * 4, # prefill-only
[(600, 1)] * 122 + [(10_000, 17)] * 8, # hybird
[(8192, 1)] * 127 * 2 + [(2048, 512)] * 1, # hybrid (chunked-prefill)
]

def _rand_case(bsz: int, lo: int, hi: int) -> List[Tuple[int, int]]:
stride, sparsity = 16, 0.05
full = np.random.randint(lo, hi, size=bsz)
out = []
for i, kv_len in enumerate(full):
if i % stride == 0:
out.append((kv_len, stride + 1))
else:
out.append((int(kv_len * sparsity), 1))
return out

cfgs.append(_rand_case(256, 1000, 8192))
cfgs.append(_rand_case(128, 2000, 16_000))
return cfgs


def main() -> None:
np.random.seed(42)
torch.random.manual_seed(42)

seq_len_cfgs = synthesize_seq_len_configs()

sweep = {
"page_block_size": (1, 8, 16),
"head_dim": (64, 128),
"num_kv_heads": (4,),
"num_qo_heads": (28,),
}

records = []

for cfg_id, pairs in enumerate(seq_len_cfgs, start=1):
kv_lens = [p[0] for p in pairs]
qo_lens = [p[1] for p in pairs]
for pbs, hd, n_kv, n_qo in itertools.product(
sweep["page_block_size"],
sweep["head_dim"],
sweep["num_kv_heads"],
sweep["num_qo_heads"],
):

ms_old, ms_new, mem_MB, bw_old, bw_new = run_bench(
kv_lens,
qo_lens,
page_block_size=pbs,
num_kv_heads=n_kv,
num_qo_heads=n_qo,
head_dim=hd,
device=0,
causal=True,
)
records.extend(
[
{
"scheduler": "BatchPrefillWithPagedKVCacheWrapper",
"seq_cfg_id": cfg_id,
"page_size": pbs,
"head_dim": hd,
"num_kv_heads": n_kv,
"num_qo_heads": n_qo,
"time_ms": ms_old,
"memory_MB": mem_MB,
"bandwidth_GB_s": bw_old,
},
{
"scheduler": "BatchAttentionWrapper",
"seq_cfg_id": cfg_id,
"page_size": pbs,
"head_dim": hd,
"num_kv_heads": n_kv,
"num_qo_heads": n_qo,
"time_ms": ms_new,
"memory_MB": mem_MB,
"bandwidth_GB_s": bw_new,
},
]
)

df = pd.DataFrame(
records,
columns=[
"scheduler",
"seq_cfg_id",
"page_size",
"head_dim",
"num_kv_heads",
"num_qo_heads",
"time_ms",
"memory_MB",
"bandwidth_GB_s",
],
)
print(df.to_markdown(index=False, floatfmt=".2f"))


if __name__ == "__main__":
main()
184 changes: 184 additions & 0 deletions csrc/batch_attention.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/pos_enc.cuh>
#include <optional>

#include "batch_attention_config.inc"
#include "pytorch_conversion_utils.h"
#include "pytorch_extension_utils.h"

namespace flashinfer {

template <uint32_t CTA_TILE_Q_1, uint32_t CTA_TILE_Q_2, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO,
MaskMode MASK_MODE, typename AttentionVariant, typename Params>
cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params params_2,
const uint32_t num_blks_x, const uint32_t num_blks_y,
const cudaStream_t stream);
} // namespace flashinfer

using namespace flashinfer;

at::Tensor BatchPagedAttentionPlan(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len,
int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_o, bool causal) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();

HolisticPlanInfo<2> plan_info;

const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

cudaError_t status = TwoStageHolisticPlan<IdType>(
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
kv_indptr.data_ptr<IdType>(), kv_len.data_ptr<IdType>(), batch_size, num_qo_heads,
num_kv_heads, head_dim_o, causal, stream);

TORCH_CHECK(status == cudaSuccess,
"Failed to plan persistent paged attention, error: ", cudaGetErrorString(status));

return vec_to_tensor(plan_info.ToVector());
}

void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor plan_info_vec, at::Tensor q, at::Tensor k_cache,
at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t page_size, double sm_scale ADDITIONAL_FUNC_PARAMS) {
HolisticPlanInfo<2> plan_info;
plan_info.FromVector(tensor_to_vec(plan_info_vec));

auto device = q.device();

void* float_buffer_ptr = float_workspace_buffer.data_ptr();
void* int_buffer_ptr = int_workspace_buffer.data_ptr();

const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);

auto q_scalar_type = q.scalar_type();
auto kv_scalar_type = k_cache.scalar_type();

// NOTE (Yilong): assume both q and o are NHD
unsigned int q_stride_n = q.stride(0);
unsigned int q_stride_h = q.stride(1);

// layout only constraint paged KV
const QKVLayout kv_layout = static_cast<QKVLayout>(layout_code);
unsigned int k_stride_page = k_cache.stride(0);
unsigned int v_stride_page = v_cache.stride(0);
unsigned int k_stride_n, k_stride_h, v_stride_n, v_stride_h;
if (kv_layout == QKVLayout::kNHD) {
k_stride_h = k_cache.stride(2);
k_stride_n = k_cache.stride(1);
v_stride_h = v_cache.stride(2);
v_stride_n = v_cache.stride(1);
} else {
k_stride_h = k_cache.stride(1);
k_stride_n = k_cache.stride(2);
v_stride_h = v_cache.stride(1);
v_stride_n = v_cache.stride(2);
}

const c10::cuda::OptionalCUDAGuard device_guard(device);
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

DISPATCH_context(
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
AttentionVariant, PersistentParams, [&] {
PersistentParams params[2];

for (int i = 0; i < 2; i++) {
params[i].q = static_cast<DTypeQ*>(q.data_ptr());
params[i].k = static_cast<DTypeKV*>(k_cache.data_ptr());
params[i].v = static_cast<DTypeKV*>(v_cache.data_ptr());

params[i].q_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].q_indptr_offset);
params[i].kv_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_indptr_offset);
params[i].partial_indptr = GetPtrFromBaseOffset<IdType>(
int_buffer_ptr, plan_info.tasks[i].partial_indptr_offset);
params[i].kv_indices = static_cast<int*>(kv_indices.data_ptr());
params[i].q_len =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].q_len_offset);
params[i].kv_len =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_len_offset);
params[i].q_start =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].q_start_offset);
params[i].kv_start =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_start_offset);
params[i].kv_end =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_end_offset);
params[i].kv_head_idx_arr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_head_idx_offset);
params[i].work_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].work_indptr_offset);
params[i].len_kv_chunk =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].len_kv_chunk_offset);

params[i].final_o = static_cast<DTypeO*>(o.data_ptr());
params[i].final_lse =
maybe_lse.has_value() ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
params[i].partial_o =
GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr, plan_info.partial_o_offset);
params[i].partial_lse =
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_lse_offset);

// for state reduction
params[i].merge_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_indptr_offset);
params[i].merge_o_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_o_indices_offset);
params[i].num_packed_qo_len =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.num_qo_len_offset);

params[i].num_kv_heads = num_kv_heads;
params[i].gqa_group_size = uint_fastdiv(num_qo_heads / num_kv_heads);
params[i].page_size = uint_fastdiv(page_size);

params[i].q_stride_n = q_stride_n;
params[i].q_stride_h = q_stride_h;
params[i].k_stride_page = k_stride_page;
params[i].k_stride_h = k_stride_h;
params[i].k_stride_n = k_stride_n;
params[i].v_stride_page = v_stride_page;
params[i].v_stride_h = v_stride_h;
params[i].v_stride_n = v_stride_n;

params[i].sm_scale = sm_scale;

ADDITIONAL_PARAMS_SETTER
}

cudaError_t status = BatchPagedAttentionPersistent<128, 16, HEAD_DIM_QK, HEAD_DIM_VO,
MASK_MODE, AttentionVariant>(
params[0], params[1], plan_info.num_blks_x, plan_info.num_blks_y, stream);
TORCH_CHECK(status == cudaSuccess, "Failed to run persistent paged attention, error: ",
cudaGetErrorString(status));
return true;
});
}
Loading