Skip to content

Commit ba867e6

Browse files
committed
expose trtllm-gen per-tensor sparse MLA kernels
Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
1 parent b8e6c83 commit ba867e6

File tree

7 files changed

+566
-38
lines changed

7 files changed

+566
-38
lines changed

csrc/fmhaReduction.cu

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ namespace kernels {
3434
template <int32_t TileSizePerCtaQ, int32_t HeadDim, int32_t HeadDimPerCta, bool IsE4m3Bmm,
3535
typename DtypeO, typename DtypePartialO>
3636
__global__ void __launch_bounds__(NumThreadsPerCta, 2)
37-
fmhaReductionKernel(KernelParams const params, int32_t numCtasForReduction,
37+
fmhaReductionKernel(KernelParams const params, bool sparseMla, int32_t numCtasForReduction,
3838
int32_t numCtasForAllHeads, int32_t numHeadDimCtasV) {
3939
// clang-format off
4040
// The shape of partialO buffer: [batchSize, numHeadCtas, numCtasQ, numCtasKv, TileSizePerCtaQ, headDimPerCta].
@@ -64,10 +64,25 @@ __global__ void __launch_bounds__(NumThreadsPerCta, 2)
6464

6565
// The number of validRows.
6666
int32_t const numValidRows{TileSizePerCtaQ};
67+
// The seqOffsetQ.
68+
int32_t const seqOffsetQ{params.ptrCumSeqLensQ == nullptr ? batchIdx * params.mMaxSeqLenQ
69+
: params.ptrCumSeqLensQ[batchIdx]};
70+
// The seqLenQ.
71+
int32_t const seqLenQ{params.ptrCumSeqLensQ == nullptr
72+
? params.mMaxSeqLenQ
73+
: (params.ptrCumSeqLensQ[batchIdx + 1] - seqOffsetQ)};
74+
// Early exit if ctaIdxQ >= seqLenQ, where each CTA processes one tokenQ.
75+
if (ctaIdxQ >= seqLenQ) {
76+
return;
77+
}
6778
// The actual number of seqLenKv.
6879
int32_t seqLenKv{params.ptrSeqLensKv[batchIdx]};
6980
// Consider the causal-mask speculative decoding.
7081
seqLenKv = seqLenKv - ((params.mMaxSeqLenQ - 1) - ctaIdxQ);
82+
// Consider sparseMlaTopK.
83+
if (sparseMla) {
84+
seqLenKv = min(seqLenKv, params.mSparseMlaTopK);
85+
}
7186
// The actual number of CtasKv (TileSizeKv is always 128 for now).
7287
int32_t numCtasKv{min((seqLenKv + 127) / 128, params.mMaxNumCtasKv)};
7388

@@ -336,7 +351,7 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams
336351
config.numAttrs = 1;
337352

338353
// Select the kernel function pointer.
339-
void (*kernel)(KernelParams const, int32_t, int32_t, int32_t) = nullptr;
354+
void (*kernel)(KernelParams const, bool, int32_t, int32_t, int32_t) = nullptr;
340355
if (headDimPerCtaV == 128) {
341356
SELECT_FMHA_REDUCTION_KERNEL(128);
342357
} else if (headDimPerCtaV == 256) {
@@ -346,8 +361,8 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams
346361
}
347362

348363
// Launch the kernel.
349-
cudaLaunchKernelEx(&config, kernel, params, numCtasForReduction, numCtasForAllHeads,
350-
numHeadDimCtasV);
364+
cudaLaunchKernelEx(&config, kernel, params, kernelMeta.mSparseMla, numCtasForReduction,
365+
numCtasForAllHeads, numHeadDimCtasV);
351366
cudaError_t err = cudaGetLastError();
352367
FLASHINFER_CHECK(err == cudaSuccess, "Failed to launch kernel: ", cudaGetErrorString(err));
353368
}

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ void trtllm_paged_attention_launcher(
8282
int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq,
8383
double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr,
8484
const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index,
85-
int64_t window_left, int64_t sum_seq_q, int64_t sm_count, bool enable_pdl,
86-
int64_t workspace_size, cudaStream_t stream) {
85+
int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, int64_t sm_count,
86+
bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
8787
if (num_qo_heads % num_kv_heads != 0) {
8888
std::ostringstream err_msg;
8989
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
@@ -139,6 +139,12 @@ void trtllm_paged_attention_launcher(
139139
runner_params.ptrAttentionSinks = attention_sinks;
140140
runner_params.enable_pdl = enable_pdl;
141141

142+
// The sparse MLA parameters.
143+
runner_params.mSparseMla = sparse_mla_top_k > 0;
144+
runner_params.mSparseMlaTopK = sparse_mla_top_k;
145+
TVM_FFI_ICHECK((head_dim_qk == 576 && head_dim_vo == 512) || sparse_mla_top_k <= 0)
146+
<< "Only decode MLA supports sparse MLA";
147+
142148
AlignedAllocator float_allocator(workspace_buffer, workspace_size);
143149
if (mode == TllmPagedAttentionMode::Context) {
144150
runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal;
@@ -201,15 +207,13 @@ inline Data_type dl_dtype_to_tllm_data_type(const DLDataType dtype) {
201207

202208
inline bool is_4bit(Data_type data_type) { return data_type == Data_type::DATA_TYPE_E2M1; }
203209

204-
void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scale_factor,
205-
TensorView query, TensorView key_cache, TensorView value_cache,
206-
TensorView workspace_buffer, TensorView block_tables,
207-
TensorView seq_lens, int64_t max_kv_len,
208-
Variant<double, ffi::Tensor> bmm1_scale,
209-
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale,
210-
int64_t o_sf_vec_size, int64_t o_sf_start_index,
211-
int64_t window_left, int64_t sm_count, bool enable_pdl,
212-
int64_t workspace_size, Optional<TensorView> attention_sinks) {
210+
void trtllm_paged_attention_decode(
211+
TensorView out, Optional<TensorView> out_scale_factor, TensorView query, TensorView key_cache,
212+
TensorView value_cache, TensorView workspace_buffer, TensorView block_tables,
213+
TensorView seq_lens, int64_t max_kv_len, Variant<double, ffi::Tensor> bmm1_scale,
214+
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
215+
int64_t o_sf_start_index, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count,
216+
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) {
213217
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
214218
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
215219
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
@@ -287,8 +291,8 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
287291
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
288292
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq,
289293
bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,
290-
o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, enable_pdl, workspace_size,
291-
stream);
294+
o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, sm_count,
295+
enable_pdl, workspace_size, stream);
292296
}
293297

294298
void trtllm_paged_attention_context(
@@ -367,8 +371,8 @@ void trtllm_paged_attention_context(
367371
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
368372
head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
369373
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
370-
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count,
371-
enable_pdl, workspace_size, stream);
374+
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
375+
/*sparse_mla_top_k=*/0, sm_count, enable_pdl, workspace_size, stream);
372376
}
373377

374378
void trtllm_ragged_attention_launcher(

flashinfer/decode.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1922,6 +1922,7 @@ def _paged_run(
19221922
-1, # o_sf_vec_size
19231923
0, # o_sf_start_index
19241924
window_left,
1925+
0, # sparse_mla_top_k
19251926
self._sm_count,
19261927
enable_pdl,
19271928
workspace_size,
@@ -2328,6 +2329,7 @@ def trtllm_batch_decode_with_kv_cache(
23282329
o_sf_vec_size or -1,
23292330
o_sf_start_index,
23302331
window_left,
2332+
0, # sparse_mla_top_k
23312333
sm_count,
23322334
enable_pdl,
23332335
workspace_buffer.numel() * workspace_buffer.element_size(),
@@ -2500,6 +2502,7 @@ def _check_trtllm_gen_mla_shape(
25002502
qk_nope_head_dim,
25012503
kv_lora_rank,
25022504
qk_rope_head_dim,
2505+
sparse_mla_top_k,
25032506
page_table,
25042507
page_size,
25052508
):
@@ -2524,16 +2527,23 @@ def _check_trtllm_gen_mla_shape(
25242527
f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}"
25252528
)
25262529

2527-
B_block_table, block_num = page_table.shape
2528-
block_size = page_size
2529-
if B_q != B_block_table:
2530-
raise ValueError(
2531-
f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}"
2532-
)
2533-
if block_num % (128 / block_size) != 0:
2534-
raise ValueError(
2535-
f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}"
2536-
)
2530+
if sparse_mla_top_k > 0:
2531+
page_table_shape = page_table.shape
2532+
if page_table_shape != (B_q, Q_len, sparse_mla_top_k):
2533+
raise ValueError(
2534+
f"Expected page_table.shape == (B_q, Q_len, sparse_mla_top_k), got {page_table_shape}"
2535+
)
2536+
else:
2537+
B_block_table, block_num = page_table.shape
2538+
block_size = page_size
2539+
if B_q != B_block_table:
2540+
raise ValueError(
2541+
f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}"
2542+
)
2543+
if block_num % (128 / block_size) != 0:
2544+
raise ValueError(
2545+
f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}"
2546+
)
25372547

25382548

25392549
@flashinfer_api
@@ -2547,6 +2557,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
25472557
block_tables: torch.Tensor,
25482558
seq_lens: torch.Tensor,
25492559
max_seq_len: int,
2560+
sparse_mla_top_k: int = 0,
25502561
out: Optional[torch.Tensor] = None,
25512562
bmm1_scale: Union[float, torch.Tensor] = 1.0,
25522563
bmm2_scale: Union[float, torch.Tensor] = 1.0,
@@ -2562,6 +2573,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
25622573
qk_nope_head_dim: qk_nope_head_dim, must be 128
25632574
kv_lora_rank: kv_lora_rank, must be 512
25642575
qk_rope_head_dim: qk_rope_head_dim, must be 64
2576+
sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA.
25652577
block_tables: page_table of kv cache, [batch_size, num_pages]
25662578
seq_lens: query_len
25672579
max_seq_len: max sequence length for kv_cache
@@ -2654,6 +2666,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
26542666
qk_nope_head_dim,
26552667
kv_lora_rank,
26562668
qk_rope_head_dim,
2669+
sparse_mla_top_k,
26572670
block_tables,
26582671
block_size,
26592672
)
@@ -2687,6 +2700,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
26872700
-1, # o_sf_vec_size
26882701
0, # o_sf_start_index
26892702
-1, # window_left
2703+
sparse_mla_top_k,
26902704
sm_count,
26912705
enable_pdl,
26922706
workspace_buffer.numel() * workspace_buffer.element_size(),
@@ -2768,6 +2782,7 @@ def xqa_batch_decode_with_kv_cache_mla(
27682782
qk_nope_head_dim,
27692783
kv_lora_rank,
27702784
qk_rope_head_dim,
2785+
0, # sparse_mla_top_k
27712786
block_tables,
27722787
block_size,
27732788
)

include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,10 @@ class TllmGenFmhaKernel {
333333
if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) {
334334
// The maximum attention window (the maximum number of tokensKv that will be attended to).
335335
int maxAttentionWindow{params.mMaxSeqLenKv};
336+
// The sparseMla only selects topK tokensKv.
337+
if (params.mSparseMla) {
338+
maxAttentionWindow = std::min(params.mMaxSeqLenKv, params.mSparseMlaTopK);
339+
}
336340
// Some of the tilesKv will be skipped if the sliding window attention or chunked attention is
337341
// used.
338342
if (isSlidingOrChunkedCausalMask(selectKernelParams.mMaskType)) {
@@ -365,7 +369,8 @@ class TllmGenFmhaKernel {
365369
// Need to select a different kernel.
366370
selectKernelParams.mSelectNewKernel = true;
367371
} else if (totalNumCtas < params.mMultiProcessorCount && isMlaGenKernel(params) &&
368-
selectKernelParams.mTileSizeKv == 128 && getEnvUseTileSizeKv64ForTrtllmGen()) {
372+
!params.mSparseMla && selectKernelParams.mTileSizeKv == 128 &&
373+
getEnvUseTileSizeKv64ForTrtllmGen()) {
369374
// Use smaller tileSizeKv to fully utilize the SMs.
370375
selectKernelParams.mTileSizeKv = 64;
371376
// Need to select a different kernel.
@@ -461,13 +466,15 @@ class TllmGenFmhaKernel {
461466
// We use the low-latency kernel (SwapsMmaAbForGeneration with tileSizeQ = 16) when any of the
462467
// following conditions are met:
463468
// 1. The number of headsQPerKv is <= 32.
464-
// 2. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned
469+
// 2. The number of headsQPerKv is < 128 for sparseMla.
470+
// 3. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned
465471
// later) and
466472
// the numCtas (after splitting the heads across multiple CTAs) <=
467473
// params.mMultiProcessorCount.
468474

469475
// Check the conditions.
470-
if (params.mNumHeadsQPerKv <= 32 || useSwapsMmaAbMlaGenKernel(params)) {
476+
if (params.mNumHeadsQPerKv <= 32 || (params.mSparseMla && params.mNumHeadsQPerKv < 128) ||
477+
useSwapsMmaAbMlaGenKernel(params)) {
471478
kernelType = FmhaKernelType::SwapsMmaAbForGeneration;
472479
} else {
473480
// Otherwise, we use the high-throughput kernel.
@@ -476,6 +483,10 @@ class TllmGenFmhaKernel {
476483
if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) {
477484
selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel;
478485
}
486+
// The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128.
487+
FLASHINFER_CHECK(
488+
!params.mSparseMla || params.mNumHeadsQPerKv == 128,
489+
"The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128");
479490
// The 2CTA keepsMmaAbForGeneration kernel is used when the numHeadsQPerKv is 128.
480491
if (params.mNumHeadsQPerKv == 128) {
481492
selectKernelParams.mUses2CtaMma = true;
@@ -524,8 +535,16 @@ class TllmGenFmhaKernel {
524535
"Sliding window attention and chunked attention should not be used together");
525536
selectKernelParams.mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal;
526537
}
527-
// NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels.
528-
int numTokensPerPage = (!isPagedKv(params.mQkvLayout)) ? 0 : params.mNumTokensPerPage;
538+
539+
// The number of tokens per page.
540+
int numTokensPerPage = params.mNumTokensPerPage;
541+
// SparseMla kernels use a fixed numTokensPerPage = 1.
542+
if (params.mSparseMla) {
543+
numTokensPerPage = 1;
544+
} else if (!isPagedKv(params.mQkvLayout)) {
545+
// NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels.
546+
numTokensPerPage = 0;
547+
}
529548

530549
// Debug info.
531550
std::string info =
@@ -542,7 +561,8 @@ class TllmGenFmhaKernel {
542561
", numTokensPerPage=" + std::to_string(numTokensPerPage) +
543562
", maxNumHeadsQPerKvInCta=" + std::to_string(maxNumHeadsQPerKvInCta) +
544563
", reuseSmemKForV=" + std::to_string(selectKernelParams.mReuseSmemKForV) +
545-
", uses2CtaMma=" + std::to_string(selectKernelParams.mUses2CtaMma);
564+
", uses2CtaMma=" + std::to_string(selectKernelParams.mUses2CtaMma) +
565+
", sparseMla=" + std::to_string(params.mSparseMla);
546566
IKL_LOG_DEBUG(
547567
"Searching for kernel traits (%d available) in TllmGenFmhaKernel(%s, %s, %s, %d) %s",
548568
getNumLoadedKernels(), toStr(mDtypeQ), toStr(mDtypeKv), toStr(mDtypeOut), mSM,
@@ -555,7 +575,7 @@ class TllmGenFmhaKernel {
555575
selectKernelParams.mHeadDimPerCtaV, params.mHeadDimQk, params.mHeadDimV,
556576
selectKernelParams.mTileSizeKv, numTokensPerPage, maxNumHeadsQPerKvInCta,
557577
selectKernelParams.mReuseSmemKForV, selectKernelParams.mUses2CtaMma,
558-
/* sparseMla */ false),
578+
params.mSparseMla),
559579
info);
560580
}
561581

include/flashinfer/trtllm/fmha/fmhaRunnerParams.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ struct TllmGenFmhaRunnerParams {
287287
float mScaleSfKv;
288288
// The SF scale for output.
289289
float mScaleSfO;
290+
// Whether to use sparse MLA.
291+
bool mSparseMla;
292+
// The top k value for sparse MLA.
293+
int mSparseMlaTopK;
290294
// The cuda stream.
291295
cudaStream_t stream;
292296
// Whether to enable PDL (Programmatic Dependent Launch).

include/flashinfer/trtllm/fmha/kernelParams.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,8 +486,8 @@ struct KernelParams {
486486

487487
// Check shape must be in range [1, 2^32]
488488
int32_t dim = shapes.size();
489-
// Max five dimension and min 3 dimension.
490-
FLASHINFER_CHECK((dim <= 5) && (dim >= 3));
489+
// Max five dimension and min 2 dimension.
490+
FLASHINFER_CHECK((dim <= 5) && (dim >= 2));
491491
// Check shape range.
492492
for (int32_t ii = 0; ii < dim; ++ii) {
493493
FLASHINFER_CHECK(shapes[ii] >= (uint64_t(1))); // Size must be min 1
@@ -597,6 +597,16 @@ struct KernelParams {
597597
std::vector<uint32_t> tileShapeKv(shapeK.size(), 1);
598598
tileShapeKv[0] = numEltsInClampedHeadDimKv / numEltsDivisor;
599599
tileShapeKv[1] = numKeysPerTile;
600+
601+
// If sparse MLA is enabled, the shape and stride for K need to be updated for 2D layout
602+
// (numTokensKvInPagedKv, headDimQk).
603+
if (options.mSparseMla) {
604+
shapeK = std::vector<uint64_t>{static_cast<uint64_t>(options.mHeadDimQk),
605+
static_cast<uint64_t>(INT_MAX)};
606+
strideK = std::vector<uint64_t>{1, static_cast<uint64_t>(options.mHeadDimQk)};
607+
tileShapeKv[1] = 1;
608+
}
609+
600610
// Build tma descriptor for K.
601611
params.tmaK_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeKv, shapeK, strideK,
602612
tileShapeKv, const_cast<void*>(kPtr),
@@ -720,6 +730,11 @@ struct KernelParams {
720730
params.mStartTokenIdxSfO = options.mSfStartTokenIdx;
721731
params.mScaleSfKv = options.mScaleSfKv;
722732
params.ptrSoftmaxStats = options.softmaxStatsPtr;
733+
// The sparseMlaTopK needs to be a multiple of 4 as we use 16B cpAsync instructions for the
734+
// indices.
735+
FLASHINFER_CHECK(!options.mSparseMla || (options.mSparseMlaTopK % 4) == 0,
736+
"SparseMlaTopK must be a multiple of 4");
737+
params.mSparseMlaTopK = options.mSparseMlaTopK;
723738
// TODO: Integrate trtllm block-sparse attention kernels when needed.
724739
params.mUseBlockSparseAttention = false;
725740
return params;

0 commit comments

Comments
 (0)