-
Notifications
You must be signed in to change notification settings - Fork 561
fix: zero-init workspace buffer for trtllm-gen fmha #1643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,7 +17,8 @@ | |||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| GPU_DEVICE = "cuda:0" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| global_workspace_buffer = None | ||||||||||||||||||||||||||||||
| global_workspace_buffer = None # can.be empty initialized | ||||||||||||||||||||||||||||||
| global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized | ||||||||||||||||||||||||||||||
| workspace_size = 128 * 1024 * 1024 | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -320,16 +321,21 @@ def test_trtllm_batch_prefill( | |||||||||||||||||||||||||||||
| else None | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| global global_workspace_buffer | ||||||||||||||||||||||||||||||
| global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer | ||||||||||||||||||||||||||||||
| if global_workspace_buffer is None: | ||||||||||||||||||||||||||||||
| global_workspace_buffer = torch.zeros( | ||||||||||||||||||||||||||||||
| global_workspace_buffer = torch.empty( | ||||||||||||||||||||||||||||||
| workspace_size, dtype=torch.int8, device=GPU_DEVICE | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| workspace_buffer = global_workspace_buffer | ||||||||||||||||||||||||||||||
| if global_trtllm_gen_fmha_workspace_buffer is None: | ||||||||||||||||||||||||||||||
| global_trtllm_gen_fmha_workspace_buffer = torch.zeros( | ||||||||||||||||||||||||||||||
| workspace_size, dtype=torch.int8, device=GPU_DEVICE | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| workspace_buffer_ref = global_workspace_buffer | ||||||||||||||||||||||||||||||
| workspace_buffer = global_trtllm_gen_fmha_workspace_buffer | ||||||||||||||||||||||||||||||
|
Comment on lines
+324
to
+334
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block of code for initializing workspace buffers is repeated in To improve maintainability and reduce code duplication, consider extracting this logic into a shared helper function in a test utility file (e.g., A shared helper could look like this: from typing import Tuple
def get_workspace_buffers(device: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initializes and returns the workspace buffers."""
global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
workspace_size, dtype=torch.int8, device=device
)
if global_trtllm_gen_fmha_workspace_buffer is None:
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
workspace_size, dtype=torch.int8, device=device
)
return global_trtllm_gen_fmha_workspace_buffer, global_workspace_bufferThen you can replace the repeated blocks in all test functions with a single call to this helper.
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Run reference wrapper | ||||||||||||||||||||||||||||||
| wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( | ||||||||||||||||||||||||||||||
| workspace_buffer, kv_layout | ||||||||||||||||||||||||||||||
| workspace_buffer_ref, kv_layout | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| plan_params = { | ||||||||||||||||||||||||||||||
| "qo_indptr": q_indptr, | ||||||||||||||||||||||||||||||
|
|
@@ -372,6 +378,9 @@ def test_trtllm_batch_prefill( | |||||||||||||||||||||||||||||
| o_sf_vec_size=o_sf_vec_size, | ||||||||||||||||||||||||||||||
| enable_pdl=enable_pdl, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero | ||||||||||||||||||||||||||||||
| # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future | ||||||||||||||||||||||||||||||
| assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if o_dtype == "nvfp4": | ||||||||||||||||||||||||||||||
| output, output_ref = unpack_compare_nvfp4( | ||||||||||||||||||||||||||||||
|
|
@@ -414,6 +423,9 @@ def test_trtllm_batch_prefill( | |||||||||||||||||||||||||||||
| torch.testing.assert_close( | ||||||||||||||||||||||||||||||
| output.float(), output_wrapper.float(), rtol=1e-1, atol=1e-1 | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero | ||||||||||||||||||||||||||||||
| # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future | ||||||||||||||||||||||||||||||
| assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND | ||||||||||||||||||||||||||||||
|
|
@@ -505,16 +517,21 @@ def test_trtllm_batch_decode( | |||||||||||||||||||||||||||||
| else None | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| global global_workspace_buffer | ||||||||||||||||||||||||||||||
| global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer | ||||||||||||||||||||||||||||||
| if global_workspace_buffer is None: | ||||||||||||||||||||||||||||||
| global_workspace_buffer = torch.zeros( | ||||||||||||||||||||||||||||||
| global_workspace_buffer = torch.empty( | ||||||||||||||||||||||||||||||
| workspace_size, dtype=torch.int8, device=GPU_DEVICE | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| if global_trtllm_gen_fmha_workspace_buffer is None: | ||||||||||||||||||||||||||||||
| global_trtllm_gen_fmha_workspace_buffer = torch.zeros( | ||||||||||||||||||||||||||||||
| workspace_size, dtype=torch.int8, device=GPU_DEVICE | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| workspace_buffer = global_workspace_buffer | ||||||||||||||||||||||||||||||
| workspace_buffer = global_trtllm_gen_fmha_workspace_buffer | ||||||||||||||||||||||||||||||
| workspace_buffer_ref = global_workspace_buffer | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Run reference wrapper | ||||||||||||||||||||||||||||||
| wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( | ||||||||||||||||||||||||||||||
| workspace_buffer, kv_layout, use_tensor_cores=True | ||||||||||||||||||||||||||||||
| workspace_buffer_ref, kv_layout, use_tensor_cores=True | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| plan_params = { | ||||||||||||||||||||||||||||||
| "indptr": kv_indptr, | ||||||||||||||||||||||||||||||
|
|
@@ -535,7 +552,7 @@ def test_trtllm_batch_decode( | |||||||||||||||||||||||||||||
| if q_len_per_req > 1: | ||||||||||||||||||||||||||||||
| # hide the output_ref from decode wrapper for speculative decoding test | ||||||||||||||||||||||||||||||
| wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( | ||||||||||||||||||||||||||||||
| workspace_buffer, kv_layout | ||||||||||||||||||||||||||||||
| workspace_buffer_ref, kv_layout | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| plan_params_prefill = { | ||||||||||||||||||||||||||||||
| "qo_indptr": q_indptr, | ||||||||||||||||||||||||||||||
|
|
@@ -576,6 +593,9 @@ def test_trtllm_batch_decode( | |||||||||||||||||||||||||||||
| enable_pdl=enable_pdl, | ||||||||||||||||||||||||||||||
| q_len_per_req=q_len_per_req, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero | ||||||||||||||||||||||||||||||
| # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future | ||||||||||||||||||||||||||||||
| assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if o_dtype == "nvfp4": | ||||||||||||||||||||||||||||||
| output, output_ref = unpack_compare_nvfp4( | ||||||||||||||||||||||||||||||
|
|
@@ -648,6 +668,9 @@ def test_trtllm_batch_decode( | |||||||||||||||||||||||||||||
| atol=1e-1, | ||||||||||||||||||||||||||||||
| max_mismatched_elements=5, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero | ||||||||||||||||||||||||||||||
| # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future | ||||||||||||||||||||||||||||||
| assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("batch_size", [4, 128, 256]) | ||||||||||||||||||||||||||||||
|
|
@@ -699,7 +722,17 @@ def test_trtllm_gen_prefill_deepseek( | |||||||||||||||||||||||||||||
| # Initialize scale | ||||||||||||||||||||||||||||||
| scale = float(1.0 / (head_dim_qk**0.5)) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| workspace_buffer = torch.empty(workspace_size, dtype=torch.int8, device=device) | ||||||||||||||||||||||||||||||
| global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer | ||||||||||||||||||||||||||||||
| if global_workspace_buffer is None: | ||||||||||||||||||||||||||||||
| global_workspace_buffer = torch.empty( | ||||||||||||||||||||||||||||||
| workspace_size, dtype=torch.int8, device=device | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| if global_trtllm_gen_fmha_workspace_buffer is None: | ||||||||||||||||||||||||||||||
| global_trtllm_gen_fmha_workspace_buffer = torch.zeros( | ||||||||||||||||||||||||||||||
| workspace_size, dtype=torch.int8, device=device | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| workspace_buffer = global_trtllm_gen_fmha_workspace_buffer | ||||||||||||||||||||||||||||||
| workspace_buffer_ref = global_workspace_buffer | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| qo_indptr = torch.cat( | ||||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||||
|
|
@@ -722,7 +755,7 @@ def test_trtllm_gen_prefill_deepseek( | |||||||||||||||||||||||||||||
| ).int() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( | ||||||||||||||||||||||||||||||
| torch.zeros(workspace_size, device="cuda", dtype=torch.uint8), | ||||||||||||||||||||||||||||||
| workspace_buffer_ref, | ||||||||||||||||||||||||||||||
| kv_layout="NHD", | ||||||||||||||||||||||||||||||
| backend="cutlass", | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
@@ -775,6 +808,9 @@ def test_trtllm_gen_prefill_deepseek( | |||||||||||||||||||||||||||||
| atol=1e-3, | ||||||||||||||||||||||||||||||
| rtol=1e-3, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero | ||||||||||||||||||||||||||||||
| # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future | ||||||||||||||||||||||||||||||
| assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,7 +5,8 @@ | |||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| import flashinfer | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| global_workspace_buffer = None | ||||||||||||||||||||||||||||||
| global_workspace_buffer = None # can.be empty initialized | ||||||||||||||||||||||||||||||
| global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized | ||||||||||||||||||||||||||||||
| workspace_size = 128 * 1024 * 1024 | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -96,12 +97,17 @@ def test_trtllm_batch_decode_mla( | |||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Allocate workspace buffer | ||||||||||||||||||||||||||||||
| # todo(Yingyi): calculate the actual size of workspace buffer | ||||||||||||||||||||||||||||||
| global global_workspace_buffer | ||||||||||||||||||||||||||||||
| global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer | ||||||||||||||||||||||||||||||
| if global_workspace_buffer is None: | ||||||||||||||||||||||||||||||
| global_workspace_buffer = torch.zeros( | ||||||||||||||||||||||||||||||
| global_workspace_buffer = torch.empty( | ||||||||||||||||||||||||||||||
| workspace_size, dtype=torch.int8, device=device | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| workspace_buffer = global_workspace_buffer | ||||||||||||||||||||||||||||||
| if global_trtllm_gen_fmha_workspace_buffer is None: | ||||||||||||||||||||||||||||||
| global_trtllm_gen_fmha_workspace_buffer = torch.zeros( | ||||||||||||||||||||||||||||||
| workspace_size, dtype=torch.int8, device=device | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| workspace_buffer = global_trtllm_gen_fmha_workspace_buffer | ||||||||||||||||||||||||||||||
| workspace_buffer_ref = global_workspace_buffer | ||||||||||||||||||||||||||||||
|
Comment on lines
+100
to
+110
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This workspace buffer initialization logic is duplicated from A shared helper could look like this: from typing import Tuple
def get_workspace_buffers(device: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initializes and returns the workspace buffers."""
global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
workspace_size, dtype=torch.int8, device=device
)
if global_trtllm_gen_fmha_workspace_buffer is None:
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
workspace_size, dtype=torch.int8, device=device
)
return global_trtllm_gen_fmha_workspace_buffer, global_workspace_bufferAfter creating the shared helper, you can replace this block with a single call.
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| bmm1_log2_scale_tensor = ( | ||||||||||||||||||||||||||||||
| torch.tensor( | ||||||||||||||||||||||||||||||
|
|
@@ -135,12 +141,14 @@ def test_trtllm_batch_decode_mla( | |||||||||||||||||||||||||||||
| bmm2_scale_tensor=bmm2_scale_tensor, | ||||||||||||||||||||||||||||||
| enable_pdl=enable_pdl, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero | ||||||||||||||||||||||||||||||
| # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future | ||||||||||||||||||||||||||||||
| assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Run reference attention and align output | ||||||||||||||||||||||||||||||
| sm_scale = scale / ( | ||||||||||||||||||||||||||||||
| (128 + 64) ** 0.5 | ||||||||||||||||||||||||||||||
| ) # use head dimension before matrix absorption | ||||||||||||||||||||||||||||||
| workspace_buffer_ref = torch.empty(workspace_size, dtype=torch.int8, device=device) | ||||||||||||||||||||||||||||||
| wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( | ||||||||||||||||||||||||||||||
| workspace_buffer_ref, | ||||||||||||||||||||||||||||||
| backend="fa2", | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this only make it zero the first time, later execution won't enter this branch, is it expected?
And is it possible to reuse with ref workspace buffer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it's expected.
No. trtllm-gen kernel and fi kernel should re-use individual workspace as fi kernel does not require zero-init workspace.