Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions csrc/fmhaReduction.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace kernels {
template <int32_t TileSizePerCtaQ, int32_t HeadDim, int32_t HeadDimPerCta, bool IsE4m3Bmm,
typename DtypeO, typename DtypePartialO>
__global__ void __launch_bounds__(NumThreadsPerCta, 2)
fmhaReductionKernel(KernelParams const params, int32_t numCtasForReduction,
fmhaReductionKernel(KernelParams const params, bool sparseMla, int32_t numCtasForReduction,
int32_t numCtasForAllHeads, int32_t numHeadDimCtasV) {
// clang-format off
// The shape of partialO buffer: [batchSize, numHeadCtas, numCtasQ, numCtasKv, TileSizePerCtaQ, headDimPerCta].
Expand Down Expand Up @@ -64,10 +64,25 @@ __global__ void __launch_bounds__(NumThreadsPerCta, 2)

// The number of validRows.
int32_t const numValidRows{TileSizePerCtaQ};
// The seqOffsetQ.
int32_t const seqOffsetQ{params.ptrCumSeqLensQ == nullptr ? batchIdx * params.mMaxSeqLenQ
: params.ptrCumSeqLensQ[batchIdx]};
// The seqLenQ.
int32_t const seqLenQ{params.ptrCumSeqLensQ == nullptr
? params.mMaxSeqLenQ
: (params.ptrCumSeqLensQ[batchIdx + 1] - seqOffsetQ)};
// Early exit if ctaIdxQ >= seqLenQ, where each CTA processes one tokenQ.
if (ctaIdxQ >= seqLenQ) {
return;
}
// The actual number of seqLenKv.
int32_t seqLenKv{params.ptrSeqLensKv[batchIdx]};
// Consider the causal-mask speculative decoding.
seqLenKv = seqLenKv - ((params.mMaxSeqLenQ - 1) - ctaIdxQ);
// Consider sparseMlaTopK.
if (sparseMla) {
seqLenKv = min(seqLenKv, params.mSparseMlaTopK);
}
// The actual number of CtasKv (TileSizeKv is always 128 for now).
int32_t numCtasKv{min((seqLenKv + 127) / 128, params.mMaxNumCtasKv)};

Expand Down Expand Up @@ -336,7 +351,7 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams
config.numAttrs = 1;

// Select the kernel function pointer.
void (*kernel)(KernelParams const, int32_t, int32_t, int32_t) = nullptr;
void (*kernel)(KernelParams const, bool, int32_t, int32_t, int32_t) = nullptr;
if (headDimPerCtaV == 128) {
SELECT_FMHA_REDUCTION_KERNEL(128);
} else if (headDimPerCtaV == 256) {
Expand All @@ -346,8 +361,8 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams
}

// Launch the kernel.
cudaLaunchKernelEx(&config, kernel, params, numCtasForReduction, numCtasForAllHeads,
numHeadDimCtasV);
cudaLaunchKernelEx(&config, kernel, params, kernelMeta.mSparseMla, numCtasForReduction,
numCtasForAllHeads, numHeadDimCtasV);
cudaError_t err = cudaGetLastError();
FLASHINFER_CHECK(err == cudaSuccess, "Failed to launch kernel: ", cudaGetErrorString(err));
}
Expand Down
34 changes: 19 additions & 15 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ void trtllm_paged_attention_launcher(
int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq,
double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr,
const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t window_left, int64_t sum_seq_q, int64_t sm_count, bool enable_pdl,
int64_t workspace_size, cudaStream_t stream) {
int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
Expand Down Expand Up @@ -139,6 +139,12 @@ void trtllm_paged_attention_launcher(
runner_params.ptrAttentionSinks = attention_sinks;
runner_params.enable_pdl = enable_pdl;

// The sparse MLA parameters.
runner_params.mSparseMla = sparse_mla_top_k > 0;
runner_params.mSparseMlaTopK = sparse_mla_top_k;
TVM_FFI_ICHECK((head_dim_qk == 576 && head_dim_vo == 512) || sparse_mla_top_k <= 0)
<< "Only decode MLA supports sparse MLA";

AlignedAllocator float_allocator(workspace_buffer, workspace_size);
if (mode == TllmPagedAttentionMode::Context) {
runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal;
Expand Down Expand Up @@ -201,15 +207,13 @@ inline Data_type dl_dtype_to_tllm_data_type(const DLDataType dtype) {

inline bool is_4bit(Data_type data_type) { return data_type == Data_type::DATA_TYPE_E2M1; }

void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scale_factor,
TensorView query, TensorView key_cache, TensorView value_cache,
TensorView workspace_buffer, TensorView block_tables,
TensorView seq_lens, int64_t max_kv_len,
Variant<double, ffi::Tensor> bmm1_scale,
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t window_left, int64_t sm_count, bool enable_pdl,
int64_t workspace_size, Optional<TensorView> attention_sinks) {
void trtllm_paged_attention_decode(
TensorView out, Optional<TensorView> out_scale_factor, TensorView query, TensorView key_cache,
TensorView value_cache, TensorView workspace_buffer, TensorView block_tables,
TensorView seq_lens, int64_t max_kv_len, Variant<double, ffi::Tensor> bmm1_scale,
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
Expand Down Expand Up @@ -287,8 +291,8 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq,
bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,
o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, enable_pdl, workspace_size,
stream);
o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, sm_count,
enable_pdl, workspace_size, stream);
}

void trtllm_paged_attention_context(
Expand Down Expand Up @@ -367,8 +371,8 @@ void trtllm_paged_attention_context(
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count,
enable_pdl, workspace_size, stream);
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
/*sparse_mla_top_k=*/0, sm_count, enable_pdl, workspace_size, stream);
}

void trtllm_ragged_attention_launcher(
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class ArtifactPath:
When compiling new cubins for backend directories, update the corresponding path.
"""

TRTLLM_GEN_FMHA: str = "1e49deb33ec20018ae0acf1d956a579578069da1/fmha/trtllm-gen/"
TRTLLM_GEN_FMHA: str = "9f1b6ddaa1592a8339a82fcab7d27a57eff445fd/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"c108f5cc46420e11805467898186533fb48d6a6f/batched_gemm-0d28130-7b26988"
)
Expand All @@ -107,7 +107,7 @@ class CheckSumHash:
"""

TRTLLM_GEN_FMHA: str = (
"66757498f573430583d63b04c02bf9e38306eefe2ce31df9b5d923d99bd15d84"
"a5a60600a80076317703695f56bbef2f0a44075ef4e24d7b06ba67ff68bc9da2"
)
TRTLLM_GEN_BMM: str = (
"85a4516b7ab25b1a6495398ae934a00e30ccd6662b9ec27be1330d7bba5e1ddf"
Expand Down
35 changes: 25 additions & 10 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1922,6 +1922,7 @@ def _paged_run(
-1, # o_sf_vec_size
0, # o_sf_start_index
window_left,
0, # sparse_mla_top_k
self._sm_count,
enable_pdl,
workspace_size,
Expand Down Expand Up @@ -2328,6 +2329,7 @@ def trtllm_batch_decode_with_kv_cache(
o_sf_vec_size or -1,
o_sf_start_index,
window_left,
0, # sparse_mla_top_k
sm_count,
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
Expand Down Expand Up @@ -2500,6 +2502,7 @@ def _check_trtllm_gen_mla_shape(
qk_nope_head_dim,
kv_lora_rank,
qk_rope_head_dim,
sparse_mla_top_k,
page_table,
page_size,
):
Expand All @@ -2524,16 +2527,23 @@ def _check_trtllm_gen_mla_shape(
f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}"
)

B_block_table, block_num = page_table.shape
block_size = page_size
if B_q != B_block_table:
raise ValueError(
f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}"
)
if block_num % (128 / block_size) != 0:
raise ValueError(
f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}"
)
if sparse_mla_top_k > 0:
page_table_shape = page_table.shape
if page_table_shape != (B_q, Q_len, sparse_mla_top_k):
raise ValueError(
f"Expected page_table.shape == (B_q, Q_len, sparse_mla_top_k), got {page_table_shape}"
)
else:
B_block_table, block_num = page_table.shape
block_size = page_size
if B_q != B_block_table:
raise ValueError(
f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}"
)
if block_num % (128 / block_size) != 0:
raise ValueError(
f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}"
)


@flashinfer_api
Expand All @@ -2547,6 +2557,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
sparse_mla_top_k: int = 0,
out: Optional[torch.Tensor] = None,
bmm1_scale: Union[float, torch.Tensor] = 1.0,
bmm2_scale: Union[float, torch.Tensor] = 1.0,
Expand All @@ -2562,6 +2573,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
qk_nope_head_dim: qk_nope_head_dim, must be 128
kv_lora_rank: kv_lora_rank, must be 512
qk_rope_head_dim: qk_rope_head_dim, must be 64
sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA.
block_tables: page_table of kv cache, [batch_size, num_pages]
seq_lens: query_len
max_seq_len: max sequence length for kv_cache
Expand Down Expand Up @@ -2654,6 +2666,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
qk_nope_head_dim,
kv_lora_rank,
qk_rope_head_dim,
sparse_mla_top_k,
block_tables,
block_size,
)
Expand Down Expand Up @@ -2687,6 +2700,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
-1, # o_sf_vec_size
0, # o_sf_start_index
-1, # window_left
sparse_mla_top_k,
sm_count,
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
Expand Down Expand Up @@ -2768,6 +2782,7 @@ def xqa_batch_decode_with_kv_cache_mla(
qk_nope_head_dim,
kv_lora_rank,
qk_rope_head_dim,
0, # sparse_mla_top_k
block_tables,
block_size,
)
Expand Down
34 changes: 27 additions & 7 deletions include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@ class TllmGenFmhaKernel {
if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) {
// The maximum attention window (the maximum number of tokensKv that will be attended to).
int maxAttentionWindow{params.mMaxSeqLenKv};
// The sparseMla only selects topK tokensKv.
if (params.mSparseMla) {
maxAttentionWindow = std::min(params.mMaxSeqLenKv, params.mSparseMlaTopK);
}
// Some of the tilesKv will be skipped if the sliding window attention or chunked attention is
// used.
if (isSlidingOrChunkedCausalMask(selectKernelParams.mMaskType)) {
Expand Down Expand Up @@ -365,7 +369,8 @@ class TllmGenFmhaKernel {
// Need to select a different kernel.
selectKernelParams.mSelectNewKernel = true;
} else if (totalNumCtas < params.mMultiProcessorCount && isMlaGenKernel(params) &&
selectKernelParams.mTileSizeKv == 128 && getEnvUseTileSizeKv64ForTrtllmGen()) {
!params.mSparseMla && selectKernelParams.mTileSizeKv == 128 &&
getEnvUseTileSizeKv64ForTrtllmGen()) {
// Use smaller tileSizeKv to fully utilize the SMs.
selectKernelParams.mTileSizeKv = 64;
// Need to select a different kernel.
Expand Down Expand Up @@ -461,13 +466,15 @@ class TllmGenFmhaKernel {
// We use the low-latency kernel (SwapsMmaAbForGeneration with tileSizeQ = 16) when any of the
// following conditions are met:
// 1. The number of headsQPerKv is <= 32.
// 2. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned
// 2. The number of headsQPerKv is < 128 for sparseMla.
// 3. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned
// later) and
// the numCtas (after splitting the heads across multiple CTAs) <=
// params.mMultiProcessorCount.

// Check the conditions.
if (params.mNumHeadsQPerKv <= 32 || useSwapsMmaAbMlaGenKernel(params)) {
if (params.mNumHeadsQPerKv <= 32 || (params.mSparseMla && params.mNumHeadsQPerKv < 128) ||
useSwapsMmaAbMlaGenKernel(params)) {
kernelType = FmhaKernelType::SwapsMmaAbForGeneration;
} else {
// Otherwise, we use the high-throughput kernel.
Expand All @@ -476,6 +483,10 @@ class TllmGenFmhaKernel {
if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) {
selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel;
}
// The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128.
FLASHINFER_CHECK(
!params.mSparseMla || params.mNumHeadsQPerKv == 128,
"The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128");
// The 2CTA keepsMmaAbForGeneration kernel is used when the numHeadsQPerKv is 128.
if (params.mNumHeadsQPerKv == 128) {
selectKernelParams.mUses2CtaMma = true;
Expand Down Expand Up @@ -524,8 +535,16 @@ class TllmGenFmhaKernel {
"Sliding window attention and chunked attention should not be used together");
selectKernelParams.mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal;
}
// NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels.
int numTokensPerPage = (!isPagedKv(params.mQkvLayout)) ? 0 : params.mNumTokensPerPage;

// The number of tokens per page.
int numTokensPerPage = params.mNumTokensPerPage;
// SparseMla kernels use a fixed numTokensPerPage = 1.
if (params.mSparseMla) {
numTokensPerPage = 1;
} else if (!isPagedKv(params.mQkvLayout)) {
// NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels.
numTokensPerPage = 0;
}

// Debug info.
std::string info =
Expand All @@ -542,7 +561,8 @@ class TllmGenFmhaKernel {
", numTokensPerPage=" + std::to_string(numTokensPerPage) +
", maxNumHeadsQPerKvInCta=" + std::to_string(maxNumHeadsQPerKvInCta) +
", reuseSmemKForV=" + std::to_string(selectKernelParams.mReuseSmemKForV) +
", uses2CtaMma=" + std::to_string(selectKernelParams.mUses2CtaMma);
", uses2CtaMma=" + std::to_string(selectKernelParams.mUses2CtaMma) +
", sparseMla=" + std::to_string(params.mSparseMla);
IKL_LOG_DEBUG(
"Searching for kernel traits (%d available) in TllmGenFmhaKernel(%s, %s, %s, %d) %s",
getNumLoadedKernels(), toStr(mDtypeQ), toStr(mDtypeKv), toStr(mDtypeOut), mSM,
Expand All @@ -555,7 +575,7 @@ class TllmGenFmhaKernel {
selectKernelParams.mHeadDimPerCtaV, params.mHeadDimQk, params.mHeadDimV,
selectKernelParams.mTileSizeKv, numTokensPerPage, maxNumHeadsQPerKvInCta,
selectKernelParams.mReuseSmemKForV, selectKernelParams.mUses2CtaMma,
/* sparseMla */ false),
params.mSparseMla),
info);
}

Expand Down
4 changes: 4 additions & 0 deletions include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ struct TllmGenFmhaRunnerParams {
float mScaleSfKv;
// The SF scale for output.
float mScaleSfO;
// Whether to use sparse MLA.
bool mSparseMla;
// The top k value for sparse MLA.
int mSparseMlaTopK;
// The cuda stream.
cudaStream_t stream;
// Whether to enable PDL (Programmatic Dependent Launch).
Expand Down
19 changes: 17 additions & 2 deletions include/flashinfer/trtllm/fmha/kernelParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,8 @@ struct KernelParams {

// Check shape must be in range [1, 2^32]
int32_t dim = shapes.size();
// Max five dimension and min 3 dimension.
FLASHINFER_CHECK((dim <= 5) && (dim >= 3));
// Max five dimension and min 2 dimension.
FLASHINFER_CHECK((dim <= 5) && (dim >= 2));
// Check shape range.
for (int32_t ii = 0; ii < dim; ++ii) {
FLASHINFER_CHECK(shapes[ii] >= (uint64_t(1))); // Size must be min 1
Expand Down Expand Up @@ -597,6 +597,16 @@ struct KernelParams {
std::vector<uint32_t> tileShapeKv(shapeK.size(), 1);
tileShapeKv[0] = numEltsInClampedHeadDimKv / numEltsDivisor;
tileShapeKv[1] = numKeysPerTile;

// If sparse MLA is enabled, the shape and stride for K need to be updated for 2D layout
// (numTokensKvInPagedKv, headDimQk).
if (options.mSparseMla) {
shapeK = std::vector<uint64_t>{static_cast<uint64_t>(options.mHeadDimQk),
static_cast<uint64_t>(INT_MAX)};
strideK = std::vector<uint64_t>{1, static_cast<uint64_t>(options.mHeadDimQk)};
tileShapeKv[1] = 1;
}

// Build tma descriptor for K.
params.tmaK_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeKv, shapeK, strideK,
tileShapeKv, const_cast<void*>(kPtr),
Expand Down Expand Up @@ -720,6 +730,11 @@ struct KernelParams {
params.mStartTokenIdxSfO = options.mSfStartTokenIdx;
params.mScaleSfKv = options.mScaleSfKv;
params.ptrSoftmaxStats = options.softmaxStatsPtr;
// The sparseMlaTopK needs to be a multiple of 4 as we use 16B cpAsync instructions for the
// indices.
FLASHINFER_CHECK(!options.mSparseMla || (options.mSparseMlaTopK % 4) == 0,
"SparseMlaTopK must be a multiple of 4");
params.mSparseMlaTopK = options.mSparseMlaTopK;
// TODO: Integrate trtllm block-sparse attention kernels when needed.
params.mUseBlockSparseAttention = false;
return params;
Expand Down
Loading