Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions tests/test_trtllm_gen_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

GPU_DEVICE = "cuda:0"

global_workspace_buffer = None
global_workspace_buffer = None # can.be empty initialized
global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized
workspace_size = 128 * 1024 * 1024


Expand Down Expand Up @@ -320,16 +321,21 @@ def test_trtllm_batch_prefill(
else None
)

global global_workspace_buffer
global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.zeros(
global_workspace_buffer = torch.empty(
workspace_size, dtype=torch.int8, device=GPU_DEVICE
)
workspace_buffer = global_workspace_buffer
if global_trtllm_gen_fmha_workspace_buffer is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this only make it zero the first time, later execution won't enter this branch, is it expected?

And is it possible to reuse with ref workspace buffer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's expected.

No. trtllm-gen kernel and fi kernel should re-use individual workspace as fi kernel does not require zero-init workspace.

global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
workspace_size, dtype=torch.int8, device=GPU_DEVICE
)
workspace_buffer_ref = global_workspace_buffer
workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
Comment on lines +324 to +334
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for initializing workspace buffers is repeated in test_trtllm_batch_decode (lines 514-524) and test_trtllm_gen_prefill_deepseek (lines 713-723) within this file. A similar block also exists in tests/test_trtllm_gen_mla.py.

To improve maintainability and reduce code duplication, consider extracting this logic into a shared helper function in a test utility file (e.g., tests/conftest.py or a new tests/utils.py) and importing it where needed.

A shared helper could look like this:

from typing import Tuple

def get_workspace_buffers(device: str) -> Tuple[torch.Tensor, torch.Tensor]:
    """Initializes and returns the workspace buffers."""
    global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
    if global_workspace_buffer is None:
        global_workspace_buffer = torch.empty(
            workspace_size, dtype=torch.int8, device=device
        )
    if global_trtllm_gen_fmha_workspace_buffer is None:
        global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
            workspace_size, dtype=torch.int8, device=device
        )
    return global_trtllm_gen_fmha_workspace_buffer, global_workspace_buffer

Then you can replace the repeated blocks in all test functions with a single call to this helper.

Suggested change
global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.zeros(
global_workspace_buffer = torch.empty(
workspace_size, dtype=torch.int8, device=GPU_DEVICE
)
workspace_buffer = global_workspace_buffer
if global_trtllm_gen_fmha_workspace_buffer is None:
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
workspace_size, dtype=torch.int8, device=GPU_DEVICE
)
workspace_buffer_ref = global_workspace_buffer
workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
workspace_buffer, workspace_buffer_ref = get_workspace_buffers(GPU_DEVICE)


# Run reference wrapper
wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
workspace_buffer_ref, kv_layout
)
plan_params = {
"qo_indptr": q_indptr,
Expand Down Expand Up @@ -372,6 +378,9 @@ def test_trtllm_batch_prefill(
o_sf_vec_size=o_sf_vec_size,
enable_pdl=enable_pdl,
)
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
# note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()

if o_dtype == "nvfp4":
output, output_ref = unpack_compare_nvfp4(
Expand Down Expand Up @@ -414,6 +423,9 @@ def test_trtllm_batch_prefill(
torch.testing.assert_close(
output.float(), output_wrapper.float(), rtol=1e-1, atol=1e-1
)
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
# note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()


@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND
Expand Down Expand Up @@ -505,16 +517,21 @@ def test_trtllm_batch_decode(
else None
)

global global_workspace_buffer
global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.zeros(
global_workspace_buffer = torch.empty(
workspace_size, dtype=torch.int8, device=GPU_DEVICE
)
if global_trtllm_gen_fmha_workspace_buffer is None:
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
workspace_size, dtype=torch.int8, device=GPU_DEVICE
)
workspace_buffer = global_workspace_buffer
workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
workspace_buffer_ref = global_workspace_buffer

# Run reference wrapper
wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout, use_tensor_cores=True
workspace_buffer_ref, kv_layout, use_tensor_cores=True
)
plan_params = {
"indptr": kv_indptr,
Expand All @@ -535,7 +552,7 @@ def test_trtllm_batch_decode(
if q_len_per_req > 1:
# hide the output_ref from decode wrapper for speculative decoding test
wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
workspace_buffer_ref, kv_layout
)
plan_params_prefill = {
"qo_indptr": q_indptr,
Expand Down Expand Up @@ -576,6 +593,9 @@ def test_trtllm_batch_decode(
enable_pdl=enable_pdl,
q_len_per_req=q_len_per_req,
)
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
# note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()

if o_dtype == "nvfp4":
output, output_ref = unpack_compare_nvfp4(
Expand Down Expand Up @@ -648,6 +668,9 @@ def test_trtllm_batch_decode(
atol=1e-1,
max_mismatched_elements=5,
)
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
# note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()


@pytest.mark.parametrize("batch_size", [4, 128, 256])
Expand Down Expand Up @@ -699,7 +722,17 @@ def test_trtllm_gen_prefill_deepseek(
# Initialize scale
scale = float(1.0 / (head_dim_qk**0.5))

workspace_buffer = torch.empty(workspace_size, dtype=torch.int8, device=device)
global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
workspace_size, dtype=torch.int8, device=device
)
if global_trtllm_gen_fmha_workspace_buffer is None:
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
workspace_size, dtype=torch.int8, device=device
)
workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
workspace_buffer_ref = global_workspace_buffer

qo_indptr = torch.cat(
[
Expand All @@ -722,7 +755,7 @@ def test_trtllm_gen_prefill_deepseek(
).int()

wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
torch.zeros(workspace_size, device="cuda", dtype=torch.uint8),
workspace_buffer_ref,
kv_layout="NHD",
backend="cutlass",
)
Expand Down Expand Up @@ -775,6 +808,9 @@ def test_trtllm_gen_prefill_deepseek(
atol=1e-3,
rtol=1e-3,
)
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
# note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()


if __name__ == "__main__":
Expand Down
18 changes: 13 additions & 5 deletions tests/test_trtllm_gen_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import flashinfer

global_workspace_buffer = None
global_workspace_buffer = None # can.be empty initialized
global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized
workspace_size = 128 * 1024 * 1024


Expand Down Expand Up @@ -96,12 +97,17 @@ def test_trtllm_batch_decode_mla(

# Allocate workspace buffer
# todo(Yingyi): calculate the actual size of workspace buffer
global global_workspace_buffer
global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.zeros(
global_workspace_buffer = torch.empty(
workspace_size, dtype=torch.int8, device=device
)
workspace_buffer = global_workspace_buffer
if global_trtllm_gen_fmha_workspace_buffer is None:
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
workspace_size, dtype=torch.int8, device=device
)
workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
workspace_buffer_ref = global_workspace_buffer
Comment on lines +100 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This workspace buffer initialization logic is duplicated from tests/test_trtllm_gen_attention.py. To avoid code duplication across test files and improve maintainability, consider creating a shared helper function in a test utility file (e.g., tests/conftest.py or a new tests/utils.py) and importing it in both files.

A shared helper could look like this:

from typing import Tuple

def get_workspace_buffers(device: str) -> Tuple[torch.Tensor, torch.Tensor]:
    """Initializes and returns the workspace buffers."""
    global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
    if global_workspace_buffer is None:
        global_workspace_buffer = torch.empty(
            workspace_size, dtype=torch.int8, device=device
        )
    if global_trtllm_gen_fmha_workspace_buffer is None:
        global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
            workspace_size, dtype=torch.int8, device=device
        )
    return global_trtllm_gen_fmha_workspace_buffer, global_workspace_buffer

After creating the shared helper, you can replace this block with a single call.

Suggested change
global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.zeros(
global_workspace_buffer = torch.empty(
workspace_size, dtype=torch.int8, device=device
)
workspace_buffer = global_workspace_buffer
if global_trtllm_gen_fmha_workspace_buffer is None:
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
workspace_size, dtype=torch.int8, device=device
)
workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
workspace_buffer_ref = global_workspace_buffer
workspace_buffer, workspace_buffer_ref = get_workspace_buffers(device)


bmm1_log2_scale_tensor = (
torch.tensor(
Expand Down Expand Up @@ -135,12 +141,14 @@ def test_trtllm_batch_decode_mla(
bmm2_scale_tensor=bmm2_scale_tensor,
enable_pdl=enable_pdl,
)
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
# note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()

# Run reference attention and align output
sm_scale = scale / (
(128 + 64) ** 0.5
) # use head dimension before matrix absorption
workspace_buffer_ref = torch.empty(workspace_size, dtype=torch.int8, device=device)
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
workspace_buffer_ref,
backend="fa2",
Expand Down