Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customizable SM90 prefill kernels. #704

Merged
merged 6 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 47 additions & 37 deletions aot_build_utils/generate_batch_paged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
from .literal_map import dtype_literal, idtype_literal, mask_mode_literal


def get_cu_file_str(
Expand All @@ -36,40 +31,56 @@ def get_cu_file_str(
dtype_out,
idtype,
):
pos_encoding_mode = None
allow_fp16_qk_reduction = None

def get_insts(attention_variant):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);

template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);

template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);

template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
return """
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/true,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
{attention_variant}>
(Params& params, cudaStream_t stream);

template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/true,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
{attention_variant}>
(Params& params, cudaStream_t stream);

template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/false,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
{attention_variant}>
(Params& params, cudaStream_t stream);

template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/false,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
{attention_variant}>
(Params& params, cudaStream_t stream);
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)
]
head_dim=head_dim,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)

dtype_q = dtype_literal[dtype_q]
dtype_kv = dtype_literal[dtype_kv]
dtype_out = dtype_literal[dtype_out]
idtype = idtype_literal[idtype]

content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
content = f""" // batch_paged_prefill_sm90 template inst
#include <flashinfer/attention/hopper/params.cuh>
#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>

Expand All @@ -82,9 +93,9 @@ def get_insts(attention_variant):

using Params = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;

{get_insts("LogitsSoftCap")}
{get_insts("LogitsSoftCap<Params>")}

{get_insts("StandardAttention")}
{get_insts("StandardAttention<Params>")}

}}"""
return content
Expand All @@ -93,12 +104,11 @@ def get_insts(attention_variant):
if __name__ == "__main__":
pattern = (
r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_"
r"dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
)
compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)

with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
path.write_text(get_cu_file_str(*match.groups()))
57 changes: 33 additions & 24 deletions aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,39 +38,48 @@ def get_cu_file_str(
):

def get_insts(attention_variant):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);

template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);

template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);

template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
return """
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/true,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
{attention_variant}>(Params& params, cudaStream_t stream);

template cudaError_t BatchPrefillWithRaggedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/true,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
{attention_variant}>(Params& params, cudaStream_t stream);

template cudaError_t BatchPrefillWithRaggedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/false,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
{attention_variant}>(Params& params, cudaStream_t stream);

template cudaError_t BatchPrefillWithRaggedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/false,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
{attention_variant}>(Params& params, cudaStream_t stream);
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)
]
)

dtype_q = dtype_literal[dtype_q]
dtype_kv = dtype_literal[dtype_kv]
dtype_out = dtype_literal[dtype_out]
idtype = idtype_literal[idtype]

content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
content = f""" // batch_ragged_prefill_sm90 template inst
#include <flashinfer/attention/hopper/params.cuh>
#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>

Expand All @@ -83,9 +92,9 @@ def get_insts(attention_variant):

using Params = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;

{get_insts("LogitsSoftCap")}
{get_insts("LogitsSoftCap<Params>")}

{get_insts("StandardAttention")}
{get_insts("StandardAttention<Params>")}

}}
"""
Expand Down
38 changes: 20 additions & 18 deletions aot_build_utils/generate_single_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def get_cu_file_str(
dtype_kv,
dtype_out,
):
content = """#include <flashinfer/attention/hopper/prefill_sm90.cuh>
content = """ // single_prefill_sm90 template inst
#include <flashinfer/attention/hopper/params.cuh>
#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>

Expand All @@ -42,31 +44,32 @@ def get_cu_file_str(

using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;

template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap>(
Params& params,
cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched
<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap<Params>>
(Params& params, cudaStream_t stream);

template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap>(
Params& params,
cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched
<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap<Params>>
(Params& params, cudaStream_t stream);

template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention>(
Params& params,
cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched
<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention<Params>>
(Params& params, cudaStream_t stream);

template cudaError_t SinglePrefillWithKVCacheDispatched
<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention<Params>>
(Params& params, cudaStream_t stream);

template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention>(
Params& params,
cudaStream_t stream);
}}
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
# pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
# allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
dtype_q=dtype_literal[dtype_q],
dtype_kv=dtype_literal[dtype_kv],
dtype_out=dtype_literal[dtype_out],
use_custom_mask="true" if int(mask_mode) == 2 else "false",
# use_custom_mask="true" if int(mask_mode) == 2 else "false",
)
return content

Expand All @@ -81,5 +84,4 @@ def get_cu_file_str(
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)
with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
path.write_text(get_cu_file_str(*match.groups()))
26 changes: 14 additions & 12 deletions csrc/batch_prefill_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,14 @@
namespace flashinfer {

template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ,
typename DTypeKV, typename DTypeO, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream);
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant>
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT& params,
cudaStream_t stream);

template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ,
typename DTypeKV, typename DTypeO, typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream);
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT& params,
cudaStream_t stream);

} // namespace flashinfer

Expand Down Expand Up @@ -110,7 +108,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run(
using DTypeO = DTypeQ;
using IdType = int32_t;

BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType> params;
using BatchPrefillRaggedParams = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>;
BatchPrefillRaggedParams params;

params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
params.k_ptr = static_cast<DTypeKV*>(k.data_ptr());
Expand Down Expand Up @@ -160,7 +159,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run(
return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
using AttentionVariant =
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap<BatchPrefillRaggedParams>,
StandardAttention<BatchPrefillRaggedParams>>;
cudaError_t status = BatchPrefillWithRaggedKVCacheDispatched<
HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
params, stream);
Expand Down Expand Up @@ -220,7 +220,8 @@ void BatchPrefillWithPagedKVCacheSM90Run(
using DTypeO = DTypeQ;
using IdType = int32_t;

BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType> params;
using BatchPrefillPagedParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>;
BatchPrefillPagedParams params;

params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
params.k_ptr = static_cast<DTypeKV*>(paged_k_cache.data_ptr());
Expand Down Expand Up @@ -272,7 +273,8 @@ void BatchPrefillWithPagedKVCacheSM90Run(
return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
using AttentionVariant =
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap<BatchPrefillPagedParams>,
StandardAttention<BatchPrefillPagedParams>>;
cudaError_t status = BatchPrefillWithPagedKVCacheDispatched<
HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
params, stream);
Expand Down
10 changes: 6 additions & 4 deletions csrc/single_prefill_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
namespace flashinfer {

template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
typename AttentionVariant, typename DTypeQ, typename DTypeKV, typename DTypeO>
cudaError_t SinglePrefillWithKVCacheDispatched(SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>& params,
typename AttentionVariant>
cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT& params,
cudaStream_t stream);

} // namespace flashinfer
Expand Down Expand Up @@ -59,7 +59,8 @@ void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q
using DTypeQ = cutlass_dtype_t<q_type>;
using DTypeKV = DTypeQ;
using DTypeO = DTypeQ;
SinglePrefillParams<DTypeQ, DTypeKV, DTypeO> params;
using SinglePrefillParams = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;
SinglePrefillParams params;
params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
params.k_ptr = static_cast<DTypeKV*>(k.data_ptr());
params.v_ptr = static_cast<DTypeKV*>(v.data_ptr());
Expand Down Expand Up @@ -96,7 +97,8 @@ void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q
return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
using AttentionVariant =
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap<SinglePrefillParams>,
StandardAttention<SinglePrefillParams>>;
cudaError_t status =
SinglePrefillWithKVCacheDispatched<HEAD_DIM, MASK_MODE, USE_SWA, AttentionVariant>(
params, stream);
Expand Down
Loading