From 0ae5f551e791498099a3248f2cef6a8064a218cd Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Fri, 5 Sep 2025 15:13:25 -0400 Subject: [PATCH 1/3] init --- tests/test_trtllm_gen_attention.py | 43 ++++++++++++++++++++++-------- tests/test_trtllm_gen_mla.py | 15 +++++++---- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/tests/test_trtllm_gen_attention.py b/tests/test_trtllm_gen_attention.py index f42d418d04..81d4bf63cd 100755 --- a/tests/test_trtllm_gen_attention.py +++ b/tests/test_trtllm_gen_attention.py @@ -17,7 +17,8 @@ GPU_DEVICE = "cuda:0" -global_workspace_buffer = None +global_workspace_buffer = None # can.be empty initialized +global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized workspace_size = 128 * 1024 * 1024 @@ -320,16 +321,21 @@ def test_trtllm_batch_prefill( else None ) - global global_workspace_buffer + global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer if global_workspace_buffer is None: - global_workspace_buffer = torch.zeros( + global_workspace_buffer = torch.empty( workspace_size, dtype=torch.int8, device=GPU_DEVICE ) - workspace_buffer = global_workspace_buffer + if global_trtllm_gen_fmha_workspace_buffer is None: + global_trtllm_gen_fmha_workspace_buffer = torch.zeros( + workspace_size, dtype=torch.int8, device=GPU_DEVICE + ) + workspace_buffer_ref = global_workspace_buffer + workspace_buffer = global_trtllm_gen_fmha_workspace_buffer # Run reference wrapper wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + workspace_buffer_ref, kv_layout ) plan_params = { "qo_indptr": q_indptr, @@ -505,16 +511,21 @@ def test_trtllm_batch_decode( else None ) - global global_workspace_buffer + global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer if global_workspace_buffer is None: - global_workspace_buffer = torch.zeros( + global_workspace_buffer = torch.empty( + workspace_size, dtype=torch.int8, device=GPU_DEVICE + ) + if global_trtllm_gen_fmha_workspace_buffer is None: + global_trtllm_gen_fmha_workspace_buffer = torch.zeros( workspace_size, dtype=torch.int8, device=GPU_DEVICE ) - workspace_buffer = global_workspace_buffer + workspace_buffer = global_trtllm_gen_fmha_workspace_buffer + workspace_buffer_ref = global_workspace_buffer # Run reference wrapper wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, use_tensor_cores=True + workspace_buffer_ref, kv_layout, use_tensor_cores=True ) plan_params = { "indptr": kv_indptr, @@ -699,7 +710,17 @@ def test_trtllm_gen_prefill_deepseek( # Initialize scale scale = float(1.0 / (head_dim_qk**0.5)) - workspace_buffer = torch.empty(workspace_size, dtype=torch.int8, device=device) + global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.empty( + workspace_size, dtype=torch.int8, device=device + ) + if global_trtllm_gen_fmha_workspace_buffer is None: + global_trtllm_gen_fmha_workspace_buffer = torch.zeros( + workspace_size, dtype=torch.int8, device=device + ) + workspace_buffer = global_trtllm_gen_fmha_workspace_buffer + workspace_buffer_ref = global_workspace_buffer qo_indptr = torch.cat( [ @@ -722,7 +743,7 @@ def test_trtllm_gen_prefill_deepseek( ).int() wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - torch.zeros(workspace_size, device="cuda", dtype=torch.uint8), + workspace_buffer_ref, kv_layout="NHD", backend="cutlass", ) diff --git a/tests/test_trtllm_gen_mla.py b/tests/test_trtllm_gen_mla.py index e73da75337..55297861e0 100644 --- a/tests/test_trtllm_gen_mla.py +++ b/tests/test_trtllm_gen_mla.py @@ -5,7 +5,8 @@ import flashinfer -global_workspace_buffer = None +global_workspace_buffer = None # can.be empty initialized +global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized workspace_size = 128 * 1024 * 1024 @@ -96,12 +97,17 @@ def test_trtllm_batch_decode_mla( # Allocate workspace buffer # todo(Yingyi): calculate the actual size of workspace buffer - global global_workspace_buffer + global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer if global_workspace_buffer is None: - global_workspace_buffer = torch.zeros( + global_workspace_buffer = torch.empty( workspace_size, dtype=torch.int8, device=device ) - workspace_buffer = global_workspace_buffer + if global_trtllm_gen_fmha_workspace_buffer is None: + global_trtllm_gen_fmha_workspace_buffer = torch.zeros( + workspace_size, dtype=torch.int8, device=device + ) + workspace_buffer = global_trtllm_gen_fmha_workspace_buffer + workspace_buffer_ref = global_workspace_buffer bmm1_log2_scale_tensor = ( torch.tensor( @@ -140,7 +146,6 @@ def test_trtllm_batch_decode_mla( sm_scale = scale / ( (128 + 64) ** 0.5 ) # use head dimension before matrix absorption - workspace_buffer_ref = torch.empty(workspace_size, dtype=torch.int8, device=device) wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( workspace_buffer_ref, backend="fa2", From 3d16e141583e94187dac4ad260f85fc27ef09ea2 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Fri, 5 Sep 2025 17:14:04 -0400 Subject: [PATCH 2/3] upd --- tests/test_trtllm_gen_attention.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/test_trtllm_gen_attention.py b/tests/test_trtllm_gen_attention.py index 81d4bf63cd..7cf172465d 100755 --- a/tests/test_trtllm_gen_attention.py +++ b/tests/test_trtllm_gen_attention.py @@ -378,6 +378,9 @@ def test_trtllm_batch_prefill( o_sf_vec_size=o_sf_vec_size, enable_pdl=enable_pdl, ) + # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero + # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future + assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() if o_dtype == "nvfp4": output, output_ref = unpack_compare_nvfp4( @@ -420,6 +423,9 @@ def test_trtllm_batch_prefill( torch.testing.assert_close( output.float(), output_wrapper.float(), rtol=1e-1, atol=1e-1 ) + # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero + # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future + assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() @pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND @@ -546,7 +552,7 @@ def test_trtllm_batch_decode( if q_len_per_req > 1: # hide the output_ref from decode wrapper for speculative decoding test wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + workspace_buffer_ref, kv_layout ) plan_params_prefill = { "qo_indptr": q_indptr, @@ -587,6 +593,9 @@ def test_trtllm_batch_decode( enable_pdl=enable_pdl, q_len_per_req=q_len_per_req, ) + # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero + # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future + assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() if o_dtype == "nvfp4": output, output_ref = unpack_compare_nvfp4( @@ -659,6 +668,9 @@ def test_trtllm_batch_decode( atol=1e-1, max_mismatched_elements=5, ) + # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero + # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future + assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() @pytest.mark.parametrize("batch_size", [4, 128, 256]) @@ -796,6 +808,9 @@ def test_trtllm_gen_prefill_deepseek( atol=1e-3, rtol=1e-3, ) + # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero + # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future + assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() if __name__ == "__main__": From cde21f7283fb50c77ae5b289c98577a47374dcd8 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Fri, 5 Sep 2025 17:14:45 -0400 Subject: [PATCH 3/3] upd --- tests/test_trtllm_gen_mla.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_trtllm_gen_mla.py b/tests/test_trtllm_gen_mla.py index 55297861e0..8070877d30 100644 --- a/tests/test_trtllm_gen_mla.py +++ b/tests/test_trtllm_gen_mla.py @@ -141,6 +141,9 @@ def test_trtllm_batch_decode_mla( bmm2_scale_tensor=bmm2_scale_tensor, enable_pdl=enable_pdl, ) + # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero + # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future + assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() # Run reference attention and align output sm_scale = scale / (