From 156eac8c39dc66a79f4d99be990799d1c16a0642 Mon Sep 17 00:00:00 2001 From: happierpig Date: Wed, 11 Jun 2025 17:56:15 +0000 Subject: [PATCH] [feat] add unified batch attention w/ correctness tests. --- benchmarks/bench_batch_attention.py | 202 +++++++ csrc/batch_attention.cu | 184 ++++++ csrc/batch_attention_customize_config.jinja | 96 +++ csrc/batch_attention_jit_pybind.cu | 36 ++ csrc/batch_attention_paged_kernel_inst.jinja | 9 + flashinfer/__init__.py | 1 + flashinfer/attention.py | 177 ++++++ flashinfer/jit/__init__.py | 2 + flashinfer/jit/attention/__init__.py | 2 + flashinfer/jit/attention/pytorch.py | 193 ++++++ include/flashinfer/attention/cascade.cuh | 5 +- include/flashinfer/attention/persistent.cuh | 566 ++++++++++++++++++ .../attention/persistent_template.cuh | 91 +++ include/flashinfer/attention/prefill.cuh | 39 +- include/flashinfer/attention/scheduler.cuh | 343 +++++++++++ include/flashinfer/permuted_smem.cuh | 10 + include/flashinfer/utils.cuh | 12 + tests/test_batch_attention.py | 171 ++++++ 18 files changed, 2119 insertions(+), 20 deletions(-) create mode 100644 benchmarks/bench_batch_attention.py create mode 100644 csrc/batch_attention.cu create mode 100644 csrc/batch_attention_customize_config.jinja create mode 100644 csrc/batch_attention_jit_pybind.cu create mode 100644 csrc/batch_attention_paged_kernel_inst.jinja create mode 100644 flashinfer/attention.py create mode 100644 include/flashinfer/attention/persistent.cuh create mode 100644 include/flashinfer/attention/persistent_template.cuh create mode 100644 tests/test_batch_attention.py diff --git a/benchmarks/bench_batch_attention.py b/benchmarks/bench_batch_attention.py new file mode 100644 index 0000000000..5274688f98 --- /dev/null +++ b/benchmarks/bench_batch_attention.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import itertools +from typing import List, Sequence, Tuple + +import numpy as np +import pandas as pd +import torch +from triton.testing import do_bench + +import flashinfer + + +def run_bench( + kv_lens: Sequence[int], + qo_lens: Sequence[int], + *, + page_block_size: int, + num_kv_heads: int, + num_qo_heads: int, + head_dim: int, + device: int = 0, + causal: bool = True, +) -> Tuple[float, float, float, float, float]: + seq_lens = torch.tensor(kv_lens, dtype=torch.int32) + q_lens = torch.tensor(qo_lens, dtype=torch.int32) + seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int() + + q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int() + kv_indptr = torch.cat( + [torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0 + ).int() + num_blocks = kv_indptr[-1].item() + + q = torch.rand( + q_indptr[-1].item(), num_qo_heads, head_dim, dtype=torch.bfloat16, device=device + ) + kv_data = torch.randn( + num_blocks, + 2, + page_block_size, + num_kv_heads, + head_dim, + dtype=torch.bfloat16, + device=device, + ) + + # old + wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device), + kv_layout="NHD", + backend="fa2", + ) + last_page_len = (seq_lens - 1) % page_block_size + 1 + wrapper_old.plan( + q_indptr.to(device), + kv_indptr.to(device), + torch.arange(num_blocks, dtype=torch.int32, device=device), + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_block_size, + causal=causal, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + ) + ms_old = do_bench(lambda: wrapper_old.run(q, kv_data)) + + # new + wrapper = flashinfer.BatchAttention(kv_layout="NHD") + wrapper.plan( + q_indptr.to(device), + kv_indptr.to(device), + torch.arange(num_blocks, dtype=torch.int32, device=device), + seq_lens.to(device), + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + page_block_size, + causal=causal, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + ) + ms_new = do_bench(lambda: wrapper.run(q, kv_data)) + + total_bytes = ( + q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size() + ) + mem_MB = total_bytes / 1024**2 + bw_old = total_bytes / (ms_old * 1e-3) / 1024**3 + bw_new = total_bytes / (ms_new * 1e-3) / 1024**3 + + return ms_old, ms_new, mem_MB, bw_old, bw_new + + +def synthesize_seq_len_configs() -> List[List[Tuple[int, int]]]: + cfgs: List[List[Tuple[int, int]]] = [ + [(8192, 1)] * 128, # decode-only + [(4096, 128)] * 4, # prefill-only + [(600, 1)] * 122 + [(10_000, 17)] * 8, # hybird + [(8192, 1)] * 127 * 2 + [(2048, 512)] * 1, # hybrid (chunked-prefill) + ] + + def _rand_case(bsz: int, lo: int, hi: int) -> List[Tuple[int, int]]: + stride, sparsity = 16, 0.05 + full = np.random.randint(lo, hi, size=bsz) + out = [] + for i, kv_len in enumerate(full): + if i % stride == 0: + out.append((kv_len, stride + 1)) + else: + out.append((int(kv_len * sparsity), 1)) + return out + + cfgs.append(_rand_case(256, 1000, 8192)) + cfgs.append(_rand_case(128, 2000, 16_000)) + return cfgs + + +def main() -> None: + np.random.seed(42) + torch.random.manual_seed(42) + + seq_len_cfgs = synthesize_seq_len_configs() + + sweep = { + "page_block_size": (1, 8, 16), + "head_dim": (64, 128), + "num_kv_heads": (4,), + "num_qo_heads": (28,), + } + + records = [] + + for cfg_id, pairs in enumerate(seq_len_cfgs, start=1): + kv_lens = [p[0] for p in pairs] + qo_lens = [p[1] for p in pairs] + for pbs, hd, n_kv, n_qo in itertools.product( + sweep["page_block_size"], + sweep["head_dim"], + sweep["num_kv_heads"], + sweep["num_qo_heads"], + ): + + ms_old, ms_new, mem_MB, bw_old, bw_new = run_bench( + kv_lens, + qo_lens, + page_block_size=pbs, + num_kv_heads=n_kv, + num_qo_heads=n_qo, + head_dim=hd, + device=0, + causal=True, + ) + records.extend( + [ + { + "scheduler": "BatchPrefillWithPagedKVCacheWrapper", + "seq_cfg_id": cfg_id, + "page_size": pbs, + "head_dim": hd, + "num_kv_heads": n_kv, + "num_qo_heads": n_qo, + "time_ms": ms_old, + "memory_MB": mem_MB, + "bandwidth_GB_s": bw_old, + }, + { + "scheduler": "BatchAttentionWrapper", + "seq_cfg_id": cfg_id, + "page_size": pbs, + "head_dim": hd, + "num_kv_heads": n_kv, + "num_qo_heads": n_qo, + "time_ms": ms_new, + "memory_MB": mem_MB, + "bandwidth_GB_s": bw_new, + }, + ] + ) + + df = pd.DataFrame( + records, + columns=[ + "scheduler", + "seq_cfg_id", + "page_size", + "head_dim", + "num_kv_heads", + "num_qo_heads", + "time_ms", + "memory_MB", + "bandwidth_GB_s", + ], + ) + print(df.to_markdown(index=False, floatfmt=".2f")) + + +if __name__ == "__main__": + main() diff --git a/csrc/batch_attention.cu b/csrc/batch_attention.cu new file mode 100644 index 0000000000..263c108304 --- /dev/null +++ b/csrc/batch_attention.cu @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include "batch_attention_config.inc" +#include "pytorch_conversion_utils.h" +#include "pytorch_extension_utils.h" + +namespace flashinfer { + +template +cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params params_2, + const uint32_t num_blks_x, const uint32_t num_blks_y, + const cudaStream_t stream); +} // namespace flashinfer + +using namespace flashinfer; + +at::Tensor BatchPagedAttentionPlan(at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len, + int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t head_dim_o, bool causal) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + + HolisticPlanInfo<2> plan_info; + + const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device()); + const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); + + cudaError_t status = TwoStageHolisticPlan( + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), + kv_indptr.data_ptr(), kv_len.data_ptr(), batch_size, num_qo_heads, + num_kv_heads, head_dim_o, causal, stream); + + TORCH_CHECK(status == cudaSuccess, + "Failed to plan persistent paged attention, error: ", cudaGetErrorString(status)); + + return vec_to_tensor(plan_info.ToVector()); +} + +void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor k_cache, + at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o, + std::optional maybe_lse, int64_t mask_mode_code, + int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t page_size, double sm_scale ADDITIONAL_FUNC_PARAMS) { + HolisticPlanInfo<2> plan_info; + plan_info.FromVector(tensor_to_vec(plan_info_vec)); + + auto device = q.device(); + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + const MaskMode mask_mode = static_cast(mask_mode_code); + + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = k_cache.scalar_type(); + + // NOTE (Yilong): assume both q and o are NHD + unsigned int q_stride_n = q.stride(0); + unsigned int q_stride_h = q.stride(1); + + // layout only constraint paged KV + const QKVLayout kv_layout = static_cast(layout_code); + unsigned int k_stride_page = k_cache.stride(0); + unsigned int v_stride_page = v_cache.stride(0); + unsigned int k_stride_n, k_stride_h, v_stride_n, v_stride_h; + if (kv_layout == QKVLayout::kNHD) { + k_stride_h = k_cache.stride(2); + k_stride_n = k_cache.stride(1); + v_stride_h = v_cache.stride(2); + v_stride_n = v_cache.stride(1); + } else { + k_stride_h = k_cache.stride(1); + k_stride_n = k_cache.stride(2); + v_stride_h = v_cache.stride(1); + v_stride_n = v_cache.stride(2); + } + + const c10::cuda::OptionalCUDAGuard device_guard(device); + const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); + + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + AttentionVariant, PersistentParams, [&] { + PersistentParams params[2]; + + for (int i = 0; i < 2; i++) { + params[i].q = static_cast(q.data_ptr()); + params[i].k = static_cast(k_cache.data_ptr()); + params[i].v = static_cast(v_cache.data_ptr()); + + params[i].q_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].q_indptr_offset); + params[i].kv_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].kv_indptr_offset); + params[i].partial_indptr = GetPtrFromBaseOffset( + int_buffer_ptr, plan_info.tasks[i].partial_indptr_offset); + params[i].kv_indices = static_cast(kv_indices.data_ptr()); + params[i].q_len = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].q_len_offset); + params[i].kv_len = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].kv_len_offset); + params[i].q_start = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].q_start_offset); + params[i].kv_start = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].kv_start_offset); + params[i].kv_end = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].kv_end_offset); + params[i].kv_head_idx_arr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].kv_head_idx_offset); + params[i].work_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].work_indptr_offset); + params[i].len_kv_chunk = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].len_kv_chunk_offset); + + params[i].final_o = static_cast(o.data_ptr()); + params[i].final_lse = + maybe_lse.has_value() ? static_cast(maybe_lse->data_ptr()) : nullptr; + params[i].partial_o = + GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); + params[i].partial_lse = + GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_lse_offset); + + // for state reduction + params[i].merge_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); + params[i].merge_o_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_o_indices_offset); + params[i].num_packed_qo_len = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.num_qo_len_offset); + + params[i].num_kv_heads = num_kv_heads; + params[i].gqa_group_size = uint_fastdiv(num_qo_heads / num_kv_heads); + params[i].page_size = uint_fastdiv(page_size); + + params[i].q_stride_n = q_stride_n; + params[i].q_stride_h = q_stride_h; + params[i].k_stride_page = k_stride_page; + params[i].k_stride_h = k_stride_h; + params[i].k_stride_n = k_stride_n; + params[i].v_stride_page = v_stride_page; + params[i].v_stride_h = v_stride_h; + params[i].v_stride_n = v_stride_n; + + params[i].sm_scale = sm_scale; + + ADDITIONAL_PARAMS_SETTER + } + + cudaError_t status = BatchPagedAttentionPersistent<128, 16, HEAD_DIM_QK, HEAD_DIM_VO, + MASK_MODE, AttentionVariant>( + params[0], params[1], plan_info.num_blks_x, plan_info.num_blks_y, stream); + TORCH_CHECK(status == cudaSuccess, "Failed to run persistent paged attention, error: ", + cudaGetErrorString(status)); + return true; + }); +} diff --git a/csrc/batch_attention_customize_config.jinja b/csrc/batch_attention_customize_config.jinja new file mode 100644 index 0000000000..7a2875494a --- /dev/null +++ b/csrc/batch_attention_customize_config.jinja @@ -0,0 +1,96 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace flashinfer; + +#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} +#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} + +{{ variant_decl }} + +struct StandardAttention : AttentionVariantBase { + float sm_scale_log2; + + PROFILER_CLOSURE_PARAMS_DECL + + template + __device__ __host__ StandardAttention(const Params& params, uint32_t batch_idx, + uint8_t* smem_ptr) { + sm_scale_log2 = params.sm_scale * math::log2e; + } +}; + +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, AttentionVariant, Params, ...) \ + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ + using AttentionVariant = {{ variant_name }}; \ + __VA_ARGS__(); \ + }) + +using DTypeQ = {{ dtype_q }}; +using DTypeKV = {{ dtype_kv }}; +using DTypeO = {{ dtype_o }}; +using IdType = {{ idtype }}; + +constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; +constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; +constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }}; + +struct PersistentParams { + using DTypeQ = DTypeQ; + using DTypeKV = DTypeKV; + using DTypeO = DTypeO; + using IdType = IdType; + + DTypeQ* q; + DTypeKV* k; + DTypeKV* v; + DTypeO* o; + DTypeO* partial_o; + float* partial_lse; + DTypeO* final_o; + float* final_lse; + + IdType* q_indptr; + IdType* kv_indptr; + IdType* partial_indptr; + IdType* kv_indices; + IdType* q_len; + IdType* kv_len; + IdType* q_start; + IdType* kv_start; + IdType* kv_end; + IdType* kv_head_idx_arr; + IdType* work_indptr; + IdType* len_kv_chunk; + + // for state reduction + IdType* merge_indptr; + IdType* merge_o_indices; + IdType* num_packed_qo_len; + + uint32_t num_kv_heads; + uint_fastdiv gqa_group_size; + uint_fastdiv page_size; + + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t k_stride_page; + uint32_t k_stride_h; + uint32_t k_stride_n; + uint32_t v_stride_page; + uint32_t v_stride_h; + uint32_t v_stride_n; + + float sm_scale; + + {{ additional_params_decl }} + + PROFILER_PARAMS_DECL +}; diff --git a/csrc/batch_attention_jit_pybind.cu b/csrc/batch_attention_jit_pybind.cu new file mode 100644 index 0000000000..562abc5a58 --- /dev/null +++ b/csrc/batch_attention_jit_pybind.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "batch_attention_config.inc" +#include "pytorch_extension_utils.h" + +at::Tensor BatchPagedAttentionPlan(at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len, + int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t head_dim_o, bool causal); + +void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor k_cache, + at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o, + std::optional maybe_lse, int64_t mask_mode_code, + int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t page_size, double sm_scale ADDITIONAL_FUNC_PARAMS); + +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + m.def("plan", &BatchPagedAttentionPlan); + m.def("run", &BatchPagedAttentionRun); +} diff --git a/csrc/batch_attention_paged_kernel_inst.jinja b/csrc/batch_attention_paged_kernel_inst.jinja new file mode 100644 index 0000000000..b6a915feca --- /dev/null +++ b/csrc/batch_attention_paged_kernel_inst.jinja @@ -0,0 +1,9 @@ +#include +#include "batch_attention_config.inc" + +namespace flashinfer { +template cudaError_t BatchPagedAttentionPersistent< + /*CTA_TILE_Q_1=*/128, /*CTA_TILE_Q_2=*/16, {{head_dim_qk}}, {{head_dim_vo}}, {{mask_mode}}, + {{ variant_name }}, PersistentParams>(const PersistentParams params_1, const PersistentParams params_2, + const uint32_t num_blks_x, const uint32_t num_blks_y, const cudaStream_t stream); +}; // namespace flashinfer diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 2befd347ac..b92d32b89e 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -24,6 +24,7 @@ from .activation import gelu_and_mul as gelu_and_mul from .activation import gelu_tanh_and_mul as gelu_tanh_and_mul from .activation import silu_and_mul as silu_and_mul +from .attention import BatchAttention as BatchAttention from .cascade import ( BatchDecodeWithSharedPrefixPagedKVCacheWrapper as BatchDecodeWithSharedPrefixPagedKVCacheWrapper, ) diff --git a/flashinfer/attention.py b/flashinfer/attention.py new file mode 100644 index 0000000000..fa5fb56dbc --- /dev/null +++ b/flashinfer/attention.py @@ -0,0 +1,177 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import functools +import logging +import math +from types import SimpleNamespace +from typing import Any, List, Literal, Optional, Tuple, Union, overload + +import torch + +from .jit import gen_batch_attention_module +from .utils import ( + MaskMode, + PosEncodingMode, + TensorLayout, + _check_kv_layout, + _unpack_paged_kv_cache, +) + +_batch_attention_modules = {} + + +def get_holistic_attention_module(): + def backend_module(*args): + global _batch_attention_modules + modules_dict = _batch_attention_modules + if args not in modules_dict: + module = gen_batch_attention_module(*args).build_and_load() + return module + + return backend_module + + +class BatchAttention: + def __init__( + self, + kv_layout: str = "NHD", + ): + _check_kv_layout(kv_layout) + self._kv_layout = kv_layout + + self.float_workspace_buffer = torch.empty( + 256 * 1024 * 1024, + dtype=torch.uint8, + device=torch.device("cuda"), + ) + self.int_workspace_buffer = torch.empty( + 8 * 1024 * 1024, + dtype=torch.uint8, + device=torch.device("cuda"), + ) + self.page_locked_int_workspace_buffer = torch.empty( + 8 * 1024 * 1024, + dtype=torch.uint8, + device=torch.device("cpu"), + pin_memory=True, + ) + + def plan( + self, + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + kv_indices: torch.Tensor, + kv_len_arr: torch.Tensor, + num_qo_heads: int, + num_kv_heads: int, + head_dim_qk: int, + head_dim_vo: int, + page_size: int, + causal: bool = False, + sm_scale: float = None, + q_data_type: torch.dtype = torch.bfloat16, + kv_data_type: torch.dtype = torch.bfloat16, + use_profiler: bool = False, + ) -> None: + # get jit module + get_module_args = ( + q_data_type, + kv_data_type, + q_data_type, + kv_indptr.dtype, + head_dim_qk, + head_dim_vo, + PosEncodingMode["NONE"].value, + ) + self.module = get_holistic_attention_module()(*get_module_args) + + qo_indptr_host = qo_indptr.to(torch.device("cpu"), non_blocking=True) + kv_indptr_host = kv_indptr.to(torch.device("cpu"), non_blocking=True) + kv_len_arr_host = kv_len_arr.to(torch.device("cpu"), non_blocking=True) + torch.cuda.synchronize() + + batch_size = kv_len_arr.shape[0] + self._page_size = page_size + self._sm_scale = sm_scale + self._mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value + self._num_qo_heads = num_qo_heads + self._num_kv_heads = num_kv_heads + self._page_size = page_size + self._sm_scale = sm_scale + self._use_profiler = use_profiler + + # No addtional buf allocated for CUDA graph tensor + # Allocate outside FlashInfer + self._kv_indices = kv_indices + + self._plan_info = self.module.plan( + self.float_workspace_buffer, + self.int_workspace_buffer, + self.page_locked_int_workspace_buffer, + qo_indptr_host, + kv_indptr_host, + kv_len_arr_host, + batch_size, + num_qo_heads, + num_kv_heads, + head_dim_vo, + causal, + ) + + def run( + self, + q: torch.Tensor, + kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + profiler_buffer: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if profiler_buffer is None: + if self._use_profiler: + raise ValueError( + "Profiler is enabled, profiler_buffer must be provided" + ) + k_cache, v_cache = _unpack_paged_kv_cache(kv_cache, self._kv_layout) + if out is None: + out = torch.empty_like(q) + if lse is None: + # lse shape: [batch_size, num_qo_heads] + lse = torch.empty(q.shape[0], q.shape[1], device=q.device) + + head_dim_qk = q.shape[2] + if self._sm_scale is None: + self._sm_scale = 1.0 / math.sqrt(head_dim_qk) + + self.module.run( + self.float_workspace_buffer, + self.int_workspace_buffer, + self._plan_info, + q, + k_cache, + v_cache, + self._kv_indices, + out, + lse, + self._mask_mode, + TensorLayout[self._kv_layout].value, + self._num_qo_heads, + self._num_kv_heads, + self._page_size, + self._sm_scale, + ) + + return out, lse diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 8db1f5a37c..c2c82dd9b8 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -22,6 +22,7 @@ from . import env as env from .activation import gen_act_and_mul_module as gen_act_and_mul_module from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str +from .attention import gen_batch_attention_module as gen_batch_attention_module from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module from .attention import gen_batch_decode_module as gen_batch_decode_module from .attention import gen_batch_mla_module as gen_batch_mla_module @@ -50,6 +51,7 @@ from .attention import gen_sampling_tvm_binding as gen_sampling_tvm_binding from .attention import gen_single_decode_module as gen_single_decode_module from .attention import gen_single_prefill_module as gen_single_prefill_module +from .attention import get_batch_attention_uri as get_batch_attention_uri from .attention import get_batch_decode_mla_uri as get_batch_decode_mla_uri from .attention import get_batch_decode_uri as get_batch_decode_uri from .attention import get_batch_mla_uri as get_batch_mla_uri diff --git a/flashinfer/jit/attention/__init__.py b/flashinfer/jit/attention/__init__.py index 7ef679e6b7..b44ed3e076 100644 --- a/flashinfer/jit/attention/__init__.py +++ b/flashinfer/jit/attention/__init__.py @@ -15,6 +15,7 @@ """ from . import pytorch, tvm +from .pytorch import gen_batch_attention_module as gen_batch_attention_module from .pytorch import gen_batch_decode_mla_module as gen_batch_decode_mla_module from .pytorch import gen_batch_decode_module as gen_batch_decode_module from .pytorch import gen_batch_mla_module as gen_batch_mla_module @@ -35,6 +36,7 @@ from .pytorch import gen_pod_module as gen_pod_module from .pytorch import gen_single_decode_module as gen_single_decode_module from .pytorch import gen_single_prefill_module as gen_single_prefill_module +from .pytorch import get_batch_attention_uri as get_batch_attention_uri from .pytorch import get_batch_decode_mla_uri as get_batch_decode_mla_uri from .pytorch import get_batch_decode_uri as get_batch_decode_uri from .pytorch import get_batch_mla_uri as get_batch_mla_uri diff --git a/flashinfer/jit/attention/pytorch.py b/flashinfer/jit/attention/pytorch.py index cacc408151..5ca39220c9 100644 --- a/flashinfer/jit/attention/pytorch.py +++ b/flashinfer/jit/attention/pytorch.py @@ -387,6 +387,26 @@ def get_batch_prefill_uri( ) +def get_batch_attention_uri( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim_qk: int, + head_dim_vo: int, + pos_encoding_mode: int, +) -> str: + return ( + f"batch_attention_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" + f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" + f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" + f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" + f"head_dim_qk_{head_dim_qk}_" + f"head_dim_vo_{head_dim_vo}_" + f"posenc_{pos_encoding_mode}" + ) + + def gen_single_decode_module( dtype_q: torch.dtype, dtype_kv: torch.dtype, @@ -831,6 +851,92 @@ def gen_batch_prefill_module( ) +def gen_batch_attention_module( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim_qk: int, + head_dim_vo: int, + pos_encoding_mode: int, +): + uri = get_batch_attention_uri( + dtype_q, + dtype_kv, + dtype_o, + dtype_idx, + head_dim_qk, + head_dim_vo, + pos_encoding_mode, + ) + additional_tensor_names = [] + additional_tensor_dtypes = [] + additional_scalar_names = [] + additional_scalar_dtypes = [] + variant_name = f"StandardAttention" + variant_decl = f"#include" + + return gen_customize_batch_attention_module( + uri, + dtype_q, + dtype_kv, + dtype_o, + dtype_idx, + head_dim_qk, + head_dim_vo, + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, + variant_name, + variant_decl, + pos_encoding_mode=pos_encoding_mode, + ) + + +def gen_batch_attention_module( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim_qk: int, + head_dim_vo: int, + pos_encoding_mode: int, +): + uri = get_batch_attention_uri( + dtype_q, + dtype_kv, + dtype_o, + dtype_idx, + head_dim_qk, + head_dim_vo, + pos_encoding_mode, + ) + additional_tensor_names = [] + additional_tensor_dtypes = [] + additional_scalar_names = [] + additional_scalar_dtypes = [] + variant_name = f"StandardAttention" + variant_decl = f"#include" + + return gen_customize_batch_attention_module( + uri, + dtype_q, + dtype_kv, + dtype_o, + dtype_idx, + head_dim_qk, + head_dim_vo, + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, + variant_name, + variant_decl, + pos_encoding_mode=pos_encoding_mode, + ) + + def gen_customize_single_decode_module( uri: str, dtype_q: torch.dtype, @@ -1421,3 +1527,90 @@ def trtllm_fmha_gen_module(): ], extra_ldflags=["-lcuda"], ) + + +def gen_customize_batch_attention_module( + uri: str, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + idtype: torch.dtype, + head_dim_qk: int, + head_dim_vo: int, + additional_tensor_names: List[str], + additional_tensor_dtypes: List[str], + additional_scalar_names: List[str], + additional_scalar_dtypes: List[str], + variant_name: str, + variant_decl: str, + pos_encoding_mode: int = 0, +): + kwargs = { + "variant_decl": variant_decl, + "variant_name": variant_name, + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "idtype": dtype_map[idtype], + "head_dim_qk": head_dim_qk, + "head_dim_vo": head_dim_vo, + "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], + } + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri + (additional_params_decl, additional_func_params, additional_params_setter) = ( + generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, + ) + ) + + with open( + jit_env.FLASHINFER_CSRC_DIR / "batch_attention_customize_config.jinja" + ) as f: + config_templ = jinja2.Template(f.read()) + + with open( + jit_env.FLASHINFER_CSRC_DIR / "batch_attention_paged_kernel_inst.jinja" + ) as f: + paged_kernel_inst_templ = jinja2.Template(f.read()) + + kwargs |= { + "additional_params_decl": additional_params_decl, + "additional_func_params": additional_func_params, + "additional_params_setter": additional_params_setter, + } + + generated_inc_str = config_templ.render( + **kwargs, + ) + os.makedirs(gen_directory, exist_ok=True) + + source_paths = [] + for mask_mode in [0, 1, 2, 3]: + dest_path = gen_directory / f"batch_attention_paged_kernel_mask_{mask_mode}.cu" + source_paths.append(dest_path) + source = paged_kernel_inst_templ.render( + mask_mode=mask_mode_literal[mask_mode], + **kwargs, + ) + write_if_different(dest_path, source) + + for filename in [ + "batch_attention.cu", + "batch_attention_jit_pybind.cu", + ]: + src_path = jit_env.FLASHINFER_CSRC_DIR / filename + dest_path = gen_directory / filename + source_paths.append(dest_path) + with open(src_path, "r") as f: + source = f.read() + write_if_different(dest_path, source) + + generated_config_path = gen_directory / "batch_attention_config.inc" + write_if_different(generated_config_path, generated_inc_str) + return gen_jit_spec( + uri, + source_paths, + ) diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index 1c72997081..0ff1429794 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -117,8 +117,9 @@ __global__ void MergeStateInPlaceKernel(DType* __restrict__ v, float* __restrict template __device__ __forceinline__ void threadblock_sync_state(state_t& st, DTypeIn* v_smem, - float* s_smem) { - const uint32_t tx = threadIdx.x, ty = threadIdx.y; + float* s_smem, + const uint32_t tx = threadIdx.x, + const uint32_t ty = threadIdx.y) { constexpr uint32_t head_dim = vec_size * bdx; st.o.cast_store(v_smem + ty * head_dim + tx * vec_size); s_smem[ty] = st.get_lse(); diff --git a/include/flashinfer/attention/persistent.cuh b/include/flashinfer/attention/persistent.cuh new file mode 100644 index 0000000000..95ecbabd69 --- /dev/null +++ b/include/flashinfer/attention/persistent.cuh @@ -0,0 +1,566 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_PERSISTENT_CUH_ +#define FLASHINFER_PERSISTENT_CUH_ + +#include "../cp_async.cuh" +#include "../math.cuh" +#include "../utils.cuh" +#include "mask.cuh" +#include "persistent_template.cuh" +#include "prefill.cuh" +#include "state.cuh" + +namespace flashinfer { + +using cp_async::PrefetchMode; +using cp_async::SharedMemFillMode; + +template +__device__ __forceinline__ auto get_block_coord(const Params& params, const uint32_t work_idx) { + return std::tuple(params.q_indptr[work_idx], params.kv_indptr[work_idx], + params.partial_indptr[work_idx], params.q_len[work_idx], + params.kv_len[work_idx], params.q_start[work_idx], params.kv_start[work_idx], + params.kv_end[work_idx], params.kv_head_idx_arr[work_idx], + params.len_kv_chunk[work_idx]); +} + +template +__device__ __forceinline__ void prefetch_offest( + const uint32_t packed_block_iter_base, const uint32_t packed_kv_bound, + const uint32_t kv_head_idx, const uint32_t kv_stride_page, const uint32_t kv_stride_h, + const uint32_t kv_stride_n, const uint_fastdiv& block_size, typename KTraits::IdType* indices, + size_t* kv_offset) { + using DTypeKV = typename KTraits::DTypeKV; + constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; + constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; + const uint32_t lane_idx = threadIdx.x % 32, warp_idx = threadIdx.x / 32; + +#pragma unroll + for (uint32_t i = 0; + i < NUM_MMA_KV * (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) { + uint32_t page_iter, entry_idx; + uint32_t packed_block_iter = packed_block_iter_base + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL + + KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i; + block_size.divmod(packed_block_iter, page_iter, entry_idx); + kv_offset[i] = (packed_block_iter < packed_kv_bound ? indices[page_iter] : 0) * kv_stride_page + + entry_idx * kv_stride_n + kv_head_idx * kv_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + } +} + +template +__device__ __forceinline__ void write_o_(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], + smem_t* o_smem, + typename KTraits::DTypeO* o_ptr_base, + const uint32_t o_packed_idx_base_warp, + const uint32_t o_packed_idx_base_cta, + const uint32_t qo_upper_bound, const uint32_t o_stride_n, + const uint_fastdiv group_size, const uint32_t warp_idx, + const uint32_t lane_idx, const dim3 tid) { + using DTypeO = typename KTraits::DTypeO; + constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + const uint32_t warp_idx_x = get_warp_idx_q(tid.y), + warp_idx_z = get_warp_idx_kv(tid.z); + + static_assert(sizeof(DTypeO) == 2); + if (warp_idx_z == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + uint32_t o_frag_f16[8 / 2]; + vec_cast::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_q][mma_d]); + +#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED + uint32_t o_smem_offset_w = o_smem->get_permuted_offset( + (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16, + mma_d * 2 + lane_idx / 16); + o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); +#else + uint32_t o_smem_offset_w = o_smem->get_permuted_offset( + (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2); + ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; + ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = + o_frag_f16[1]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = + o_frag_f16[3]; +#endif + } + } + + uint32_t o_smem_offset_w = o_smem->get_permuted_offset( + warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2 * 2; ++j) { + uint32_t q, r; + const uint32_t o_packed_idx = o_packed_idx_base_warp + lane_idx / 8 + mma_q * 16 + j * 4; + group_size.divmod(o_packed_idx, q, r); + + const uint32_t o_idx = q; + DTypeO* o_ptr = o_ptr_base + (o_packed_idx - o_packed_idx_base_cta) * o_stride_n + + (lane_idx % 8) * upcast_size(); +#pragma unroll + for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) { + if (o_idx < qo_upper_bound) { + o_smem->store_128b(o_smem_offset_w, o_ptr); + } + o_ptr += 8 * upcast_size(); + o_smem_offset_w = o_smem->template advance_offset_by_column<8>(o_smem_offset_w, mma_do); + } + o_smem_offset_w = + o_smem->template advance_offset_by_row<4, UPCAST_STRIDE_O>(o_smem_offset_w) - + 2 * KTraits::NUM_MMA_D_VO; + } + } + } +} + +template +struct BlockBatchPagedAttentionPersistent { + using KTraits = KTraits_; + using Params = Params_; + + static __device__ __forceinline__ void Run(const Params& params, + typename KTraits::SharedStorage* smem_storage) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = KTraits::SWIZZLE_MODE_Q; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; + [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; + [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + [[maybe_unused]] constexpr bool CAUSAL = KTraits::MASK_MODE == MaskMode::kCausal; + [[maybe_unused]] constexpr uint32_t NUM_STAGES = KTraits::NUM_STAGES; + + DTypeQ* q = params.q; + DTypeKV* k = params.k; + DTypeKV* v = params.v; + IdType* kv_indices = params.kv_indices; + DTypeO* partial_o = params.partial_o; + float* partial_lse = params.partial_lse; + IdType* work_indptr = params.work_indptr; + + float s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; + float m[NUM_MMA_Q][2]; + float d[NUM_MMA_Q][2]; + + const uint_fastdiv& gqa_group_size = params.gqa_group_size; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint_fastdiv& block_size = params.page_size; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + const uint32_t k_stride_page = params.k_stride_page; + const uint32_t k_stride_h = params.k_stride_h; + const uint32_t k_stride_n = params.k_stride_n; + const uint32_t v_stride_page = params.v_stride_page; + const uint32_t v_stride_h = params.v_stride_h; + const uint32_t v_stride_n = params.v_stride_n; + const uint32_t cluster_tile_q = gridDim.x * CTA_TILE_Q; + smem_t q_smem(smem_storage->q_smem); + + AttentionVariant variant(params, /*batch_idx=*/0, nullptr); + + const uint32_t lane_idx = threadIdx.x % 32; + const uint32_t warp_idx = threadIdx.x / 32; + + // threadIdx: [32, NUM_WARPS_Q, NUM_WARPS_KV] + // remap to utilize tool function in FA2 prefill + const dim3 tid = dim3(lane_idx, warp_idx % NUM_WARPS_Q, warp_idx / NUM_WARPS_Q); + + uint32_t q_smem_offset_r = get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); + uint32_t k_smem_offset_r = get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + + lane_idx % 8, + (lane_idx % 16) / 8), + v_smem_offset_r = get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16); + uint32_t k_smem_offset_w = get_permuted_offset( + warp_idx * KTraits::KV_THR_LAYOUT_ROW + lane_idx / KTraits::KV_THR_LAYOUT_COL, + lane_idx % KTraits::KV_THR_LAYOUT_COL), + v_smem_offset_w = get_permuted_offset( + warp_idx * KTraits::KV_THR_LAYOUT_ROW + lane_idx / KTraits::KV_THR_LAYOUT_COL, + lane_idx % KTraits::KV_THR_LAYOUT_COL); + size_t thr_local_kv_offset[NUM_MMA_KV * KTraits::KV_THR_LAYOUT_COL / 2 / KTraits::NUM_WARPS_Q]; + +#pragma unroll 1 + for (IdType work_idx = work_indptr[blockIdx.y]; work_idx < work_indptr[blockIdx.y + 1]; + ++work_idx) { + const auto [q_indptr, kv_indptr, o_indptr, q_len, kv_len, packed_qo_start, kv_start, kv_end, + kv_head_idx, len_kv_chunk] = get_block_coord(params, work_idx); + + const uint32_t kv_chunk_idx = ceil_div(kv_start, len_kv_chunk); + const uint32_t num_kv_chunks = ceil_div( + CAUSAL + ? min((kv_len - q_len) + (packed_qo_start + cluster_tile_q) / gqa_group_size, kv_len) + : kv_len, + len_kv_chunk); + + const uint32_t qo_packed_idx_base = packed_qo_start + blockIdx.x * CTA_TILE_Q + + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16; + const uint32_t qo_upperbound = + min(q_len, ceil_div(qo_packed_idx_base + CTA_TILE_Q, gqa_group_size)); + + init_states(variant, o_frag, m, d); + + DTypeQ* q_ptr_base = q + q_indptr * q_stride_n + (kv_head_idx * gqa_group_size) * q_stride_h; + + // load_q + load_q_global_smem(qo_packed_idx_base, qo_upperbound, q_ptr_base, q_stride_n, + q_stride_h, gqa_group_size, &q_smem, tid); + + smem_t k_smem(smem_storage->k_smem), v_smem(smem_storage->v_smem); + int kv_tile_idx = + ceil_div((CAUSAL ? min(kv_end, kv_len - q_len + + (packed_qo_start + cluster_tile_q) / gqa_group_size) + : kv_end), + CTA_TILE_KV) - + 1 - (kv_start / CTA_TILE_KV); + + int mask_tile_idx = + (CAUSAL ? min(kv_end, kv_len - q_len + packed_qo_start / gqa_group_size) : kv_end) / + CTA_TILE_KV - + (kv_start / CTA_TILE_KV); + + uint32_t block_iter_base = kv_indptr * block_size + kv_start; + // last kv tile + __syncthreads(); + uint32_t packed_kv_bound = kv_indptr * block_size + kv_len; + + prefetch_offest(block_iter_base + kv_tile_idx * CTA_TILE_KV, packed_kv_bound, + kv_head_idx, k_stride_page, k_stride_h, k_stride_n, block_size, + kv_indices, thr_local_kv_offset); + page_produce_kv(smem_storage, &k_smem_offset_w, k, + kv_start + kv_tile_idx * CTA_TILE_KV, thr_local_kv_offset, + kv_end, warp_idx, lane_idx); + cp_async::commit_group(); + page_produce_kv(smem_storage, &v_smem_offset_w, v, + kv_start + kv_tile_idx * CTA_TILE_KV, thr_local_kv_offset, + kv_end, warp_idx, lane_idx); + cp_async::commit_group(); + + // loop with mask + LOOP_SPLIT_MASK( + kv_tile_idx, kv_tile_idx >= mask_tile_idx && kv_tile_idx > 0, + kv_tile_idx + 1 > NUM_STAGES, { + prefetch_offest(block_iter_base + (kv_tile_idx - 1) * CTA_TILE_KV, + packed_kv_bound, kv_head_idx, k_stride_page, k_stride_h, + k_stride_n, block_size, kv_indices, thr_local_kv_offset); + cp_async::wait_group<1>(); + __syncthreads(); + + compute_qk(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + if constexpr (WITH_MASK) { + logits_mask( + params, variant, /*batch_idx=*/0, qo_packed_idx_base, + kv_start + (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * + NUM_MMA_KV * 16, + q_len, kv_len, kv_end, gqa_group_size, s_frag, tid, kv_head_idx); + } + update_mdo_states(variant, s_frag, o_frag, m, d); + + __syncthreads(); + page_produce_kv(smem_storage, &k_smem_offset_w, k, + kv_start + (kv_tile_idx - 1) * CTA_TILE_KV, + thr_local_kv_offset, kv_end, warp_idx, lane_idx); + cp_async::commit_group(); + cp_async::wait_group<1>(); + + __syncthreads(); + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); + __syncthreads(); + + page_produce_kv(smem_storage, &v_smem_offset_w, v, + kv_start + (kv_tile_idx - 1) * CTA_TILE_KV, + thr_local_kv_offset, kv_end, warp_idx, lane_idx); + cp_async::commit_group(); + }); + cp_async::wait_group<0>(); + __syncthreads(); + +#pragma unroll + for (; kv_tile_idx >= 0; --kv_tile_idx) { + compute_qk(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + logits_mask( + params, variant, /*batch_idx=*/0, qo_packed_idx_base, + kv_start + + (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, + q_len, kv_len, kv_end, gqa_group_size, s_frag, tid, kv_head_idx); + update_mdo_states(variant, s_frag, o_frag, m, d); + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); + } + + __syncthreads(); + + finalize_m(variant, m); + + // threadblock synchronization + threadblock_sync_mdo_states(o_frag, smem_storage, m, d, warp_idx, lane_idx, tid); + + // normalize d + normalize_d(o_frag, m, d); + + // write back to global memory + // o_indptr: [packed_qo_len * num_kv_chunks, num_kv_heads, head_dim] + DTypeO* o_ptr_base = + partial_o + ((o_indptr + kv_chunk_idx) * num_kv_heads + kv_head_idx) * HEAD_DIM_VO; + write_o_(o_frag, &q_smem, o_ptr_base, qo_packed_idx_base, packed_qo_start, + qo_upperbound, num_kv_chunks * num_kv_heads * HEAD_DIM_VO, gqa_group_size, + warp_idx, lane_idx, tid); + // write lse to partial lse + if constexpr (variant.use_softmax) { + if (get_warp_idx_kv(tid.z) == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t q, r; + const uint32_t packed_qo_idx = qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16; + gqa_group_size.divmod(packed_qo_idx, q, r); + if (q < qo_upperbound) { + partial_lse[(o_indptr + (packed_qo_idx - packed_qo_start) * num_kv_chunks + + kv_chunk_idx) * + num_kv_heads + + kv_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + } + } + } + } + } + } + } +}; + +template +struct StateReductionKernelTraits { + using DTypeIn = DTypeIn_; + using DTypeO = DTypeO_; + using IdType = IdType_; + + static constexpr uint32_t HEAD_DIM_VO = HEAD_DIM_VO_; + static constexpr uint32_t NUM_SMEM_STAGES = NUM_SMEM_STAGES_; + static constexpr uint32_t NUM_THREADS = NUM_THREADS_; + + static constexpr uint32_t vec_size = (16U / sizeof(DTypeIn)) > (HEAD_DIM_VO / 32U) + ? (16U / sizeof(DTypeIn)) + : (HEAD_DIM_VO / 32U); + static constexpr uint32_t bdx = HEAD_DIM_VO / vec_size; + + // gridDim is accessed by runtime variable and should be set by core attention + static_assert(NUM_THREADS % bdx == 0); + static constexpr uint32_t bdy = NUM_THREADS / bdx; + + // pipeline load & reduction + static constexpr size_t SMEM_SIZE = + NUM_SMEM_STAGES * bdy * HEAD_DIM_VO * sizeof(DTypeIn) + NUM_THREADS * sizeof(float); +}; + +template +struct BlockBatchReductionPersistent { + using KTraits = KTraits_; + + static __device__ __forceinline__ void Run( + typename KTraits::DTypeIn* __restrict__ V, typename KTraits::DTypeO* __restrict__ v_merged, + float* __restrict__ S, float* __restrict__ s_merged, + const typename KTraits::IdType num_packed_qo_len, const uint_fastdiv gqa_group_size, + const uint32_t num_kv_heads, const typename KTraits::IdType* indptr, + const typename KTraits::IdType* o_indices, uint8_t* smem) { + using DTypeIn = typename KTraits::DTypeIn; + using DTypeO = typename KTraits::DTypeO; + using IdType = typename KTraits::IdType; + + [[maybe_unused]] constexpr uint32_t bdx = KTraits::bdx; + [[maybe_unused]] constexpr uint32_t bdy = KTraits::bdy; + [[maybe_unused]] constexpr uint32_t vec_size = KTraits::vec_size; + [[maybe_unused]] constexpr uint32_t head_dim = KTraits::HEAD_DIM_VO; + [[maybe_unused]] constexpr uint32_t num_smem_stages = KTraits::NUM_SMEM_STAGES; + [[maybe_unused]] constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; + + // control flow metadata + uint32_t tx = threadIdx.x % bdx, ty = threadIdx.x / bdx; + uint32_t cta_id = blockIdx.y; + uint32_t num_ctas = gridDim.x * gridDim.y * gridDim.z; + + DTypeIn* v_smem = (DTypeIn*)smem; + float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn)); + + // V: [num_packed_qo_len x num_kv_tiles, num_kv_heads, head_dim] + // v_merged: [qo_len, num_kv_heads, gqa_group_size, head_dim] +#pragma unroll 1 + for (uint32_t i = cta_id; i < num_packed_qo_len * num_kv_heads; i += num_ctas) { + // remap workload + uint32_t packed_qo_idx = i / num_kv_heads; + uint32_t kv_head_idx = i % num_kv_heads; + uint32_t qo_head_idx = packed_qo_idx % gqa_group_size; + + // index calculation + auto partial_idx_to_offset = [&](uint32_t off) { + return (indptr[packed_qo_idx] + off) * num_kv_heads + kv_head_idx; + }; + auto merge_idx_to_offset = [&]() { + return (o_indices[packed_qo_idx] * num_kv_heads + kv_head_idx) * gqa_group_size + + qo_head_idx; + }; + + state_t st; + const uint32_t num_index_sets = indptr[packed_qo_idx + 1] - indptr[packed_qo_idx]; + + if (num_index_sets == 0) { + vec_t v; + v.fill(DTypeO(0.f)); + v.store(v_merged + merge_idx_to_offset() * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[merge_idx_to_offset()] = -math::inf; + } + continue; + } + + if (num_index_sets == 1) { + vec_t v; + v.cast_load(V + partial_idx_to_offset(0) * head_dim + tx * vec_size); + v.store(v_merged + merge_idx_to_offset() * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[merge_idx_to_offset()] = S[partial_idx_to_offset(0)]; + } + continue; + } + +#pragma unroll + for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { + cp_async::pred_load( + v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, + V + partial_idx_to_offset(iter * bdy + ty) * head_dim + tx * vec_size, + (iter * bdy + ty) < num_index_sets); + cp_async::commit_group(); + } +#pragma unroll 4 + for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) { + if (iter % bdx == 0) { + s_smem[ty * bdx + tx] = iter * bdy + (ty * bdx + tx) < num_index_sets + ? S[partial_idx_to_offset(iter * bdy + ty * bdx + tx)] + : 0.f; + __syncthreads(); + } + cp_async::wait_group(); + __syncthreads(); + vec_t v; + v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size); + if (iter * bdy + ty < num_index_sets) { + float s = s_smem[(iter % bdx) * bdy + ty]; + st.merge(v, s, 1); + } + __syncthreads(); + cp_async::pred_load( + v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, + V + partial_idx_to_offset((iter + num_smem_stages) * bdy + ty) * head_dim + + tx * vec_size, + (iter + num_smem_stages) * bdy + ty < num_index_sets); + cp_async::commit_group(); + } + cp_async::wait_group<0>(); + __syncthreads(); + + st.normalize(); + threadblock_sync_state(st, v_smem, s_smem, tx, ty); + st.normalize(); + + st.o.cast_store(v_merged + merge_idx_to_offset() * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[merge_idx_to_offset()] = st.get_lse(); + } + } + } +}; + +template +cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params params_2, + const uint32_t num_blks_x, const uint32_t num_blks_y, + const cudaStream_t stream) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + constexpr uint32_t NUM_WARPS_Q_1 = get_num_warps_q(CTA_TILE_Q_1); + constexpr uint32_t NUM_WARPS_KV_1 = get_num_warps_kv(CTA_TILE_Q_1); + constexpr uint32_t NUM_MMA_Q_1 = get_num_mma_q(CTA_TILE_Q_1); + constexpr uint32_t NUM_MMA_KV_1 = 4; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + using KTraits1 = KernelTraits; + constexpr uint32_t NUM_WARPS_Q_2 = get_num_warps_q(CTA_TILE_Q_2); + constexpr uint32_t NUM_WARPS_KV_2 = get_num_warps_kv(CTA_TILE_Q_2); + constexpr uint32_t NUM_MMA_Q_2 = get_num_mma_q(CTA_TILE_Q_2); + constexpr uint32_t NUM_MMA_KV_2 = 2; + using KTraits2 = KernelTraits; + + // Attention state reduction kernel + constexpr uint32_t NUM_THREADS = + KTraits1::NUM_THREADS > KTraits2::NUM_THREADS ? KTraits1::NUM_THREADS : KTraits2::NUM_THREADS; + using ReductionKTraits = + StateReductionKernelTraits; + size_t smem_size = + max(sizeof(typename KTraits1::SharedStorage), sizeof(typename KTraits2::SharedStorage)); + smem_size = max(smem_size, ReductionKTraits::SMEM_SIZE); + + // Launch persistent kernel + auto kernel = PersistentKernelTemplate, + BlockBatchPagedAttentionPersistent, + BlockBatchReductionPersistent>; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + dim3 nblks(num_blks_x, num_blks_y); + dim3 nthrs(NUM_THREADS); + void* args[] = {(void*)¶ms_1, (void*)¶ms_2}; + FLASHINFER_CUDA_CALL( + cudaLaunchCooperativeKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + return cudaSuccess; +} + +}; // namespace flashinfer + +#endif // FLASHINFER_PERSISTENT_CUH_ diff --git a/include/flashinfer/attention/persistent_template.cuh b/include/flashinfer/attention/persistent_template.cuh new file mode 100644 index 0000000000..07ec51ab58 --- /dev/null +++ b/include/flashinfer/attention/persistent_template.cuh @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_PERSISTENT_TEMPLATE_CUH +#define FLASHINFER_ATTENTION_PERSISTENT_TEMPLATE_CUH + +#include +#include + +#include + +#include "../profiler.cuh" + +namespace flashinfer { +namespace cg = cooperative_groups; + +// Define profiler event types for persistent kernels +enum class PersistentProfileEventType { + kRunner1 = 0U, + kRunner2 = 1U, + kRunner3 = 2U, + kRunner4 = 3U, +}; + +// Helper metafunction to find maximum threads among multiple BlockPersistentRunners +template +struct max_threads; + +template +struct max_threads { + static constexpr size_t value = Runner::KTraits::NUM_THREADS; +}; + +template +struct max_threads { + static constexpr size_t value = Runner1::KTraits::NUM_THREADS > Runner2::KTraits::NUM_THREADS + ? max_threads::value + : max_threads::value; +}; + +// Two runners version +template +__global__ __launch_bounds__( + max_threads:: + value) void PersistentKernelTemplate(const __grid_constant__ + typename BlockPersistentRunner1::Params params_1, + const __grid_constant__ + typename BlockPersistentRunner2::Params params_2) { + extern __shared__ uint8_t smem[]; + + PROFILER_INIT(params_1, smem, profiler_closure, 0, 1, (threadIdx.x == 0)); + + auto& smem_storage_1 = + reinterpret_cast(smem); + + PROFILER_EVENT_START(profiler_closure, PersistentProfileEventType::kRunner1); + BlockPersistentRunner1::Run(params_1, &smem_storage_1); + PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kRunner1); + + __syncthreads(); + + auto& smem_storage_2 = + reinterpret_cast(smem); + + PROFILER_EVENT_START(profiler_closure, PersistentProfileEventType::kRunner2); + BlockPersistentRunner2::Run(params_2, &smem_storage_2); + PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kRunner2); + + // NOTE(Yilong): optimize the barrier + auto grid = cg::this_grid(); + grid.sync(); + BlockReductionRunner::Run(params_1.partial_o, params_1.final_o, params_1.partial_lse, + params_1.final_lse, *(params_1.num_packed_qo_len), + params_1.gqa_group_size, params_1.num_kv_heads, params_1.merge_indptr, + params_1.merge_o_indices, smem); +} +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_PERSISTENT_TEMPLATE_CUH diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 4f7deb2f3f..1cd2aafe92 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -98,6 +98,7 @@ template struct KernelTraits { + static constexpr uint32_t NUM_STAGES = 1; // used for BatchAttention Template static constexpr MaskMode MASK_MODE = MASK_MODE_; static constexpr uint32_t NUM_MMA_Q = NUM_MMA_Q_; static constexpr uint32_t NUM_MMA_KV = NUM_MMA_KV_; @@ -322,12 +323,15 @@ __device__ __forceinline__ void produce_kv(smem_t smem } template -__device__ __forceinline__ void page_produce_kv( - smem_t smem, uint32_t* smem_offset, - const paged_kv_t& paged_kv, - const uint32_t kv_idx_base, const size_t* thr_local_kv_offset, const uint32_t kv_len, - const dim3 tid = threadIdx) { +__device__ __forceinline__ void page_produce_kv(typename KTraits::SharedStorage* smem_storage, + uint32_t* smem_offset, + typename KTraits::DTypeKV* kv_ptr, + const uint32_t kv_idx_base, + const size_t* thr_local_kv_offset, + const uint32_t kv_len, const uint32_t warp_idx, + const uint32_t lane_idx) { // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment + smem_t smem(produce_v ? smem_storage->v_smem : smem_storage->k_smem); using DType = typename KTraits::DTypeKV; using IdType = typename KTraits::IdType; constexpr SharedMemFillMode fill_mode = @@ -338,15 +342,13 @@ __device__ __forceinline__ void page_produce_kv( constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; constexpr uint32_t UPCAST_STRIDE = produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; - const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { - DType* gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] - : paged_kv.k_data + thr_local_kv_offset[i]; + DType* gptr = kv_ptr + thr_local_kv_offset[i]; #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); @@ -365,8 +367,7 @@ __device__ __forceinline__ void page_produce_kv( static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { - DType* gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] - : paged_kv.k_data + thr_local_kv_offset[i]; + DType* gptr = kv_ptr + thr_local_kv_offset[i]; smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); kv_idx += NUM_WARPS * 8; *smem_offset = @@ -2186,11 +2187,11 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( page_iter, kv_head_idx, entry_idx, (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), last_indptr); } - page_produce_kv(k_smem, &k_smem_offset_w, paged_kv, 0, thr_local_kv_offset, - chunk_size, tid); + page_produce_kv(&smem_storage, &k_smem_offset_w, paged_kv.k_data, 0, + thr_local_kv_offset, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); - page_produce_kv(v_smem, &v_smem_offset_w, paged_kv, 0, thr_local_kv_offset, - chunk_size, tid); + page_produce_kv(&smem_storage, &v_smem_offset_w, paged_kv.v_data, 0, + thr_local_kv_offset, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); uint32_t num_iterations_prefix; @@ -2327,8 +2328,9 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); - page_produce_kv(k_smem, &k_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV, - thr_local_kv_offset, chunk_size, tid); + page_produce_kv(&smem_storage, &k_smem_offset_w, paged_kv.k_data, + (iter + 1) * CTA_TILE_KV, thr_local_kv_offset, chunk_size, + warp_idx, lane_idx); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); @@ -2337,8 +2339,9 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); - page_produce_kv(v_smem, &v_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV, - thr_local_kv_offset, chunk_size, tid); + page_produce_kv(&smem_storage, &v_smem_offset_w, paged_kv.v_data, + (iter + 1) * CTA_TILE_KV, thr_local_kv_offset, chunk_size, + warp_idx, lane_idx); cp_async::commit_group(); } cp_async::wait_group<0>(); diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 89eb6a0821..04a2873ffa 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -50,6 +50,13 @@ std::tuple LaunchSpecForDecodeKernelMlaCuteSM80( template __global__ void BatchDecodeWithPagedKVCacheKernelMlaCuteSM80(Params params); +template +inline void CopyToPageLockedBuffer(void* page_locked_int_buffer, int64_t offset, + const std::vector& vec) { + DType* ptr = GetPtrFromBaseOffset(page_locked_int_buffer, offset); + std::copy(vec.begin(), vec.end(), ptr); +} + /*! * \brief Compute the maximum number of pages per batch and the new batch size * after we partition Paged KV-Cache into multiple chunks on KV sequence length @@ -991,6 +998,342 @@ inline int packed_causal_kv_end(int qo_len, int kv_len, int qo_tile_idx, int clu return kv_len_init + (qo_tile_idx + 1) * cluster_tile_q / group_size; } +template +struct HolisticPlanInfo { + int64_t num_blks_x; + int64_t num_blks_y; + struct { + int64_t q_indptr_offset; + int64_t kv_indptr_offset; + int64_t partial_indptr_offset; + int64_t q_len_offset; + int64_t kv_len_offset; + int64_t q_start_offset; + int64_t kv_start_offset; + int64_t kv_end_offset; + int64_t kv_head_idx_offset; + int64_t work_indptr_offset; + int64_t len_kv_chunk_offset; + } tasks[NUM_TASKS]; + + int64_t partial_o_offset; + int64_t partial_lse_offset; + int64_t merge_indptr_offset; + int64_t merge_o_indices_offset; + int64_t num_qo_len_offset; + + static constexpr uint32_t NUM_TASK_ARGS = 11; + static constexpr uint32_t NUM_SHARED_ARGS = 7; + + std::vector ToVector() const { + std::vector vec; + vec.push_back(num_blks_x); + vec.push_back(num_blks_y); + for (uint32_t i = 0; i < NUM_TASKS; ++i) { + vec.push_back(tasks[i].q_indptr_offset); + vec.push_back(tasks[i].kv_indptr_offset); + vec.push_back(tasks[i].partial_indptr_offset); + vec.push_back(tasks[i].q_len_offset); + vec.push_back(tasks[i].kv_len_offset); + vec.push_back(tasks[i].q_start_offset); + vec.push_back(tasks[i].kv_start_offset); + vec.push_back(tasks[i].kv_end_offset); + vec.push_back(tasks[i].kv_head_idx_offset); + vec.push_back(tasks[i].work_indptr_offset); + vec.push_back(tasks[i].len_kv_chunk_offset); + } + vec.push_back(partial_o_offset); + vec.push_back(partial_lse_offset); + vec.push_back(merge_indptr_offset); + vec.push_back(merge_o_indices_offset); + vec.push_back(num_qo_len_offset); + return vec; + } + + void FromVector(const std::vector& vec) { + if (vec.size() != NUM_SHARED_ARGS + NUM_TASKS * NUM_TASK_ARGS) { + std::ostringstream err_msg; + err_msg << "HolisticPlanInfo::FromVector: vec.size() should be " + << NUM_SHARED_ARGS + NUM_TASKS * NUM_TASK_ARGS << ", but got " << vec.size(); + FLASHINFER_ERROR(err_msg.str()); + } + num_blks_x = vec[0]; + num_blks_y = vec[1]; + for (uint32_t i = 0; i < NUM_TASKS; ++i) { + tasks[i].q_indptr_offset = vec[2 + i * NUM_TASK_ARGS + 0]; + tasks[i].kv_indptr_offset = vec[2 + i * NUM_TASK_ARGS + 1]; + tasks[i].partial_indptr_offset = vec[2 + i * NUM_TASK_ARGS + 2]; + tasks[i].q_len_offset = vec[2 + i * NUM_TASK_ARGS + 3]; + tasks[i].kv_len_offset = vec[2 + i * NUM_TASK_ARGS + 4]; + tasks[i].q_start_offset = vec[2 + i * NUM_TASK_ARGS + 5]; + tasks[i].kv_start_offset = vec[2 + i * NUM_TASK_ARGS + 6]; + tasks[i].kv_end_offset = vec[2 + i * NUM_TASK_ARGS + 7]; + tasks[i].kv_head_idx_offset = vec[2 + i * NUM_TASK_ARGS + 8]; + tasks[i].work_indptr_offset = vec[2 + i * NUM_TASK_ARGS + 9]; + tasks[i].len_kv_chunk_offset = vec[2 + i * NUM_TASK_ARGS + 10]; + } + partial_o_offset = vec[2 + NUM_TASKS * NUM_TASK_ARGS]; + partial_lse_offset = vec[3 + NUM_TASKS * NUM_TASK_ARGS]; + merge_indptr_offset = vec[4 + NUM_TASKS * NUM_TASK_ARGS]; + merge_o_indices_offset = vec[5 + NUM_TASKS * NUM_TASK_ARGS]; + num_qo_len_offset = vec[6 + NUM_TASKS * NUM_TASK_ARGS]; + } +}; + +template +inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, void* page_locked_int_buffer, + size_t int_workspace_size_in_bytes, + HolisticPlanInfo<2>& plan_info, IdType* qo_indptr_h, + IdType* kv_indptr_h, IdType* kv_len_arr_h, + uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t head_dim, bool causal, + cudaStream_t stream) { + constexpr uint32_t NUM_TASKS = 2; + const uint32_t CTA_TILE_Q_SIZES[NUM_TASKS] = {128, 16}; + int num_sm = 0; + int dev_id = 0; + + uint32_t gqa_group_size = num_qo_heads / num_kv_heads; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + + if (head_dim >= 256) { + // NOTE (Yilong): optimize this code path + // constraint gridDim due to cooperative group + num_sm *= 1; + } else { + // NOTE(Zihao): two cta per sm + num_sm *= 2; + } + + // step 0. determine the number of blocks in x and y dimensions + std::vector> idx_qo_kv_len_vec[NUM_TASKS]; + for (uint32_t i = 0; i < batch_size; ++i) { + if (qo_indptr_h[i + 1] - qo_indptr_h[i] < 0) { + std::ostringstream err_msg; + err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]" + << qo_indptr_h[i] << " should be non-negative"; + FLASHINFER_ERROR(err_msg.str()); + } + + int qo_len = qo_indptr_h[i + 1] - qo_indptr_h[i]; + int packed_qo_len = qo_len * gqa_group_size; + int kv_len = kv_len_arr_h[i]; + + // TODO(Zihao): add more stages + if (packed_qo_len > CTA_TILE_Q_SIZES[1]) { + idx_qo_kv_len_vec[0].push_back({i, qo_len, kv_len}); + } else { + idx_qo_kv_len_vec[1].push_back({i, qo_len, kv_len}); + } + } + + int cluster_size = 1; + int num_clusters = num_sm / cluster_size; + plan_info.num_blks_x = cluster_size; + plan_info.num_blks_y = num_clusters; + + auto f = [](int x) { + if (x <= 8) { + return 32; + } else if (x <= 16) { + return 64; + } else if (x <= 32) { + return 128; + } else if (x <= 64) { + return 192; + } + return ceil_div(x, 256) * 256; + }; + + MinHeap cluster_cost_heap(num_clusters); + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + + // NOTE(Zihao): adjust it later + const int max_total_num_works = 16384; + const int max_packed_qo_lens = + 4 * num_clusters * cluster_size * (CTA_TILE_Q_SIZES[0] + CTA_TILE_Q_SIZES[1]); + ; // max_partial_num_rows + + // used for remapping the output offsets + // layout [packed_qo_len x num_kv_tiels, num_kv_heads, head_dim] + int partial_o_nnz = 0; + std::vector merge_indptr, merge_o_indices, num_expand_qo_len_vec; + merge_indptr.push_back(partial_o_nnz); + for (uint32_t task = 0; task < NUM_TASKS; ++task) { + int64_t total_kv_lens = 0; + int cluster_tile_q = CTA_TILE_Q_SIZES[task] * cluster_size; + for (auto& [_, qo_len, kv_len] : idx_qo_kv_len_vec[task]) { + int packed_qo_len = qo_len * gqa_group_size; + int num_qo_tiles = ceil_div(packed_qo_len, cluster_tile_q); + for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; --qo_tile_idx) { + int effective_kv_len = + causal ? packed_causal_kv_end(qo_len, kv_len, qo_tile_idx, cluster_tile_q, num_qo_tiles, + gqa_group_size) + : kv_len; + total_kv_lens += effective_kv_len; + } + } + int kv_len_limit = f(std::max(ceil_div(total_kv_lens, num_clusters), 1L)); + + std::vector> cluster_q_indptr(num_clusters, std::vector()), + cluster_kv_indptr(num_clusters, std::vector()), + cluster_q_len(num_clusters, std::vector()), + cluster_kv_len(num_clusters, std::vector()), + cluster_q_start(num_clusters, std::vector()), + cluster_kv_start(num_clusters, std::vector()), + cluster_kv_end(num_clusters, std::vector()), + cluster_kv_head_idx(num_clusters, std::vector()), + cluster_partial_indptr(num_clusters, std::vector()), + cluster_len_kv_chunk(num_clusters, std::vector()); + + for (auto& [i, qo_len, kv_len] : idx_qo_kv_len_vec[task]) { + int packed_qo_len = qo_len * gqa_group_size; + int num_qo_tiles = ceil_div(packed_qo_len, cluster_tile_q); + for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; --qo_tile_idx) { + int remaining_len = causal + ? packed_causal_kv_end(qo_len, kv_len, qo_tile_idx, cluster_tile_q, + num_qo_tiles, gqa_group_size) + : kv_len; + int kv_start = 0; + bool split_kv = remaining_len > kv_len_limit; + int num_kv_tiles = split_kv ? ceil_div(remaining_len, kv_len_limit) : 1; + int row_tile_size = std::min(cluster_tile_q, packed_qo_len - qo_tile_idx * cluster_tile_q); + bool zero_kv_len = (remaining_len == 0); + while (remaining_len > 0 || zero_kv_len) { + int actual_len = std::min(remaining_len, kv_len_limit); + for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) { + auto [cluster_idx, accum_cost] = cluster_cost_heap.pop(); + cluster_cost_heap.insert( + {cluster_idx, accum_cost + cost_function(cluster_tile_q, actual_len)}); + cluster_q_len[cluster_idx].push_back(qo_len); + cluster_kv_len[cluster_idx].push_back(kv_len); + cluster_q_indptr[cluster_idx].push_back(qo_indptr_h[i]); + cluster_kv_indptr[cluster_idx].push_back(kv_indptr_h[i]); + + // use kv_chunk to rematerize num_kv_tiles and kv_tile_idx + cluster_len_kv_chunk[cluster_idx].push_back(kv_len_limit); + cluster_partial_indptr[cluster_idx].push_back(partial_o_nnz); + + cluster_q_start[cluster_idx].push_back(qo_tile_idx * cluster_tile_q); + cluster_kv_start[cluster_idx].push_back(kv_start); + cluster_kv_end[cluster_idx].push_back(kv_start + actual_len); + cluster_kv_head_idx[cluster_idx].push_back(kv_head_idx); + } + remaining_len -= actual_len; + zero_kv_len = (remaining_len == 0); + kv_start += actual_len; + if (zero_kv_len) { + break; + } + } + for (int row = 0; row < row_tile_size; ++row) { + merge_indptr.push_back(merge_indptr.back() + num_kv_tiles); + merge_o_indices.push_back(qo_indptr_h[i] + + (qo_tile_idx * cluster_tile_q + row) / gqa_group_size); + } + partial_o_nnz += row_tile_size * num_kv_tiles; + } + } + + std::vector work_indptr_vec(num_clusters + 1, 0); + for (uint32_t i = 0; i < num_clusters; ++i) { + work_indptr_vec[i + 1] = work_indptr_vec[i] + cluster_q_indptr[i].size(); + } + int total_num_works = work_indptr_vec.back(); + auto q_indptr_vec = flatten(cluster_q_indptr, total_num_works); + auto kv_indptr_vec = flatten(cluster_kv_indptr, total_num_works); + auto partial_indptr_vec = flatten(cluster_partial_indptr, total_num_works); + auto q_len_vec = flatten(cluster_q_len, total_num_works); + auto kv_len_vec = flatten(cluster_kv_len, total_num_works); + auto q_start_vec = flatten(cluster_q_start, total_num_works); + auto kv_start_vec = flatten(cluster_kv_start, total_num_works); + auto kv_end_vec = flatten(cluster_kv_end, total_num_works); + auto kv_head_idx_vec = flatten(cluster_kv_head_idx, total_num_works); + auto len_kv_chunk_vec = flatten(cluster_len_kv_chunk, total_num_works); + + plan_info.tasks[task].q_indptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "q_indptr"); + plan_info.tasks[task].kv_indptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "kv_indptr"); + plan_info.tasks[task].partial_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "partial_indptr"); + plan_info.tasks[task].q_len_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "q_len"); + plan_info.tasks[task].kv_len_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "kv_len"); + plan_info.tasks[task].q_start_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "q_start"); + plan_info.tasks[task].kv_start_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "kv_start"); + plan_info.tasks[task].kv_end_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "kv_end"); + plan_info.tasks[task].kv_head_idx_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "kv_head_idx"); + plan_info.tasks[task].work_indptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "work_indptr"); + plan_info.tasks[task].len_kv_chunk_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "len_kv_chunk"); + + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].q_indptr_offset, + q_indptr_vec); + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].kv_indptr_offset, + kv_indptr_vec); + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].partial_indptr_offset, + partial_indptr_vec); + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].q_len_offset, q_len_vec); + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].kv_len_offset, kv_len_vec); + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].q_start_offset, + q_start_vec); + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].kv_start_offset, + kv_start_vec); + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].kv_end_offset, kv_end_vec); + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].kv_head_idx_offset, + kv_head_idx_vec); + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].work_indptr_offset, + work_indptr_vec); + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].len_kv_chunk_offset, + len_kv_chunk_vec); + } + + if (partial_o_nnz > max_packed_qo_lens) { + std::ostringstream err_msg; + err_msg << "partial_o_nnz " << partial_o_nnz << " exceeds max_packed_qo_lens " + << max_packed_qo_lens; + FLASHINFER_ERROR(err_msg.str()); + } + + // update num_qo_len_vec + num_expand_qo_len_vec.push_back(merge_indptr.size() - 1); + // allocate buffer for state merge function + plan_info.merge_indptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_packed_qo_lens, 16, "merge_indptr"); + plan_info.merge_o_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_packed_qo_lens, 16, "merge_o_indices"); + plan_info.num_qo_len_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType), 16, "num_qo_len_offset"); + // copy data to paged cpu buffer + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.merge_indptr_offset, merge_indptr); + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.merge_o_indices_offset, merge_o_indices); + CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.num_qo_len_offset, + num_expand_qo_len_vec); + + size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, + cudaMemcpyHostToDevice, stream)); + constexpr size_t sizeof_dtype_o = 2; // NOTE (Yilong): assume fp16 + + // Note(Yilong): adjust it later + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + plan_info.partial_o_offset = float_allocator.aligned_alloc_offset( + 2 * max_packed_qo_lens * sizeof_dtype_o * head_dim, 16, "holistic_partial_o"); + plan_info.partial_lse_offset = float_allocator.aligned_alloc_offset( + 2 * max_packed_qo_lens * sizeof(float), 16, "holistic_partial_lse"); + + return cudaSuccess; +} + struct MLAPlanInfo { int64_t num_blks_x; int64_t num_blks_y; diff --git a/include/flashinfer/permuted_smem.cuh b/include/flashinfer/permuted_smem.cuh index 8c76f1ef05..a63283ebdc 100644 --- a/include/flashinfer/permuted_smem.cuh +++ b/include/flashinfer/permuted_smem.cuh @@ -44,6 +44,16 @@ constexpr __host__ __device__ __forceinline__ uint32_t upcast_size() { return sizeof(b128_t) / sizeof(T); } +template +__device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, uint32_t j) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + return i * stride + (j ^ (i % 8)); + } else { + // swizzle_mode == SwizzleMode::k64B + return i * stride + (j ^ ((i / 2) % 4)); + } +} + /*! * \brief The shared memory wrapper. */ diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 5c349cc915..9f21f5b2b5 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -321,6 +321,18 @@ inline uint32_t FA2DetermineCtaTileQ(int64_t avg_packed_qo_len, uint32_t head_di } } +#define LOOP_SPLIT_MASK(iter, COND1, COND2, ...) \ + { \ + _Pragma("unroll 1") for (; (COND1); (iter) -= 1) { \ + constexpr bool WITH_MASK = true; \ + __VA_ARGS__ \ + } \ + _Pragma("unroll 1") for (; (COND2); (iter) -= 1) { \ + constexpr bool WITH_MASK = false; \ + __VA_ARGS__ \ + } \ + } + /*! * \brief Return x - y if x > y, otherwise return 0. */ diff --git a/tests/test_batch_attention.py b/tests/test_batch_attention.py new file mode 100644 index 0000000000..b413d17d83 --- /dev/null +++ b/tests/test_batch_attention.py @@ -0,0 +1,171 @@ +# test_flashinfer_attention.py +import numpy as np +import pytest +import torch + +import flashinfer + + +# ------------------------- Configuration generation function ----------------------------- # +def _build_seq_len_configs(): + """ + Reproduce the sequence length configurations from the original benchmark (including random cases). + Returns: List[List[Tuple[int,int]]] -> Each element is a list of (kv_len, qo_len) pairs. + """ + np.random.seed(42) + torch.manual_seed(42) + + seq_len_configs = [ + [(2048, 1)] * 77, # decode-only + [(4099, 129)] * 2, # prefill-only + [(600, 1)] * 132 * 2 + [(5000, 3)] * 128, + [(1024, 1)] * 100 + [(8192, 17)] * 8, # speculative decode + [(766, 2)] * 99 + [(1024, 512)] * 1, # chunked prefill + ] + + # Construct random seqlen tests + bsz, stride, sparsity = 256, 16, 0.05 + full_kv_len = np.random.randint(1000, 11000, size=bsz) + seq_len = [] + for i in range(bsz): + if i % stride == 0: + kv_len, qo_len = full_kv_len[i], stride + 1 + else: + kv_len, qo_len = int(full_kv_len[i] * sparsity), 1 + seq_len.append((kv_len, qo_len)) + seq_len_configs.append(seq_len) + + return seq_len_configs + + +def _run_attention( + kv_lens, + qo_lens, + page_block_size=1, + num_kv_heads=1, + num_qo_heads=1, + head_dim=128, + layout="NHD", + test_dtype=torch.bfloat16, + device="cuda", + causal=True, +): + """ + Run both implementations and return (output_old, lse_old, output_new, lse_new) + """ + dev = torch.device(device) + seq_lens = torch.tensor(kv_lens, dtype=torch.int32) + q_lens = torch.tensor(qo_lens, dtype=torch.int32) + + seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int() + + q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int() + kv_indptr = torch.cat( + [torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0 + ).int() + + num_blocks = kv_indptr[-1].item() + + q = torch.rand( + q_indptr[-1].item(), num_qo_heads, head_dim, dtype=test_dtype, device=dev + ) + if layout == "NHD": + kv_data = torch.randn( + num_blocks, + 2, + page_block_size, + num_kv_heads, + head_dim, + dtype=test_dtype, + device=dev, + ) + elif layout == "HND": + kv_data = torch.randn( + num_blocks, + 2, + num_kv_heads, + page_block_size, + head_dim, + dtype=test_dtype, + device=dev, + ) + + # --------- old scheduler --------- # + wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=dev), + kv_layout=layout, + backend="fa2", + ) + last_page_len = (seq_lens - 1) % page_block_size + 1 + wrapper_old.plan( + q_indptr.to(dev), + kv_indptr.to(dev), + torch.arange(num_blocks, device=dev).int(), + last_page_len.to(dev), + num_qo_heads, + num_kv_heads, + head_dim, + page_block_size, + causal=causal, + q_data_type=test_dtype, + kv_data_type=test_dtype, + ) + out_old, lse_old = wrapper_old.run(q, kv_data, return_lse=True) + + # --------- new / mixed scheduler --------- # + wrapper = flashinfer.BatchAttention(kv_layout=layout) + wrapper.plan( + q_indptr.to(dev), + kv_indptr.to(dev), + torch.arange(num_blocks, device=dev).int(), + seq_lens.to(dev), + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + page_block_size, + causal=causal, + q_data_type=test_dtype, + kv_data_type=test_dtype, + ) + out_new, lse_new = wrapper.run(q, kv_data) + + torch.cuda.synchronize() + torch.testing.assert_close(out_old, out_new, rtol=1e-2, atol=1e-2) + + +# ------------------------- PyTest test case ----------------------------- # +@pytest.mark.parametrize("seq_len_pairs", _build_seq_len_configs()) +@pytest.mark.parametrize("page_block_size", [1, 8, 16]) +@pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) +@pytest.mark.parametrize("gqa_group_size", [1, 4]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("layout", ["HND", "NHD"]) +@pytest.mark.parametrize("test_dtype", [torch.bfloat16, torch.float16]) +def test_batch_attention_correctness( + seq_len_pairs, + page_block_size, + num_kv_heads, + gqa_group_size, + head_dim, + causal, + layout, + test_dtype, +): + num_qo_heads = num_kv_heads * gqa_group_size + kv_lens = [p[0] for p in seq_len_pairs] + qo_lens = [p[1] for p in seq_len_pairs] + + _run_attention( + kv_lens=kv_lens, + qo_lens=qo_lens, + page_block_size=page_block_size, + num_kv_heads=num_kv_heads, + num_qo_heads=num_qo_heads, + head_dim=head_dim, + causal=causal, + layout=layout, + test_dtype=test_dtype, + device="cuda", + )