From 7f492110d5b9827a5df16a529d5bafb6f0a86fe8 Mon Sep 17 00:00:00 2001 From: Ray Wang Date: Fri, 5 Sep 2025 00:22:33 -0700 Subject: [PATCH 1/2] Fix FLOPS calculation for bench_trtllm_gen_mla.py --- benchmarks/bench_trtllm_gen_mla.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/bench_trtllm_gen_mla.py b/benchmarks/bench_trtllm_gen_mla.py index 3051608322..c59faed65d 100644 --- a/benchmarks/bench_trtllm_gen_mla.py +++ b/benchmarks/bench_trtllm_gen_mla.py @@ -30,6 +30,7 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): # Sequence lengths and block tables seq_lens = [torch.randint(1, seq_len, (1,)).item() for _ in range(batch_size)] seq_lens[-1] = seq_len + avg_seq_len = sum(seq_lens) / len(seq_lens) max_seq_len = max(seq_lens) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) @@ -111,7 +112,7 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): * batch_size * num_q_heads * (2 * kv_lora_rank + qk_rope_head_dim) - * seq_len + * avg_seq_len * q_len_per_request ) print( From bc421b565effc372b95b109a14838bab5f14c6db Mon Sep 17 00:00:00 2001 From: Ray Wang Date: Fri, 5 Sep 2025 00:54:49 -0700 Subject: [PATCH 2/2] Simplify the fix --- benchmarks/bench_trtllm_gen_mla.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_mla.py b/benchmarks/bench_trtllm_gen_mla.py index c59faed65d..b9ac3967dd 100644 --- a/benchmarks/bench_trtllm_gen_mla.py +++ b/benchmarks/bench_trtllm_gen_mla.py @@ -30,7 +30,6 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): # Sequence lengths and block tables seq_lens = [torch.randint(1, seq_len, (1,)).item() for _ in range(batch_size)] seq_lens[-1] = seq_len - avg_seq_len = sum(seq_lens) / len(seq_lens) max_seq_len = max(seq_lens) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) @@ -109,10 +108,9 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): ms = np.median(measurements) flops = ( 2 - * batch_size * num_q_heads * (2 * kv_lora_rank + qk_rope_head_dim) - * avg_seq_len + * sum(seq_lens) * q_len_per_request ) print(