From d4e8d79b340589633943bebd827da17b3f4c29ad Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 16 Dec 2024 23:27:39 -0800 Subject: [PATCH] feat: add JIT compilation support for FA3 templates (#672) Follow up work of #667 --- flashinfer/jit/attention.py | 6 +- flashinfer/jit/batch_prefill_sm90_templ.py | 386 +++++++++++++++++++- flashinfer/jit/core.py | 2 + flashinfer/jit/single_prefill_sm90_templ.py | 148 +++++++- flashinfer/prefill.py | 34 +- tests/test_jit_warmup.py | 57 ++- 6 files changed, 610 insertions(+), 23 deletions(-) diff --git a/flashinfer/jit/attention.py b/flashinfer/jit/attention.py index ba24f7c1..1ead1215 100644 --- a/flashinfer/jit/attention.py +++ b/flashinfer/jit/attention.py @@ -28,7 +28,7 @@ batch_prefill_sm90_templ, ) from .batch_prefill_templ import batch_prefill_suffix, batch_prefill_templ -from .core import load_cuda_ops +from .core import load_cuda_ops, sm90a_nvcc_flags from .env import FLASHINFER_GEN_SRC_DIR from .single_decode_templ import ( customizable_single_decode_templ, @@ -333,7 +333,7 @@ def gen_single_prefill_sm90_module(*args): source_paths.append(path) write_if_different(path, source) - return load_cuda_ops(uri, source_paths) + return load_cuda_ops(uri, source_paths, extra_cuda_cflags=sm90a_nvcc_flags) def get_batch_prefill_sources( @@ -445,7 +445,7 @@ def gen_batch_prefill_sm90_module(*args): source_paths.append(path) write_if_different(path, source) - return load_cuda_ops(uri, source_paths) + return load_cuda_ops(uri, source_paths, extra_cuda_cflags=sm90a_nvcc_flags) def get_customize_single_decode_sources( diff --git a/flashinfer/jit/batch_prefill_sm90_templ.py b/flashinfer/jit/batch_prefill_sm90_templ.py index c06c4aac..88bd2407 100644 --- a/flashinfer/jit/batch_prefill_sm90_templ.py +++ b/flashinfer/jit/batch_prefill_sm90_templ.py @@ -14,6 +14,388 @@ limitations under the License. """ -batch_prefill_sm90_suffix = [".cu", "_pybind.cc"] +batch_prefill_sm90_suffix = [ + "_plan.cu", + *[f"_ragged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]], + "_ragged_run.cu", + *[f"_paged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]], + "_paged_run.cu", + "_pybind.cc", +] -batch_prefill_sm90_templ = [r"""""", r""""""] + +def ragged_prefill_sm90_inst_templ(mask_mode: str) -> str: + return ( + r"""#include +#include +#include +#include + +namespace flashinfer { + +using DTypeQ = cutlass_dtype_t<{{dtype_q}}>; +using DTypeKV = cutlass_dtype_t<{{dtype_kv}}>; +using DTypeO = cutlass_dtype_t<{{dtype_o}}>; +using IdType = cutlass_dtype_t<{{dtype_idx}}>; + +using RaggedParams = + BatchPrefillRaggedParams; +using AttentionVariant = std::conditional_t<{{use_logits_soft_cap}}, LogitsSoftCap, StandardAttention>; + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, """ + + mask_mode + + r""", /*USE_SWA=*/true, AttentionVariant>( + RaggedParams& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, """ + + mask_mode + + r""", /*USE_SWA=*/false, AttentionVariant>( + RaggedParams& params, + cudaStream_t stream); + +}""" + ) + + +def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: + return ( + r"""#include +#include +#include +#include + +namespace flashinfer { + +using DTypeQ = cutlass_dtype_t<{{dtype_q}}>; +using DTypeKV = cutlass_dtype_t<{{dtype_kv}}>; +using DTypeO = cutlass_dtype_t<{{dtype_o}}>; +using IdType = cutlass_dtype_t<{{dtype_idx}}>; + +using PagedParams = BatchPrefillPagedParams; +using AttentionVariant = std::conditional_t<{{use_logits_soft_cap}}, LogitsSoftCap, StandardAttention>; + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, """ + + mask_mode + + r""", /*USE_SWA=*/true, AttentionVariant>( + PagedParams& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, """ + + mask_mode + + r""", /*USE_SWA=*/false, AttentionVariant>( + PagedParams& params, + cudaStream_t stream); + +}""" + ) + + +batch_prefill_sm90_templ = [ + r"""#include +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +std::vector BatchPrefillWithKVCacheSM90Plan( + bool causal, at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, + bool enable_cuda_graph, int64_t cuda_stream) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + + PrefillPlanSM90Info plan_info; + cudaStream_t stream = reinterpret_cast(cuda_stream); + + cudaError_t status = PrefillSM90Plan( + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<{{ dtype_idx }}>(), + kv_indptr.data_ptr<{{ dtype_idx }}>(), kv_len_arr.data_ptr<{{ dtype_idx }}>(), batch_size, num_qo_heads, + num_kv_heads, {{ head_dim }}, page_size, causal, enable_cuda_graph, sizeof({{dtype_o}}), stream); + + TORCH_CHECK(status == cudaSuccess, + "PrefillSM90Plan failed with error: ", cudaGetErrorString(status)); + + return plan_info.ToVector(); +} +""", + *[ + ragged_prefill_sm90_inst_templ(mask_mode) + for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"] + ], + r""" +#include +#include +#include +#include +#include +#include +#include + +#include "pytorch_extension_utils.h" + +namespace flashinfer { + +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched( + BatchPrefillRaggedParams& params, cudaStream_t stream); + +}; // namespace flashinfer + +using namespace flashinfer; + +using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>; +using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>; +using DTypeO = cutlass_dtype_t<{{ dtype_o }}>; +using IdType = cutlass_dtype_t<{{ dtype_idx }}>; + +using RaggedParams = BatchPrefillRaggedParams; + +void BatchPrefillWithRaggedKVCacheSM90Run( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, at::Tensor kv_indptr, std::optional maybe_qk_indptr, + at::Tensor o, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream) { + PrefillPlanSM90Info plan_info; + plan_info.FromVector(plan_info_vec); + + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + } + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + auto q_scalar_type = q.scalar_type(); + + QKVLayout kv_layout = static_cast(layout); + cudaStream_t stream = reinterpret_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); + + RaggedParams params; + + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) { + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); + } else { + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); + } + params.nnz_qo = q.size(0); + params.nnz_kv = k.size(0); + params.head_dim = {{ head_dim }}; + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); + params.group_size = params.num_qo_heads / params.num_kv_heads; + params.window_left = window_left; + params.logits_soft_cap = logits_soft_cap; + params.sm_scale_log2 = sm_scale * math::log2e; + params.causal = mask_mode_code == 1; + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); + params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.head_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { + using AttentionVariant = + std::conditional_t<{{ use_logits_soft_cap }}, LogitsSoftCap, StandardAttention>; + cudaError_t status = + BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, MASK_MODE, {{ use_sliding_window }}, + AttentionVariant>(params, stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithRaggedKVCacheSM90Run failed with error: ", + cudaGetErrorString(status)); + }); +} +""", + *[ + paged_prefill_sm90_inst_templ(mask_mode) + for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"] + ], + r"""#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "pytorch_extension_utils.h" + +namespace flashinfer { + +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched( + BatchPrefillPagedParams& params, cudaStream_t stream); + +}; // namespace flashinfer + +using namespace flashinfer; + +using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>; +using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>; +using DTypeO = cutlass_dtype_t<{{ dtype_o }}>; +using IdType = cutlass_dtype_t<{{ dtype_idx }}>; + +using PagedParams = BatchPrefillPagedParams; + +void BatchPrefillWithPagedKVCacheSM90Run( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor paged_v_cache, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, + std::optional maybe_qk_indptr, at::Tensor o, unsigned int layout, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse, int64_t cuda_stream) { + PrefillPlanSM90Info plan_info; + plan_info.FromVector(plan_info_vec); + + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + } + QKVLayout kv_layout = static_cast(layout); + unsigned int num_kv_heads, page_size; + unsigned int head_dim = q.size(2); + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); + } else { + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); + } + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + auto q_scalar_type = q.scalar_type(); + + cudaStream_t stream = reinterpret_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); + + PagedParams params; + + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(paged_k_cache.data_ptr()); + params.v_ptr = static_cast(paged_v_cache.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) { + // (num_pages, page_size, num_heads, head_dim) + params.k_stride_n = paged_k_cache.stride(1); + params.k_stride_h = paged_k_cache.stride(2); + params.v_stride_n = paged_v_cache.stride(1); + params.v_stride_h = paged_v_cache.stride(2); + } else { + // (num_pages, num_heads, page_size, head_dim) + params.k_stride_h = paged_k_cache.stride(1); + params.k_stride_n = paged_k_cache.stride(2); + params.v_stride_h = paged_v_cache.stride(1); + params.v_stride_n = paged_v_cache.stride(2); + } + params.nnz_qo = q.size(0); + params.head_dim = head_dim; + params.num_qo_heads = q.size(1); + params.num_kv_heads = num_kv_heads; + params.group_size = params.num_qo_heads / num_kv_heads; + params.page_size = page_size; + params.window_left = window_left; + params.logits_soft_cap = logits_soft_cap; + params.sm_scale_log2 = sm_scale * math::log2e; + params.causal = mask_mode_code == 1; + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); + params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.head_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + params.kv_indices = static_cast(paged_kv_indices.data_ptr()); + + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { + using AttentionVariant = + std::conditional_t<{{ use_logits_soft_cap }}, LogitsSoftCap, StandardAttention>; + cudaError_t status = + BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, MASK_MODE, {{ use_sliding_window }}, + AttentionVariant>(params, stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCacheSM90Run failed with error: ", + cudaGetErrorString(status)); + }); +} +""", + r"""#include "pytorch_extension_utils.h" + +std::vector BatchPrefillWithKVCacheSM90Plan( + bool causal, at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, + bool enable_cuda_graph, int64_t cuda_stream); + +void BatchPrefillWithRaggedKVCacheSM90Run( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, at::Tensor kv_indptr, std::optional maybe_qk_indptr, + at::Tensor o, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream); + +void BatchPrefillWithPagedKVCacheSM90Run( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor paged_v_cache, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, + std::optional maybe_qk_indptr, at::Tensor o, unsigned int layout, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse, int64_t cuda_stream); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("plan", &BatchPrefillWithKVCacheSM90Plan); + m.def("ragged_run", &BatchPrefillWithRaggedKVCacheSM90Run); + m.def("paged_run", &BatchPrefillWithPagedKVCacheSM90Run); +} +""", +] diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 32ab7ca6..62013a0e 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -74,6 +74,8 @@ def remove_unwanted_pytorch_nvcc_flags(): remove_unwanted_pytorch_nvcc_flags() +sm90a_nvcc_flags = ["-gencode", "arch=compute_90a,code=sm_90a"] + def load_cuda_ops( name: str, diff --git a/flashinfer/jit/single_prefill_sm90_templ.py b/flashinfer/jit/single_prefill_sm90_templ.py index 917cf368..ac9f7112 100644 --- a/flashinfer/jit/single_prefill_sm90_templ.py +++ b/flashinfer/jit/single_prefill_sm90_templ.py @@ -14,6 +14,150 @@ limitations under the License. """ -single_prefill_sm90_suffix = [".cu", "_pybind.cc"] +single_prefill_sm90_suffix = [ + *[f"_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]], + ".cu", + "_pybind.cc", +] -single_prefill_sm90_templ = [r"""""", r""""""] + +def single_prefill_sm90_inst_templ(mask_mode: str) -> str: + return ( + r"""#include +#include +#include + +namespace flashinfer { + +using DTypeQ = cutlass_dtype_t<{{dtype_q}}>; +using DTypeKV = cutlass_dtype_t<{{dtype_kv}}>; +using DTypeO = cutlass_dtype_t<{{dtype_o}}>; + +using Params = SinglePrefillParams; +using AttentionVariant = std::conditional_t<{{use_logits_soft_cap}}, LogitsSoftCap, StandardAttention>; + +template cudaError_t SinglePrefillWithKVCacheDispatched<{{ head_dim }},""" + f"{mask_mode}" + r""", /*USE_SWA=*/false, AttentionVariant>( + Params& params, + cudaStream_t stream); + +template cudaError_t SinglePrefillWithKVCacheDispatched<{{ head_dim }},""" + f"{mask_mode}" + r""", /*USE_SWA=*/true, AttentionVariant>( + Params& params, + cudaStream_t stream); + +} // namespace flashinfer +""" + ) + + +single_prefill_sm90_templ = [ + *[ + single_prefill_sm90_inst_templ(mask_mode) + for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"] + ], + r"""#include +#include +#include +#include +#include +#include +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +namespace flashinfer { + +template +cudaError_t SinglePrefillWithKVCacheDispatched(SinglePrefillParams& params, + cudaStream_t stream); + +} // namespace flashinfer + +using namespace flashinfer; + +void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q, at::Tensor k, + at::Tensor v, + std::optional maybe_packed_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor o, + unsigned int layout, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta, std::optional maybe_lse, + int64_t cuda_stream) { + unsigned int head_dim = q.size(2); + unsigned int num_qo_heads = q.size(1); + unsigned int qo_len = q.size(0); + + auto q_scalar_type = q.scalar_type(); + + QKVLayout kv_layout = static_cast(layout); + cudaStream_t stream = reinterpret_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); + + using DTypeQ = cutlass_dtype_t<{{dtype_q}}>; + using DTypeKV = cutlass_dtype_t<{{dtype_kv}}>; + using DTypeO = cutlass_dtype_t<{{dtype_o}}>; + + SinglePrefillParams params; + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? (static_cast(maybe_lse->data_ptr())) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) { + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); + } else { + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); + } + params.qo_len = q.size(0); + params.kv_len = k.size(0); + params.head_dim = head_dim; + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); + params.causal = mask_mode == MaskMode::kCausal; + params.group_size = params.num_qo_heads / params.num_kv_heads; + params.window_left = window_left; + params.logits_soft_cap = logits_soft_cap; + params.sm_scale_log2 = sm_scale * math::log2e; + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { + using AttentionVariant = + std::conditional_t<{{ use_logits_soft_cap }}, LogitsSoftCap, StandardAttention>; + cudaError_t status = + SinglePrefillWithKVCacheDispatched<{{ head_dim }}, MASK_MODE, {{ use_sliding_window }}, AttentionVariant>( + params, stream); + TORCH_CHECK(status == cudaSuccess, + "single_prefill_with_kv_cache_sm90 failed with error: " + + std::string(cudaGetErrorString(status))); + }); +} +""", + r"""#include "pytorch_extension_utils.h" + +void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q, at::Tensor k, + at::Tensor v, + std::optional maybe_packed_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor o, + unsigned int layout, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta, std::optional maybe_lse, + int64_t cuda_stream); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", &single_prefill_with_kv_cache_sm90, + "Single-request prefill attention with KV-Cache operator"); +} +""", +] diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index d5e8f5a3..fd8252e3 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -66,12 +66,12 @@ def get_single_prefill_sm90_module(*args): global _single_prefill_sm90_modules if args not in _single_prefill_sm90_modules: uri = get_single_prefill_sm90_uri(*args) - # if has_prebuilt_ops and uri in prebuilt_ops_uri: - from . import _kernels_sm90 + if has_prebuilt_ops and uri in prebuilt_ops_uri: + from . import _kernels_sm90 - run_func = _kernels_sm90.single_prefill_with_kv_cache_sm90 - # else: - # run_func = gen_single_prefill_sm90_module(*args).run + run_func = _kernels_sm90.single_prefill_with_kv_cache_sm90 + else: + run_func = gen_single_prefill_sm90_module(*args).run @register_custom_op(f"flashinfer::{uri}_run", mutates_args=("tmp", "maybe_lse")) def run_single_prefill_sm90( @@ -221,17 +221,23 @@ def get_batch_prefill_sm90_module(*args): if args not in _batch_prefill_sm90_modules: uri = get_batch_prefill_sm90_uri(*args) - from . import _kernels_sm90 + if has_prebuilt_ops and uri in prebuilt_ops_uri: + from . import _kernels_sm90 - head_dim = args[4] - plan_func = ( - lambda *plan_args: _kernels_sm90.batch_prefill_with_kv_cache_sm90_plan( - head_dim, - *plan_args, + head_dim = args[4] + plan_func = ( + lambda *plan_args: _kernels_sm90.batch_prefill_with_kv_cache_sm90_plan( + head_dim, + *plan_args, + ) ) - ) - ragged_run_func = _kernels_sm90.batch_prefill_with_ragged_kv_cache_sm90_run - paged_run_func = _kernels_sm90.batch_prefill_with_paged_kv_cache_sm90_run + ragged_run_func = _kernels_sm90.batch_prefill_with_ragged_kv_cache_sm90_run + paged_run_func = _kernels_sm90.batch_prefill_with_paged_kv_cache_sm90_run + else: + module = gen_batch_prefill_sm90_module(*args) + plan_func = module.plan + ragged_run_func = module.ragged_run + paged_run_func = module.paged_run # torch library for ragged_run diff --git a/tests/test_jit_warmup.py b/tests/test_jit_warmup.py index dc0d5b89..77370e43 100644 --- a/tests/test_jit_warmup.py +++ b/tests/test_jit_warmup.py @@ -15,10 +15,10 @@ """ import torch -from flashinfer.jit import parallel_load_modules -from flashinfer.utils import PosEncodingMode import flashinfer +from flashinfer.jit import parallel_load_modules +from flashinfer.utils import PosEncodingMode def test_warmpup_llama(): @@ -58,3 +58,56 @@ def test_warmpup_llama(): ), ] ) + + +def test_warmpup_llama_sm90(): + parallel_load_modules( + [ + (flashinfer.activation.get_act_and_mul_module, ["silu"]), + (flashinfer.norm.get_norm_module, []), + (flashinfer.sampling.get_sampling_module, []), + (flashinfer.quantization.get_quantization_module, []), + (flashinfer.page.get_page_module, []), + ( + flashinfer.decode.get_batch_decode_module, + [ + torch.float16, + torch.float16, + torch.float16, + torch.int32, + 128, + PosEncodingMode.NONE.value, + False, # use_sliding_window + False, # use_logits_soft_cap + ], + ), + ( + flashinfer.prefill.gen_batch_prefill_module, + [ + torch.float16, + torch.float16, + torch.float16, + torch.int32, + 128, + PosEncodingMode.NONE.value, + False, # use_sliding_window + False, # use_logits_soft_cap + False, # allow_fp16_qk_reduction + ], + ), + ( + flashinfer.prefill.gen_batch_prefill_sm90_module, + [ + torch.float16, + torch.float16, + torch.float16, + torch.int32, + 128, + PosEncodingMode.NONE.value, + False, # use_sliding_window + False, # use_logits_soft_cap + False, # allow_fp16_qk_reduction + ], + ), + ] + )