From 034bb83dfbfb5845c1886e9810769e864fc2c9d7 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Mon, 6 Oct 2025 18:15:39 +0000 Subject: [PATCH 1/3] Test cases for trtllm attention errors --- tests/attention/test_trtllm_gen_attention.py | 48 ++++++-------------- 1 file changed, 15 insertions(+), 33 deletions(-) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 790eb02d86..99f2c593e3 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -529,41 +529,18 @@ def test_trtllm_batch_prefill( @pytest.mark.parametrize( "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", [ - (4, 1, 16, 2, 1), - (4, 1, 32, 2, 5), - (4, 2, 64, 2, 5), - (4, 3, 32, 2, 5), - (4, 3, 64, 2, 1), - (4, 4, 64, 4, 1), - (4, 5, 64, 4, 8), - (128, 1, 64, 2, 5), - (128, 2, 32, 4, 1), - (128, 3, 16, 4, 8), - (128, 4, 16, 2, 5), - (128, 5, 16, 2, 5), - (256, 1, 64, 4, 8), - (256, 2, 16, 2, 8), - (256, 3, 64, 4, 5), - (256, 4, 32, 2, 8), - (256, 5, 32, 2, 1), + (1, 1, 16, 8, 8), ], ) -@pytest.mark.parametrize("window_left", [-1, 127]) +@pytest.mark.parametrize("window_left", [-1]) @pytest.mark.parametrize( "q_dtype,kv_dtype,o_dtype", [ - ("bf16", "bf16", "bf16"), - ("fp16", "fp16", "fp16"), - ("bf16", "fp8", "bf16"), - ("fp16", "fp8", "fp16"), - ("fp8", "fp8", "bf16"), - ("fp8", "fp8", "fp16"), ("fp8", "fp8", "fp8"), - ("fp8", "fp8", "nvfp4"), ], ) -@pytest.mark.parametrize("enable_pdl", [True, False, None]) -@pytest.mark.parametrize("enable_sink", [True, False]) +@pytest.mark.parametrize("enable_pdl", [None]) +@pytest.mark.parametrize("enable_sink", [False]) def test_trtllm_batch_decode( kv_layout, batch_size, @@ -589,7 +566,7 @@ def test_trtllm_batch_decode( # Set up test parameters torch.manual_seed(0) head_dim = 128 - MAX_IN_KV_LEN = 110 + MAX_IN_KV_LEN = 8192 # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size @@ -805,11 +782,16 @@ def test_trtllm_batch_decode( assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() -@pytest.mark.parametrize("batch_size", [4, 128, 256]) -@pytest.mark.parametrize("s_qo", [32, 64, 87]) -@pytest.mark.parametrize("s_kv", [32, 64, 87]) -@pytest.mark.parametrize("num_kv_heads", [16, 32]) -@pytest.mark.parametrize("head_grp_size", [1, 5, 8]) +# @pytest.mark.parametrize("batch_size", [4, 128, 256]) +# @pytest.mark.parametrize("s_qo", [32, 64, 87]) +# @pytest.mark.parametrize("s_kv", [32, 64, 87]) +# @pytest.mark.parametrize("num_kv_heads", [16, 32]) +# @pytest.mark.parametrize("head_grp_size", [1, 5, 8]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("s_qo", [1024]) +@pytest.mark.parametrize("s_kv", [1024]) +@pytest.mark.parametrize("num_kv_heads", [128]) +@pytest.mark.parametrize("head_grp_size", [1]) @pytest.mark.parametrize("causal", [True, False]) def test_trtllm_gen_prefill_deepseek( batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal From 5fd31a5042e72cf6144fbbc367e799d522bf59af Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Mon, 6 Oct 2025 18:25:03 +0000 Subject: [PATCH 2/3] Cleaning up comment --- tests/attention/test_trtllm_gen_attention.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 99f2c593e3..94d751a0e0 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -782,11 +782,6 @@ def test_trtllm_batch_decode( assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() -# @pytest.mark.parametrize("batch_size", [4, 128, 256]) -# @pytest.mark.parametrize("s_qo", [32, 64, 87]) -# @pytest.mark.parametrize("s_kv", [32, 64, 87]) -# @pytest.mark.parametrize("num_kv_heads", [16, 32]) -# @pytest.mark.parametrize("head_grp_size", [1, 5, 8]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("s_qo", [1024]) @pytest.mark.parametrize("s_kv", [1024]) From abbd125ad579488c45215a5172c2d4469761414c Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 9 Oct 2025 18:06:41 +0000 Subject: [PATCH 3/3] Add bs1 as separate test case --- tests/attention/test_trtllm_gen_attention.py | 111 +++++++++++++++++-- 1 file changed, 100 insertions(+), 11 deletions(-) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 94d751a0e0..80853c7dbf 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -529,18 +529,42 @@ def test_trtllm_batch_prefill( @pytest.mark.parametrize( "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", [ - (1, 1, 16, 8, 8), + (4, 1, 16, 2, 1), + (4, 1, 32, 2, 5), + (4, 2, 64, 2, 5), + (4, 3, 32, 2, 5), + (4, 3, 64, 2, 1), + (4, 4, 64, 4, 1), + (4, 5, 64, 4, 8), + (128, 1, 64, 2, 5), + (128, 2, 32, 4, 1), + (128, 3, 16, 4, 8), + (128, 4, 16, 2, 5), + (128, 5, 16, 2, 5), + (256, 1, 64, 4, 8), + (256, 2, 16, 2, 8), + (256, 3, 64, 4, 5), + (256, 4, 32, 2, 8), + (256, 5, 32, 2, 1), ], ) -@pytest.mark.parametrize("window_left", [-1]) +@pytest.mark.parametrize("window_left", [-1, 127]) @pytest.mark.parametrize( "q_dtype,kv_dtype,o_dtype", [ + ("bf16", "bf16", "bf16"), + ("fp16", "fp16", "fp16"), + ("bf16", "fp8", "bf16"), + ("fp16", "fp8", "fp16"), + ("fp8", "fp8", "bf16"), + ("fp8", "fp8", "fp16"), ("fp8", "fp8", "fp8"), + ("fp8", "fp8", "nvfp4"), ], ) -@pytest.mark.parametrize("enable_pdl", [None]) -@pytest.mark.parametrize("enable_sink", [False]) +@pytest.mark.parametrize("enable_pdl", [True, False, None]) +@pytest.mark.parametrize("enable_sink", [True, False]) +@pytest.mark.parametrize("max_in_kv_len", [110]) def test_trtllm_batch_decode( kv_layout, batch_size, @@ -554,6 +578,7 @@ def test_trtllm_batch_decode( kv_dtype, enable_pdl, enable_sink, + max_in_kv_len, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] != 10: @@ -566,12 +591,11 @@ def test_trtllm_batch_decode( # Set up test parameters torch.manual_seed(0) head_dim = 128 - MAX_IN_KV_LEN = 8192 # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode( - batch_size, q_len_per_req, MAX_IN_KV_LEN + batch_size, q_len_per_req, max_in_kv_len ) # Create query tensor and related data @@ -782,11 +806,61 @@ def test_trtllm_batch_decode( assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() -@pytest.mark.parametrize("batch_size", [1]) -@pytest.mark.parametrize("s_qo", [1024]) -@pytest.mark.parametrize("s_kv", [1024]) -@pytest.mark.parametrize("num_kv_heads", [128]) -@pytest.mark.parametrize("head_grp_size", [1]) +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (1, 1, 16, 8, 8), + ], +) +@pytest.mark.parametrize("window_left", [-1]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("fp8", "fp8", "fp8"), + ], +) +@pytest.mark.parametrize("enable_pdl", [None]) +@pytest.mark.parametrize("enable_sink", [False]) +@pytest.mark.parametrize("max_in_kv_len", [8192]) +def test_trtllm_batch_decode_bs1( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, +): + pytest.xfail("trtllm-gen decode gets incorrect output with bs1") + test_trtllm_batch_decode( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + ) + + +@pytest.mark.parametrize("batch_size", [4, 128, 256]) +@pytest.mark.parametrize("s_qo", [32, 64, 87]) +@pytest.mark.parametrize("s_kv", [32, 64, 87]) +@pytest.mark.parametrize("num_kv_heads", [16, 32]) +@pytest.mark.parametrize("head_grp_size", [1, 5, 8]) @pytest.mark.parametrize("causal", [True, False]) def test_trtllm_gen_prefill_deepseek( batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal @@ -915,6 +989,21 @@ def test_trtllm_gen_prefill_deepseek( assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("s_qo", [1024]) +@pytest.mark.parametrize("s_kv", [1024]) +@pytest.mark.parametrize("num_kv_heads", [128]) +@pytest.mark.parametrize("head_grp_size", [1]) +@pytest.mark.parametrize("causal", [True, False]) +def test_trtllm_gen_prefill_deepseek_bs1( + batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal +): + pytest.xfail("trtllm-gen prefill triggers an IMA with bs1") + test_trtllm_gen_prefill_deepseek( + batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal + ) + + if __name__ == "__main__": test_trtllm_batch_prefill("HND", 128, 32, 2, 5, -1, "fp16", "fp16", "fp16", False) test_trtllm_batch_decode("HND", 256, 3, 64, 4, 5, -1, "fp8", "fp8", "fp8", True)