diff --git a/csrc/fmhaReduction.cu b/csrc/fmhaReduction.cu index 1f1ca8c755..e329e1c14b 100644 --- a/csrc/fmhaReduction.cu +++ b/csrc/fmhaReduction.cu @@ -34,7 +34,7 @@ namespace kernels { template __global__ void __launch_bounds__(NumThreadsPerCta, 2) - fmhaReductionKernel(KernelParams const params, int32_t numCtasForReduction, + fmhaReductionKernel(KernelParams const params, bool sparseMla, int32_t numCtasForReduction, int32_t numCtasForAllHeads, int32_t numHeadDimCtasV) { // clang-format off // The shape of partialO buffer: [batchSize, numHeadCtas, numCtasQ, numCtasKv, TileSizePerCtaQ, headDimPerCta]. @@ -64,10 +64,25 @@ __global__ void __launch_bounds__(NumThreadsPerCta, 2) // The number of validRows. int32_t const numValidRows{TileSizePerCtaQ}; + // The seqOffsetQ. + int32_t const seqOffsetQ{params.ptrCumSeqLensQ == nullptr ? batchIdx * params.mMaxSeqLenQ + : params.ptrCumSeqLensQ[batchIdx]}; + // The seqLenQ. + int32_t const seqLenQ{params.ptrCumSeqLensQ == nullptr + ? params.mMaxSeqLenQ + : (params.ptrCumSeqLensQ[batchIdx + 1] - seqOffsetQ)}; + // Early exit if ctaIdxQ >= seqLenQ, where each CTA processes one tokenQ. + if (ctaIdxQ >= seqLenQ) { + return; + } // The actual number of seqLenKv. int32_t seqLenKv{params.ptrSeqLensKv[batchIdx]}; // Consider the causal-mask speculative decoding. seqLenKv = seqLenKv - ((params.mMaxSeqLenQ - 1) - ctaIdxQ); + // Consider sparseMlaTopK. + if (sparseMla) { + seqLenKv = min(seqLenKv, params.mSparseMlaTopK); + } // The actual number of CtasKv (TileSizeKv is always 128 for now). int32_t numCtasKv{min((seqLenKv + 127) / 128, params.mMaxNumCtasKv)}; @@ -336,7 +351,7 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams config.numAttrs = 1; // Select the kernel function pointer. - void (*kernel)(KernelParams const, int32_t, int32_t, int32_t) = nullptr; + void (*kernel)(KernelParams const, bool, int32_t, int32_t, int32_t) = nullptr; if (headDimPerCtaV == 128) { SELECT_FMHA_REDUCTION_KERNEL(128); } else if (headDimPerCtaV == 256) { @@ -346,8 +361,8 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams } // Launch the kernel. - cudaLaunchKernelEx(&config, kernel, params, numCtasForReduction, numCtasForAllHeads, - numHeadDimCtasV); + cudaLaunchKernelEx(&config, kernel, params, kernelMeta.mSparseMla, numCtasForReduction, + numCtasForAllHeads, numHeadDimCtasV); cudaError_t err = cudaGetLastError(); FLASHINFER_CHECK(err == cudaSuccess, "Failed to launch kernel: ", cudaGetErrorString(err)); } diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 5c1de17bb0..89fe53b874 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -82,8 +82,8 @@ void trtllm_paged_attention_launcher( int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, - int64_t window_left, int64_t sum_seq_q, int64_t sm_count, bool enable_pdl, - int64_t workspace_size, cudaStream_t stream) { + int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, int64_t sm_count, + bool enable_pdl, int64_t workspace_size, cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads @@ -139,6 +139,12 @@ void trtllm_paged_attention_launcher( runner_params.ptrAttentionSinks = attention_sinks; runner_params.enable_pdl = enable_pdl; + // The sparse MLA parameters. + runner_params.mSparseMla = sparse_mla_top_k > 0; + runner_params.mSparseMlaTopK = sparse_mla_top_k; + TVM_FFI_ICHECK((head_dim_qk == 576 && head_dim_vo == 512) || sparse_mla_top_k <= 0) + << "Only decode MLA supports sparse MLA"; + AlignedAllocator float_allocator(workspace_buffer, workspace_size); if (mode == TllmPagedAttentionMode::Context) { runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal; @@ -201,15 +207,13 @@ inline Data_type dl_dtype_to_tllm_data_type(const DLDataType dtype) { inline bool is_4bit(Data_type data_type) { return data_type == Data_type::DATA_TYPE_E2M1; } -void trtllm_paged_attention_decode(TensorView out, Optional out_scale_factor, - TensorView query, TensorView key_cache, TensorView value_cache, - TensorView workspace_buffer, TensorView block_tables, - TensorView seq_lens, int64_t max_kv_len, - Variant bmm1_scale, - Variant bmm2_scale, double o_sf_scale, - int64_t o_sf_vec_size, int64_t o_sf_start_index, - int64_t window_left, int64_t sm_count, bool enable_pdl, - int64_t workspace_size, Optional attention_sinks) { +void trtllm_paged_attention_decode( + TensorView out, Optional out_scale_factor, TensorView query, TensorView key_cache, + TensorView value_cache, TensorView workspace_buffer, TensorView block_tables, + TensorView seq_lens, int64_t max_kv_len, Variant bmm1_scale, + Variant bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, + int64_t o_sf_start_index, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, + bool enable_pdl, int64_t workspace_size, Optional attention_sinks) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); @@ -287,8 +291,8 @@ void trtllm_paged_attention_decode(TensorView out, Optional out_scal num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, - o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, - stream); + o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, sm_count, + enable_pdl, workspace_size, stream); } void trtllm_paged_attention_context( @@ -367,8 +371,8 @@ void trtllm_paged_attention_context( max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, - bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, - enable_pdl, workspace_size, stream); + bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, + /*sparse_mla_top_k=*/0, sm_count, enable_pdl, workspace_size, stream); } void trtllm_ragged_attention_launcher( diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index cfb2862e47..b520023b70 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -87,7 +87,7 @@ class ArtifactPath: When compiling new cubins for backend directories, update the corresponding path. """ - TRTLLM_GEN_FMHA: str = "1e49deb33ec20018ae0acf1d956a579578069da1/fmha/trtllm-gen/" + TRTLLM_GEN_FMHA: str = "9f1b6ddaa1592a8339a82fcab7d27a57eff445fd/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "c108f5cc46420e11805467898186533fb48d6a6f/batched_gemm-0d28130-7b26988" ) @@ -107,7 +107,7 @@ class CheckSumHash: """ TRTLLM_GEN_FMHA: str = ( - "66757498f573430583d63b04c02bf9e38306eefe2ce31df9b5d923d99bd15d84" + "a5a60600a80076317703695f56bbef2f0a44075ef4e24d7b06ba67ff68bc9da2" ) TRTLLM_GEN_BMM: str = ( "85a4516b7ab25b1a6495398ae934a00e30ccd6662b9ec27be1330d7bba5e1ddf" diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 1f682c9844..3f9f03ebb7 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1922,6 +1922,7 @@ def _paged_run( -1, # o_sf_vec_size 0, # o_sf_start_index window_left, + 0, # sparse_mla_top_k self._sm_count, enable_pdl, workspace_size, @@ -2328,6 +2329,7 @@ def trtllm_batch_decode_with_kv_cache( o_sf_vec_size or -1, o_sf_start_index, window_left, + 0, # sparse_mla_top_k sm_count, enable_pdl, workspace_buffer.numel() * workspace_buffer.element_size(), @@ -2500,6 +2502,7 @@ def _check_trtllm_gen_mla_shape( qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim, + sparse_mla_top_k, page_table, page_size, ): @@ -2524,16 +2527,23 @@ def _check_trtllm_gen_mla_shape( f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}" ) - B_block_table, block_num = page_table.shape - block_size = page_size - if B_q != B_block_table: - raise ValueError( - f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}" - ) - if block_num % (128 / block_size) != 0: - raise ValueError( - f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}" - ) + if sparse_mla_top_k > 0: + page_table_shape = page_table.shape + if page_table_shape != (B_q, Q_len, sparse_mla_top_k): + raise ValueError( + f"Expected page_table.shape == (B_q, Q_len, sparse_mla_top_k), got {page_table_shape}" + ) + else: + B_block_table, block_num = page_table.shape + block_size = page_size + if B_q != B_block_table: + raise ValueError( + f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}" + ) + if block_num % (128 / block_size) != 0: + raise ValueError( + f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}" + ) @flashinfer_api @@ -2547,6 +2557,7 @@ def trtllm_batch_decode_with_kv_cache_mla( block_tables: torch.Tensor, seq_lens: torch.Tensor, max_seq_len: int, + sparse_mla_top_k: int = 0, out: Optional[torch.Tensor] = None, bmm1_scale: Union[float, torch.Tensor] = 1.0, bmm2_scale: Union[float, torch.Tensor] = 1.0, @@ -2562,6 +2573,7 @@ def trtllm_batch_decode_with_kv_cache_mla( qk_nope_head_dim: qk_nope_head_dim, must be 128 kv_lora_rank: kv_lora_rank, must be 512 qk_rope_head_dim: qk_rope_head_dim, must be 64 + sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA. block_tables: page_table of kv cache, [batch_size, num_pages] seq_lens: query_len max_seq_len: max sequence length for kv_cache @@ -2654,6 +2666,7 @@ def trtllm_batch_decode_with_kv_cache_mla( qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim, + sparse_mla_top_k, block_tables, block_size, ) @@ -2687,6 +2700,7 @@ def trtllm_batch_decode_with_kv_cache_mla( -1, # o_sf_vec_size 0, # o_sf_start_index -1, # window_left + sparse_mla_top_k, sm_count, enable_pdl, workspace_buffer.numel() * workspace_buffer.element_size(), @@ -2768,6 +2782,7 @@ def xqa_batch_decode_with_kv_cache_mla( qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim, + 0, # sparse_mla_top_k block_tables, block_size, ) diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index 5bd91f4064..7fb695ed6d 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -333,6 +333,10 @@ class TllmGenFmhaKernel { if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) { // The maximum attention window (the maximum number of tokensKv that will be attended to). int maxAttentionWindow{params.mMaxSeqLenKv}; + // The sparseMla only selects topK tokensKv. + if (params.mSparseMla) { + maxAttentionWindow = std::min(params.mMaxSeqLenKv, params.mSparseMlaTopK); + } // Some of the tilesKv will be skipped if the sliding window attention or chunked attention is // used. if (isSlidingOrChunkedCausalMask(selectKernelParams.mMaskType)) { @@ -365,7 +369,8 @@ class TllmGenFmhaKernel { // Need to select a different kernel. selectKernelParams.mSelectNewKernel = true; } else if (totalNumCtas < params.mMultiProcessorCount && isMlaGenKernel(params) && - selectKernelParams.mTileSizeKv == 128 && getEnvUseTileSizeKv64ForTrtllmGen()) { + !params.mSparseMla && selectKernelParams.mTileSizeKv == 128 && + getEnvUseTileSizeKv64ForTrtllmGen()) { // Use smaller tileSizeKv to fully utilize the SMs. selectKernelParams.mTileSizeKv = 64; // Need to select a different kernel. @@ -461,13 +466,15 @@ class TllmGenFmhaKernel { // We use the low-latency kernel (SwapsMmaAbForGeneration with tileSizeQ = 16) when any of the // following conditions are met: // 1. The number of headsQPerKv is <= 32. - // 2. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned + // 2. The number of headsQPerKv is < 128 for sparseMla. + // 3. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned // later) and // the numCtas (after splitting the heads across multiple CTAs) <= // params.mMultiProcessorCount. // Check the conditions. - if (params.mNumHeadsQPerKv <= 32 || useSwapsMmaAbMlaGenKernel(params)) { + if (params.mNumHeadsQPerKv <= 32 || (params.mSparseMla && params.mNumHeadsQPerKv < 128) || + useSwapsMmaAbMlaGenKernel(params)) { kernelType = FmhaKernelType::SwapsMmaAbForGeneration; } else { // Otherwise, we use the high-throughput kernel. @@ -476,6 +483,10 @@ class TllmGenFmhaKernel { if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) { selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel; } + // The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128. + FLASHINFER_CHECK( + !params.mSparseMla || params.mNumHeadsQPerKv == 128, + "The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128"); // The 2CTA keepsMmaAbForGeneration kernel is used when the numHeadsQPerKv is 128. if (params.mNumHeadsQPerKv == 128) { selectKernelParams.mUses2CtaMma = true; @@ -524,8 +535,16 @@ class TllmGenFmhaKernel { "Sliding window attention and chunked attention should not be used together"); selectKernelParams.mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal; } - // NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels. - int numTokensPerPage = (!isPagedKv(params.mQkvLayout)) ? 0 : params.mNumTokensPerPage; + + // The number of tokens per page. + int numTokensPerPage = params.mNumTokensPerPage; + // SparseMla kernels use a fixed numTokensPerPage = 1. + if (params.mSparseMla) { + numTokensPerPage = 1; + } else if (!isPagedKv(params.mQkvLayout)) { + // NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels. + numTokensPerPage = 0; + } // Debug info. std::string info = @@ -542,7 +561,8 @@ class TllmGenFmhaKernel { ", numTokensPerPage=" + std::to_string(numTokensPerPage) + ", maxNumHeadsQPerKvInCta=" + std::to_string(maxNumHeadsQPerKvInCta) + ", reuseSmemKForV=" + std::to_string(selectKernelParams.mReuseSmemKForV) + - ", uses2CtaMma=" + std::to_string(selectKernelParams.mUses2CtaMma); + ", uses2CtaMma=" + std::to_string(selectKernelParams.mUses2CtaMma) + + ", sparseMla=" + std::to_string(params.mSparseMla); IKL_LOG_DEBUG( "Searching for kernel traits (%d available) in TllmGenFmhaKernel(%s, %s, %s, %d) %s", getNumLoadedKernels(), toStr(mDtypeQ), toStr(mDtypeKv), toStr(mDtypeOut), mSM, @@ -555,7 +575,7 @@ class TllmGenFmhaKernel { selectKernelParams.mHeadDimPerCtaV, params.mHeadDimQk, params.mHeadDimV, selectKernelParams.mTileSizeKv, numTokensPerPage, maxNumHeadsQPerKvInCta, selectKernelParams.mReuseSmemKForV, selectKernelParams.mUses2CtaMma, - /* sparseMla */ false), + params.mSparseMla), info); } diff --git a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h index b05ce51ae3..ab48bc04cd 100755 --- a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h +++ b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h @@ -287,6 +287,10 @@ struct TllmGenFmhaRunnerParams { float mScaleSfKv; // The SF scale for output. float mScaleSfO; + // Whether to use sparse MLA. + bool mSparseMla; + // The top k value for sparse MLA. + int mSparseMlaTopK; // The cuda stream. cudaStream_t stream; // Whether to enable PDL (Programmatic Dependent Launch). diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index c184ad9e10..a308eacfce 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -486,8 +486,8 @@ struct KernelParams { // Check shape must be in range [1, 2^32] int32_t dim = shapes.size(); - // Max five dimension and min 3 dimension. - FLASHINFER_CHECK((dim <= 5) && (dim >= 3)); + // Max five dimension and min 2 dimension. + FLASHINFER_CHECK((dim <= 5) && (dim >= 2)); // Check shape range. for (int32_t ii = 0; ii < dim; ++ii) { FLASHINFER_CHECK(shapes[ii] >= (uint64_t(1))); // Size must be min 1 @@ -597,6 +597,16 @@ struct KernelParams { std::vector tileShapeKv(shapeK.size(), 1); tileShapeKv[0] = numEltsInClampedHeadDimKv / numEltsDivisor; tileShapeKv[1] = numKeysPerTile; + + // If sparse MLA is enabled, the shape and stride for K need to be updated for 2D layout + // (numTokensKvInPagedKv, headDimQk). + if (options.mSparseMla) { + shapeK = std::vector{static_cast(options.mHeadDimQk), + static_cast(INT_MAX)}; + strideK = std::vector{1, static_cast(options.mHeadDimQk)}; + tileShapeKv[1] = 1; + } + // Build tma descriptor for K. params.tmaK_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeKv, shapeK, strideK, tileShapeKv, const_cast(kPtr), @@ -720,6 +730,11 @@ struct KernelParams { params.mStartTokenIdxSfO = options.mSfStartTokenIdx; params.mScaleSfKv = options.mScaleSfKv; params.ptrSoftmaxStats = options.softmaxStatsPtr; + // The sparseMlaTopK needs to be a multiple of 4 as we use 16B cpAsync instructions for the + // indices. + FLASHINFER_CHECK(!options.mSparseMla || (options.mSparseMlaTopK % 4) == 0, + "SparseMlaTopK must be a multiple of 4"); + params.mSparseMlaTopK = options.mSparseMlaTopK; // TODO: Integrate trtllm block-sparse attention kernels when needed. params.mUseBlockSparseAttention = false; return params; diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 4e2b7aefe0..dd0002ff06 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -414,13 +414,13 @@ def _test_trtllm_batch_prefill( max_q_len, max_kv_len, device_scale, + head_dim, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] != 10: pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") # Set up test parameters torch.manual_seed(0) - head_dim = 128 # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size @@ -639,6 +639,7 @@ def _test_trtllm_batch_prefill( @pytest.mark.parametrize("enable_sink", [True, False]) @pytest.mark.parametrize("max_q_len", [511]) @pytest.mark.parametrize("max_kv_len", [2047]) +@pytest.mark.parametrize("head_dim", [128, 256]) def test_trtllm_batch_prefill( kv_layout, batch_size, @@ -653,6 +654,7 @@ def test_trtllm_batch_prefill( enable_sink, max_q_len, max_kv_len, + head_dim, ): _test_trtllm_batch_prefill( kv_layout, @@ -669,6 +671,7 @@ def test_trtllm_batch_prefill( max_q_len, max_kv_len, kv_dtype == "fp8", + head_dim, ) @@ -690,6 +693,7 @@ def test_trtllm_batch_prefill( @pytest.mark.parametrize("enable_sink", [False]) @pytest.mark.parametrize("max_q_len", [8192]) @pytest.mark.parametrize("max_kv_len", [8192]) +@pytest.mark.parametrize("head_dim", [128, 256]) def test_trtllm_batch_prefill_bs1( kv_layout, batch_size, @@ -704,6 +708,7 @@ def test_trtllm_batch_prefill_bs1( enable_sink, max_q_len, max_kv_len, + head_dim, ): _test_trtllm_batch_prefill( kv_layout, @@ -720,6 +725,7 @@ def test_trtllm_batch_prefill_bs1( max_q_len, max_kv_len, False, + head_dim, ) @@ -1202,7 +1208,6 @@ def test_trtllm_batch_decode_head_dim_256( device_scale, ): # Small number of test cases for head_dim = 256 - pytest.xfail("trtllm-gen decode gets incorrect output with head_dim = 256") _test_trtllm_batch_decode( "trtllm-gen", kv_layout, diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index d56be03eb6..d71e8cb386 100644 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -1,5 +1,6 @@ import pytest import torch +import random import flashinfer from flashinfer.utils import get_compute_capability @@ -9,6 +10,205 @@ workspace_size = 128 * 1024 * 1024 +def generate_sparse_indices( + batch_size: int, + q_len_per_request: int, + seq_lens: torch.Tensor, + topk: int, + page_size: int, + block_tables: torch.Tensor, + device: str, + seed: int = 42, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Generate sparse attention indices for MLA. + + Returns: + abs_indices: [batch_size, q_len_per_request, topk] - absolute positions in sequence + indices_in_kvcache: [batch_size, q_len_per_request, topk] - positions in blocked KV cache + """ + random.seed(seed) + torch.manual_seed(seed) + + block_tables_cpu = block_tables.cpu() + seq_lens_cpu = seq_lens.cpu() + + abs_indices = torch.empty( + batch_size, q_len_per_request, topk, dtype=torch.int32, device="cpu" + ) + indices_in_kvcache = torch.empty( + batch_size, q_len_per_request, topk, dtype=torch.int32, device="cpu" + ) + + for i in range(batch_size): + cur_seq_len = int(seq_lens_cpu[i].item()) + # Generate indices for each query position + for j in range(q_len_per_request): + # Randomly sample topk positions from the sequence + if cur_seq_len > 0: + # cur_abs_indices = torch.randperm(cur_seq_len, device="cpu")[:topk] + cur_abs_indices = torch.arange(0, topk, device="cpu") + # Convert to blocked indices + cur_blocked_indices = block_tables_cpu[ + i, cur_abs_indices // page_size + ] * page_size + (cur_abs_indices % page_size) + else: + cur_abs_indices = torch.empty(0, dtype=torch.int32, device="cpu") + cur_blocked_indices = torch.empty(0, dtype=torch.int32, device="cpu") + + # Pad with -1 if we don't have enough indices + if len(cur_abs_indices) < topk: + pad_len = topk - len(cur_abs_indices) + cur_abs_indices = torch.cat( + [ + cur_abs_indices, + torch.full((pad_len,), -1, device="cpu", dtype=torch.int32), + ] + ) + cur_blocked_indices = torch.cat( + [ + cur_blocked_indices, + torch.full((pad_len,), -1, device="cpu", dtype=torch.int32), + ] + ) + + # Randomly permute the indices + # perm = torch.randperm(topk, device="cpu") + perm = torch.arange(0, topk, device="cpu") + cur_abs_indices = cur_abs_indices[perm] + cur_blocked_indices = cur_blocked_indices[perm] + + abs_indices[i, j, :] = cur_abs_indices + indices_in_kvcache[i, j, :] = cur_blocked_indices + + return abs_indices.to(device), indices_in_kvcache.to(device) + + +def sparse_mla_reference_torch( + cache_seqlens: torch.Tensor, # [batch_size] + block_table: torch.Tensor, # [batch_size, ?] + q: torch.Tensor, # [batch_size, s_q, h_q, d] + blocked_k: torch.Tensor, # [?, block_size, d] + blocked_v: torch.Tensor, # [?, block_size, dv] + page_size: int, + is_causal: bool, + sm_scale: float, + indices: torch.Tensor | None = None, # [batch_size, s_q, topk] +) -> tuple[torch.Tensor, torch.Tensor]: + """ + A reference implementation in PyTorch for MLA attention. + Based on FlashMLA's reference implementation. + + Args: + cache_seqlens: Sequence lengths for each batch [batch_size] + block_table: Block table mapping [batch_size, max_num_blocks] + q: Query tensor [batch_size, s_q, h_q, d] + blocked_k: Blocked key cache [num_blocks, block_size, d] + blocked_v: Blocked value cache [num_blocks, block_size, dv] + page_size: Size of each block/page + is_causal: Whether to apply causal masking + sm_scale: Softmax scale factor + indices: Optional sparse indices [batch_size, s_q, topk] + + Returns: + output: Attention output [batch_size, s_q, h_q, dv] + lse: Log-sum-exp values [batch_size, h_q, s_q] + """ + + def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor): + """Create attention mask for top-k sparse attention.""" + mask = torch.zeros(s_q, s_k, dtype=torch.bool) + for i in range(s_q): + cur_indices = indices[i] + valid_indices = cur_indices[cur_indices != -1] + mask[i, valid_indices] = True + return mask + + def scaled_dot_product_attention( + batch_idx: int, + query: torch.Tensor, # [h_q, s_q, d] + key: torch.Tensor, # [s_k, d] + value: torch.Tensor, # [s_k, dv] + is_causal: bool, + sm_scale: float, + indices: torch.Tensor | None, # [s_q, topk] + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute scaled dot-product attention.""" + h_q = query.size(0) + s_q = query.shape[-2] + s_k = key.shape[-2] + dv = value.shape[-1] + + query = query.float() + key = key.float() + value = value.float() + + # Handle NaN values in KV + key[key != key] = 0.0 + value[value != value] = 0.0 + + # Compute attention weights: [h_q, s_q, s_k] + attn_weight = query @ key.transpose(-2, -1) + + # Apply masking if needed + if (is_causal and query.size(1) > 1) or indices is not None: + mask = torch.ones(s_q, s_k, dtype=torch.bool) + if is_causal: + mask = mask.tril(diagonal=s_k - s_q) + if indices is not None: + mask &= get_topk_attn_mask(s_q, s_k, indices) + attn_bias = torch.zeros(s_q, s_k, dtype=torch.float, device=query.device) + mask = mask.to(device=query.device) + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + attn_weight += attn_bias.to(query.dtype) + + # Scale and softmax + attn_weight *= sm_scale + lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + + # Compute output + output = attn_weight @ value # [h_q, s_q, dv] + + # Correct for query tokens which have no attendable keys + lonely_q_mask = lse == float("-inf") + output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0 + lse[lonely_q_mask] = float("+inf") + + return output, lse + + b, s_q, h_q, d = q.size() + dv = blocked_v.size(2) + cache_seqlens_cpu = cache_seqlens.cpu() + + out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32) + + for i in range(b): + cur_len = int(cache_seqlens_cpu[i].item()) + cur_num_blocks = (cur_len + page_size - 1) // page_size + cur_block_indices = block_table[i][0:cur_num_blocks] + + # Gather KV for this sequence + cur_key = blocked_k[cur_block_indices].view(-1, d)[:cur_len, ...] + cur_value = blocked_v[cur_block_indices].view(-1, dv)[:cur_len, ...] + + cur_out, cur_lse = scaled_dot_product_attention( + i, + q[i].transpose(0, 1), # [h_q, s_q, d] + cur_key, # [s_k, d] + cur_value, # [s_k, dv] + is_causal, + sm_scale, + indices[i] if indices is not None else None, + ) + out_ref[i] = cur_out.transpose(0, 1) + lse_ref[i] = cur_lse + + out_ref = out_ref.to(torch.bfloat16).to(q.device) + return out_ref, lse_ref + + def trtllm_batch_decode_mla( batch_size: int, scale: float, @@ -296,3 +496,258 @@ def test_dsr1_trtllm_mla( backend, MAX_SEQ_LEN, ) + + +@pytest.mark.parametrize( + "batch_size", + [1, 2, 4, 16, 32, 64, 128], +) +@pytest.mark.parametrize("scale", [1.0]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) +@pytest.mark.parametrize("q_len_per_request", [1, 2]) +@pytest.mark.parametrize("topk", [128, 2048]) +@pytest.mark.parametrize("is_varlen", [False, True]) +@pytest.mark.parametrize("enable_pdl", [True, False, None]) +@pytest.mark.parametrize("backend", ["trtllm-gen"]) +def test_trtllm_batch_decode_mla_sparse( + batch_size: int, + scale: float, + dtype: torch.dtype, + q_len_per_request: int, + topk: int, + is_varlen: bool, + enable_pdl: bool, + backend: str, +): + """ + Test sparse MLA decoding with top-k attention. + Based on FlashMLA test patterns from: + https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla_decoding.py + """ + compute_capability = get_compute_capability(torch.device(device="cuda")) + if backend == "trtllm-gen": + if compute_capability[0] != 10: + pytest.skip("TRTLLM-GEN MLA only supports SM100 and SM103 GPUs") + + torch.manual_seed(42) + device = "cuda:0" + + # Deepseek attention config (decode-MLA) + num_q_heads = 128 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + kv_lora_rank = 512 + + # Fixed or variable sequence lengths + if is_varlen: + # Variable sequence lengths + MAX_SEQ_LEN = 4096 + seq_lens = [ + max( + topk, + int( + torch.distributions.Normal(MAX_SEQ_LEN, MAX_SEQ_LEN / 2) + .sample() + .item() + ), + ) + for _ in range(batch_size) + ] + seq_lens[-1] = MAX_SEQ_LEN # Ensure at least one max length + seq_lens = [min(s, MAX_SEQ_LEN) for s in seq_lens] + else: + # Fixed sequence length + MAX_SEQ_LEN = 4096 + seq_lens = [MAX_SEQ_LEN] * batch_size + + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) + + # Initialize query tensors + query = torch.randn( + batch_size, + q_len_per_request, + num_q_heads, + kv_lora_rank + qk_rope_head_dim, + device=device, + ) + query.clamp_(min=-1.0, max=1.0) + query = query.to(dtype) + + # Calculate blocks needed + page_size = 32 + blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size + max_num_blocks_per_seq = blocks_per_seq.max().item() + total_blocks_needed = int(blocks_per_seq.sum().item()) + + # Generate random but unique block IDs + all_block_ids = torch.randperm(total_blocks_needed, device=device) + + # Create block tables + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), dtype=torch.int, device=device + ) + block_id = 0 + for i in range(batch_size): + num_blocks_needed = int(blocks_per_seq[i].item()) + block_tables[i, :num_blocks_needed] = all_block_ids[ + block_id : block_id + num_blocks_needed + ] + block_id += num_blocks_needed + + # Create KV cache + num_blocks = total_blocks_needed + kv_cache = torch.randn( + size=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim), + device=device, + ) + kv_cache.clamp_(min=-1.0, max=1.0) + kv_cache = kv_cache.to(dtype) + + # Generate sparse indices + abs_indices, indices_in_kvcache = generate_sparse_indices( + batch_size, + q_len_per_request, + seq_lens_tensor, + topk, + page_size, + block_tables, + device, + ) + + # Mask unused KV cache entries with NaN for correctness checking + kv_cache_ref = kv_cache.clone() + if dtype == torch.float8_e4m3fn: + kv_cache_ref = kv_cache_ref.to(torch.bfloat16) + + # Mark all positions as NaN initially + all_indices = indices_in_kvcache.flatten().tolist() + all_indices = list(set(all_indices)) + if -1 in all_indices: + all_indices.remove(-1) + + # Only used indices should be valid + kv_cache_flat = kv_cache_ref.view(-1, kv_lora_rank + qk_rope_head_dim) + used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu") + used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True + kv_cache_flat[~used_mask] = float("0") + + # Allocate workspace buffers + global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.empty( + workspace_size, dtype=torch.int8, device=device + ) + if global_trtllm_gen_fmha_workspace_buffer is None: + global_trtllm_gen_fmha_workspace_buffer = torch.zeros( + workspace_size, dtype=torch.int8, device=device + ) + workspace_buffer = global_trtllm_gen_fmha_workspace_buffer + # workspace_buffer_ref = global_workspace_buffer + + # Run sparse decode-MLA + query_input = query.clone() + output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + query=query_input, + kv_cache=kv_cache.unsqueeze(1), + workspace_buffer=workspace_buffer, + qk_nope_head_dim=qk_nope_head_dim, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=indices_in_kvcache, + seq_lens=seq_lens_tensor, + max_seq_len=max_seq_len, + sparse_mla_top_k=topk, + bmm1_scale=scale / ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5), + bmm2_scale=1.0, + enable_pdl=enable_pdl, + backend=backend, + ) + + # Check workspace buffer is zeroed + assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() + + # For now, just check that output has correct shape and no NaNs + expected_shape = (batch_size, q_len_per_request, num_q_heads, kv_lora_rank) + assert output.shape == expected_shape, ( + f"Output shape {output.shape} != {expected_shape}" + ) + + # Check for NaNs + if dtype != torch.float8_e4m3fn: + assert not torch.isnan(output).any(), "Output contains NaN values" + + # Generate reference output using PyTorch implementation + query_ref = query.clone() + if dtype == torch.float8_e4m3fn: + query_ref = query_ref.to(torch.bfloat16) + + # Split kv_cache into K and V components + # K uses full dimension (kv_lora_rank + qk_rope_head_dim) + # V uses only kv_lora_rank dimension + blocked_k = kv_cache_ref # [num_blocks, page_size, kv_lora_rank + qk_rope_head_dim] + blocked_v = kv_cache_ref[ + ..., :kv_lora_rank + ] # [num_blocks, page_size, kv_lora_rank] + + sm_scale = scale / ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5) + + out_ref, lse_ref = sparse_mla_reference_torch( + cache_seqlens=seq_lens_tensor, + block_table=block_tables, + q=query_ref, + blocked_k=blocked_k, + blocked_v=blocked_v, + page_size=page_size, + is_causal=True, # Cover cases where number of attendable kv values are less than topk + sm_scale=sm_scale, + indices=abs_indices, + ) + + # Compare outputs + assert not torch.isnan(output).any(), "Kernel output contains NaN values" + assert not torch.isnan(out_ref).any(), "Reference output contains NaN values" + + if dtype == torch.float8_e4m3fn: + # FP8 has lower precision, use more relaxed tolerances + try: + torch.testing.assert_close( + output.float(), + out_ref.float(), + rtol=1e-1, + atol=1e-1, + ) + except AssertionError as e: + # Calculate element-wise differences for debugging + diff = torch.abs(output.float() - out_ref.float()) + max_diff = diff.max().item() + mean_diff = diff.mean().item() + print(f"Max difference: {max_diff}, Mean difference: {mean_diff}") + print(f"Output sample: {output[0, 0, 0, :8]}") + print(f"Reference sample: {out_ref[0, 0, 0, :8]}") + raise e + else: + # BF16 should have better precision + try: + torch.testing.assert_close( + output.float(), + out_ref.float(), + rtol=2e-2, + atol=8e-4, + ) + except AssertionError as e: + # Calculate element-wise differences for debugging + diff = torch.abs(output.float() - out_ref.float()) + max_diff = diff.max().item() + mean_diff = diff.mean().item() + print(f"Max difference: {max_diff}, Mean difference: {mean_diff}") + print(f"Output sample: {output[0, 0, 0, :8]}") + print(f"Output sample: {output[0, 1, 0, :8]}") + print(f"Reference sample: {out_ref[0, 0, 0, :8]}") + print(f"Reference sample: {out_ref[0, 1, 0, :8]}") + raise e + + print( + f"Sparse MLA test passed: batch_size={batch_size}, topk={topk}, " + f"q_len={q_len_per_request}, varlen={is_varlen}, dtype={dtype}" + ) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py b/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py index 60933cf89b..772ceead0b 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py @@ -250,7 +250,7 @@ def test_mnnvl_allreduce_custom_communicator( available_gpus = torch.cuda.device_count() if world_size > available_gpus: - raise ValueError( + pytest.skip( f"world_size {world_size} is greater than available_gpus {available_gpus}" ) print(f"Running test for world_size={world_size}")