Skip to content

Commit 934f093

Browse files
committed
expose trtllm-gen per-tensor sparse MLA kernels
Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
1 parent 54eb341 commit 934f093

File tree

7 files changed

+562
-31
lines changed

7 files changed

+562
-31
lines changed

csrc/fmhaReduction.cu

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ namespace kernels {
3434
template <int32_t TileSizePerCtaQ, int32_t HeadDim, int32_t HeadDimPerCta, bool IsE4m3Bmm,
3535
typename DtypeO, typename DtypePartialO>
3636
__global__ void __launch_bounds__(NumThreadsPerCta, 2)
37-
fmhaReductionKernel(KernelParams const params, int32_t numCtasForReduction,
37+
fmhaReductionKernel(KernelParams const params, bool sparseMla, int32_t numCtasForReduction,
3838
int32_t numCtasForAllHeads, int32_t numHeadDimCtasV) {
3939
// clang-format off
4040
// The shape of partialO buffer: [batchSize, numHeadCtas, numCtasQ, numCtasKv, TileSizePerCtaQ, headDimPerCta].
@@ -64,10 +64,25 @@ __global__ void __launch_bounds__(NumThreadsPerCta, 2)
6464

6565
// The number of validRows.
6666
int32_t const numValidRows{TileSizePerCtaQ};
67+
// The seqOffsetQ.
68+
int32_t const seqOffsetQ{params.ptrCumSeqLensQ == nullptr ? batchIdx * params.mMaxSeqLenQ
69+
: params.ptrCumSeqLensQ[batchIdx]};
70+
// The seqLenQ.
71+
int32_t const seqLenQ{params.ptrCumSeqLensQ == nullptr
72+
? params.mMaxSeqLenQ
73+
: (params.ptrCumSeqLensQ[batchIdx + 1] - seqOffsetQ)};
74+
// Early exit if ctaIdxQ >= seqLenQ, where each CTA processes one tokenQ.
75+
if (ctaIdxQ >= seqLenQ) {
76+
return;
77+
}
6778
// The actual number of seqLenKv.
6879
int32_t seqLenKv{params.ptrSeqLensKv[batchIdx]};
6980
// Consider the causal-mask speculative decoding.
7081
seqLenKv = seqLenKv - ((params.mMaxSeqLenQ - 1) - ctaIdxQ);
82+
// Consider sparseMlaTopK.
83+
if (sparseMla) {
84+
seqLenKv = min(seqLenKv, params.mSparseMlaTopK);
85+
}
7186
// The actual number of CtasKv (TileSizeKv is always 128 for now).
7287
int32_t numCtasKv{min((seqLenKv + 127) / 128, params.mMaxNumCtasKv)};
7388

@@ -336,7 +351,7 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams
336351
config.numAttrs = 1;
337352

338353
// Select the kernel function pointer.
339-
void (*kernel)(KernelParams const, int32_t, int32_t, int32_t) = nullptr;
354+
void (*kernel)(KernelParams const, bool, int32_t, int32_t, int32_t) = nullptr;
340355
if (headDimPerCtaV == 128) {
341356
SELECT_FMHA_REDUCTION_KERNEL(128);
342357
} else if (headDimPerCtaV == 256) {
@@ -346,8 +361,8 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams
346361
}
347362

348363
// Launch the kernel.
349-
cudaLaunchKernelEx(&config, kernel, params, numCtasForReduction, numCtasForAllHeads,
350-
numHeadDimCtasV);
364+
cudaLaunchKernelEx(&config, kernel, params, kernelMeta.mSparseMla, numCtasForReduction,
365+
numCtasForAllHeads, numHeadDimCtasV);
351366
cudaError_t err = cudaGetLastError();
352367
FLASHINFER_CHECK(err == cudaSuccess, "Failed to launch kernel: ", cudaGetErrorString(err));
353368
}

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ void trtllm_paged_attention_launcher(
7979
int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t kv_stride_keys_values,
8080
int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq,
8181
double bmm1_scale, double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
82-
int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sm_count,
83-
bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
82+
int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k,
83+
int64_t sm_count, bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
8484
if (num_qo_heads % num_kv_heads != 0) {
8585
std::ostringstream err_msg;
8686
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
@@ -132,6 +132,12 @@ void trtllm_paged_attention_launcher(
132132
runner_params.ptrAttentionSinks = attention_sinks;
133133
runner_params.enable_pdl = enable_pdl;
134134

135+
// The sparse MLA parameters.
136+
runner_params.mSparseMla = sparse_mla_top_k > 0;
137+
runner_params.mSparseMlaTopK = sparse_mla_top_k;
138+
TVM_FFI_ICHECK((head_dim_qk == 576 && head_dim_vo == 512) || sparse_mla_top_k <= 0)
139+
<< "Only decode MLA supports sparse MLA";
140+
135141
AlignedAllocator float_allocator(workspace_buffer, workspace_size);
136142
if (mode == TllmPagedAttentionMode::Context) {
137143
runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal;
@@ -199,9 +205,9 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
199205
TensorView workspace_buffer, TensorView block_tables,
200206
TensorView seq_lens, int64_t max_kv_len, double bmm1_scale,
201207
double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
202-
int64_t o_sf_start_index, int64_t window_left, int64_t sm_count,
203-
bool enable_pdl, int64_t workspace_size,
204-
Optional<TensorView> attention_sinks) {
208+
int64_t o_sf_start_index, int64_t window_left,
209+
int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl,
210+
int64_t workspace_size, Optional<TensorView> attention_sinks) {
205211
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
206212
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
207213
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
@@ -260,8 +266,8 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
260266
TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len,
261267
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
262268
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale,
263-
bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count,
264-
enable_pdl, workspace_size, stream);
269+
bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
270+
sparse_mla_top_k, sm_count, enable_pdl, workspace_size, stream);
265271
}
266272

267273
void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_scale_factor,
@@ -322,7 +328,7 @@ void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_sca
322328
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
323329
head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
324330
max_num_blocks_per_seq, bmm1_scale, bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index,
325-
window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, stream);
331+
window_left, sum_seq_q, /*sparse_mla_top_k=*/0, sm_count, enable_pdl, workspace_size, stream);
326332
}
327333

328334
void trtllm_ragged_attention_launcher(

flashinfer/decode.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,6 +1916,7 @@ def _paged_run(
19161916
-1, # o_sf_vec_size
19171917
0, # o_sf_start_index
19181918
window_left,
1919+
0, # sparse_mla_top_k
19191920
self._sm_count,
19201921
enable_pdl,
19211922
workspace_size,
@@ -2315,6 +2316,7 @@ def trtllm_batch_decode_with_kv_cache(
23152316
o_sf_vec_size or -1,
23162317
o_sf_start_index,
23172318
window_left,
2319+
0, # sparse_mla_top_k
23182320
sm_count,
23192321
enable_pdl,
23202322
workspace_buffer.numel() * workspace_buffer.element_size(),
@@ -2486,6 +2488,7 @@ def _check_trtllm_gen_mla_shape(
24862488
qk_nope_head_dim,
24872489
kv_lora_rank,
24882490
qk_rope_head_dim,
2491+
sparse_mla_top_k,
24892492
page_table,
24902493
page_size,
24912494
):
@@ -2510,16 +2513,23 @@ def _check_trtllm_gen_mla_shape(
25102513
f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}"
25112514
)
25122515

2513-
B_block_table, block_num = page_table.shape
2514-
block_size = page_size
2515-
if B_q != B_block_table:
2516-
raise ValueError(
2517-
f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}"
2518-
)
2519-
if block_num % (128 / block_size) != 0:
2520-
raise ValueError(
2521-
f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}"
2522-
)
2516+
if sparse_mla_top_k > 0:
2517+
page_table_shape = page_table.shape
2518+
if page_table_shape != (B_q, Q_len, sparse_mla_top_k):
2519+
raise ValueError(
2520+
f"Expected page_table.shape == (B_q, Q_len, sparse_mla_top_k), got {page_table_shape}"
2521+
)
2522+
else:
2523+
B_block_table, block_num = page_table.shape
2524+
block_size = page_size
2525+
if B_q != B_block_table:
2526+
raise ValueError(
2527+
f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}"
2528+
)
2529+
if block_num % (128 / block_size) != 0:
2530+
raise ValueError(
2531+
f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}"
2532+
)
25232533

25242534

25252535
def trtllm_batch_decode_with_kv_cache_mla(
@@ -2532,6 +2542,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
25322542
block_tables: torch.Tensor,
25332543
seq_lens: torch.Tensor,
25342544
max_seq_len: int,
2545+
sparse_mla_top_k: int = 0,
25352546
out: Optional[torch.Tensor] = None,
25362547
bmm1_scale: Optional[float] = 1.0,
25372548
bmm2_scale: Optional[float] = 1.0,
@@ -2549,6 +2560,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
25492560
qk_nope_head_dim: qk_nope_head_dim, must be 128
25502561
kv_lora_rank: kv_lora_rank, must be 512
25512562
qk_rope_head_dim: qk_rope_head_dim, must be 64
2563+
sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA.
25522564
block_tables: page_table of kv cache, [batch_size, num_pages]
25532565
seq_lens: query_len
25542566
max_seq_len: max sequence length for kv_cache
@@ -2636,6 +2648,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
26362648
qk_nope_head_dim,
26372649
kv_lora_rank,
26382650
qk_rope_head_dim,
2651+
sparse_mla_top_k,
26392652
block_tables,
26402653
block_size,
26412654
)
@@ -2663,6 +2676,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
26632676
"Dynamic scale factors bmm1_scale_tensor and bmm2_scale_tensor are only supported for fp8 tensor core operation"
26642677
)
26652678

2679+
print(f"query shape: {query.shape}")
26662680
run_func(
26672681
out,
26682682
None, # fp4 output not supported in wrapper api yet.
@@ -2679,6 +2693,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
26792693
-1, # o_sf_vec_size
26802694
0, # o_sf_start_index
26812695
-1, # window_left
2696+
sparse_mla_top_k,
26822697
sm_count,
26832698
enable_pdl,
26842699
workspace_buffer.numel() * workspace_buffer.element_size(),
@@ -2766,6 +2781,7 @@ def xqa_batch_decode_with_kv_cache_mla(
27662781
qk_nope_head_dim,
27672782
kv_lora_rank,
27682783
qk_rope_head_dim,
2784+
0, # sparse_mla_top_k
27692785
block_tables,
27702786
block_size,
27712787
)

include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,10 @@ class TllmGenFmhaKernel {
333333
if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) {
334334
// The maximum attention window (the maximum number of tokensKv that will be attended to).
335335
int maxAttentionWindow{params.mMaxSeqLenKv};
336+
// The sparseMla only selects topK tokensKv.
337+
if (params.mSparseMla) {
338+
maxAttentionWindow = std::min(params.mMaxSeqLenKv, params.mSparseMlaTopK);
339+
}
336340
// Some of the tilesKv will be skipped if the sliding window attention or chunked attention is
337341
// used.
338342
if (isSlidingOrChunkedCausalMask(selectKernelParams.mMaskType)) {
@@ -365,7 +369,8 @@ class TllmGenFmhaKernel {
365369
// Need to select a different kernel.
366370
selectKernelParams.mSelectNewKernel = true;
367371
} else if (totalNumCtas < params.mMultiProcessorCount && isMlaGenKernel(params) &&
368-
selectKernelParams.mTileSizeKv == 128 && getEnvUseTileSizeKv64ForTrtllmGen()) {
372+
!params.mSparseMla && selectKernelParams.mTileSizeKv == 128 &&
373+
getEnvUseTileSizeKv64ForTrtllmGen()) {
369374
// Use smaller tileSizeKv to fully utilize the SMs.
370375
selectKernelParams.mTileSizeKv = 64;
371376
// Need to select a different kernel.
@@ -461,13 +466,15 @@ class TllmGenFmhaKernel {
461466
// We use the low-latency kernel (SwapsMmaAbForGeneration with tileSizeQ = 16) when any of the
462467
// following conditions are met:
463468
// 1. The number of headsQPerKv is <= 32.
464-
// 2. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned
469+
// 2. The number of headsQPerKv is < 128 for sparseMla.
470+
// 3. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned
465471
// later) and
466472
// the numCtas (after splitting the heads across multiple CTAs) <=
467473
// params.mMultiProcessorCount.
468474

469475
// Check the conditions.
470-
if (params.mNumHeadsQPerKv <= 32 || useSwapsMmaAbMlaGenKernel(params)) {
476+
if (params.mNumHeadsQPerKv <= 32 || (params.mSparseMla && params.mNumHeadsQPerKv < 128) ||
477+
useSwapsMmaAbMlaGenKernel(params)) {
471478
kernelType = FmhaKernelType::SwapsMmaAbForGeneration;
472479
} else {
473480
// Otherwise, we use the high-throughput kernel.
@@ -476,6 +483,10 @@ class TllmGenFmhaKernel {
476483
if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) {
477484
selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel;
478485
}
486+
// The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128.
487+
FLASHINFER_CHECK(
488+
!params.mSparseMla || params.mNumHeadsQPerKv == 128,
489+
"The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128");
479490
// The 2CTA keepsMmaAbForGeneration kernel is used when the numHeadsQPerKv is 128.
480491
if (params.mNumHeadsQPerKv == 128) {
481492
selectKernelParams.mUses2CtaMma = true;
@@ -524,8 +535,16 @@ class TllmGenFmhaKernel {
524535
"Sliding window attention and chunked attention should not be used together");
525536
selectKernelParams.mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal;
526537
}
527-
// NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels.
528-
int numTokensPerPage = (!isPagedKv(params.mQkvLayout)) ? 0 : params.mNumTokensPerPage;
538+
539+
// The number of tokens per page.
540+
int numTokensPerPage = params.mNumTokensPerPage;
541+
// SparseMla kernels use a fixed numTokensPerPage = 1.
542+
if (params.mSparseMla) {
543+
numTokensPerPage = 1;
544+
} else if (!isPagedKv(params.mQkvLayout)) {
545+
// NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels.
546+
numTokensPerPage = 0;
547+
}
529548

530549
// Debug info.
531550
std::string info =
@@ -542,7 +561,8 @@ class TllmGenFmhaKernel {
542561
", numTokensPerPage=" + std::to_string(numTokensPerPage) +
543562
", maxNumHeadsQPerKvInCta=" + std::to_string(maxNumHeadsQPerKvInCta) +
544563
", reuseSmemKForV=" + std::to_string(selectKernelParams.mReuseSmemKForV) +
545-
", uses2CtaMma=" + std::to_string(selectKernelParams.mUses2CtaMma);
564+
", uses2CtaMma=" + std::to_string(selectKernelParams.mUses2CtaMma) +
565+
", sparseMla=" + std::to_string(params.mSparseMla);
546566
IKL_LOG_DEBUG(
547567
"Searching for kernel traits (%d available) in TllmGenFmhaKernel(%s, %s, %s, %d) %s",
548568
getNumLoadedKernels(), toStr(mDtypeQ), toStr(mDtypeKv), toStr(mDtypeOut), mSM,
@@ -555,7 +575,7 @@ class TllmGenFmhaKernel {
555575
selectKernelParams.mHeadDimPerCtaV, params.mHeadDimQk, params.mHeadDimV,
556576
selectKernelParams.mTileSizeKv, numTokensPerPage, maxNumHeadsQPerKvInCta,
557577
selectKernelParams.mReuseSmemKForV, selectKernelParams.mUses2CtaMma,
558-
/* sparseMla */ false),
578+
params.mSparseMla),
559579
info);
560580
}
561581

include/flashinfer/trtllm/fmha/fmhaRunnerParams.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ struct TllmGenFmhaRunnerParams {
287287
float mScaleSfKv;
288288
// The SF scale for output.
289289
float mScaleSfO;
290+
// Whether to use sparse MLA.
291+
bool mSparseMla;
292+
// The top k value for sparse MLA.
293+
int mSparseMlaTopK;
290294
// The cuda stream.
291295
cudaStream_t stream;
292296
// Whether to enable PDL (Programmatic Dependent Launch).

include/flashinfer/trtllm/fmha/kernelParams.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,8 +486,8 @@ struct KernelParams {
486486

487487
// Check shape must be in range [1, 2^32]
488488
int32_t dim = shapes.size();
489-
// Max five dimension and min 3 dimension.
490-
FLASHINFER_CHECK((dim <= 5) && (dim >= 3));
489+
// Max five dimension and min 2 dimension.
490+
FLASHINFER_CHECK((dim <= 5) && (dim >= 2));
491491
// Check shape range.
492492
for (int32_t ii = 0; ii < dim; ++ii) {
493493
FLASHINFER_CHECK(shapes[ii] >= (uint64_t(1))); // Size must be min 1
@@ -597,6 +597,16 @@ struct KernelParams {
597597
std::vector<uint32_t> tileShapeKv(shapeK.size(), 1);
598598
tileShapeKv[0] = numEltsInClampedHeadDimKv / numEltsDivisor;
599599
tileShapeKv[1] = numKeysPerTile;
600+
601+
// If sparse MLA is enabled, the shape and stride for K need to be updated for 2D layout
602+
// (numTokensKvInPagedKv, headDimQk).
603+
if (options.mSparseMla) {
604+
shapeK = std::vector<uint64_t>{static_cast<uint64_t>(options.mHeadDimQk),
605+
static_cast<uint64_t>(INT_MAX)};
606+
strideK = std::vector<uint64_t>{1, static_cast<uint64_t>(options.mHeadDimQk)};
607+
tileShapeKv[1] = 1;
608+
}
609+
600610
// Build tma descriptor for K.
601611
params.tmaK_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeKv, shapeK, strideK,
602612
tileShapeKv, const_cast<void*>(kPtr),
@@ -721,6 +731,11 @@ struct KernelParams {
721731
params.mStartTokenIdxSfO = options.mSfStartTokenIdx;
722732
params.mScaleSfKv = options.mScaleSfKv;
723733
params.ptrSoftmaxStats = options.softmaxStatsPtr;
734+
// The sparseMlaTopK needs to be a multiple of 4 as we use 16B cpAsync instructions for the
735+
// indices.
736+
FLASHINFER_CHECK(!options.mSparseMla || (options.mSparseMlaTopK % 4) == 0,
737+
"SparseMlaTopK must be a multiple of 4");
738+
params.mSparseMlaTopK = options.mSparseMlaTopK;
724739
// TODO: Integrate trtllm block-sparse attention kernels when needed.
725740
params.mUseBlockSparseAttention = false;
726741
return params;

0 commit comments

Comments
 (0)