1717
1818GPU_DEVICE = "cuda:0"
1919
20- global_workspace_buffer = None
20+ global_workspace_buffer = None # can.be empty initialized
21+ global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized
2122workspace_size = 128 * 1024 * 1024
2223
2324
@@ -320,16 +321,21 @@ def test_trtllm_batch_prefill(
320321 else None
321322 )
322323
323- global global_workspace_buffer
324+ global global_workspace_buffer , global_trtllm_gen_fmha_workspace_buffer
324325 if global_workspace_buffer is None :
325- global_workspace_buffer = torch .zeros (
326+ global_workspace_buffer = torch .empty (
326327 workspace_size , dtype = torch .int8 , device = GPU_DEVICE
327328 )
328- workspace_buffer = global_workspace_buffer
329+ if global_trtllm_gen_fmha_workspace_buffer is None :
330+ global_trtllm_gen_fmha_workspace_buffer = torch .zeros (
331+ workspace_size , dtype = torch .int8 , device = GPU_DEVICE
332+ )
333+ workspace_buffer_ref = global_workspace_buffer
334+ workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
329335
330336 # Run reference wrapper
331337 wrapper_ref = flashinfer .prefill .BatchPrefillWithPagedKVCacheWrapper (
332- workspace_buffer , kv_layout
338+ workspace_buffer_ref , kv_layout
333339 )
334340 plan_params = {
335341 "qo_indptr" : q_indptr ,
@@ -372,6 +378,9 @@ def test_trtllm_batch_prefill(
372378 o_sf_vec_size = o_sf_vec_size ,
373379 enable_pdl = enable_pdl ,
374380 )
381+ # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
382+ # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
383+ assert (workspace_buffer [: 8192 * 256 * 4 ].cpu ().numpy () == 0 ).all ()
375384
376385 if o_dtype == "nvfp4" :
377386 output , output_ref = unpack_compare_nvfp4 (
@@ -414,6 +423,9 @@ def test_trtllm_batch_prefill(
414423 torch .testing .assert_close (
415424 output .float (), output_wrapper .float (), rtol = 1e-1 , atol = 1e-1
416425 )
426+ # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
427+ # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
428+ assert (workspace_buffer [: 8192 * 256 * 4 ].cpu ().numpy () == 0 ).all ()
417429
418430
419431@pytest .mark .parametrize ("kv_layout" , ["HND" ]) # trtllm-gen only support HND
@@ -505,16 +517,21 @@ def test_trtllm_batch_decode(
505517 else None
506518 )
507519
508- global global_workspace_buffer
520+ global global_workspace_buffer , global_trtllm_gen_fmha_workspace_buffer
509521 if global_workspace_buffer is None :
510- global_workspace_buffer = torch .zeros (
522+ global_workspace_buffer = torch .empty (
523+ workspace_size , dtype = torch .int8 , device = GPU_DEVICE
524+ )
525+ if global_trtllm_gen_fmha_workspace_buffer is None :
526+ global_trtllm_gen_fmha_workspace_buffer = torch .zeros (
511527 workspace_size , dtype = torch .int8 , device = GPU_DEVICE
512528 )
513- workspace_buffer = global_workspace_buffer
529+ workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
530+ workspace_buffer_ref = global_workspace_buffer
514531
515532 # Run reference wrapper
516533 wrapper_ref = flashinfer .decode .BatchDecodeWithPagedKVCacheWrapper (
517- workspace_buffer , kv_layout , use_tensor_cores = True
534+ workspace_buffer_ref , kv_layout , use_tensor_cores = True
518535 )
519536 plan_params = {
520537 "indptr" : kv_indptr ,
@@ -535,7 +552,7 @@ def test_trtllm_batch_decode(
535552 if q_len_per_req > 1 :
536553 # hide the output_ref from decode wrapper for speculative decoding test
537554 wrapper_ref = flashinfer .prefill .BatchPrefillWithPagedKVCacheWrapper (
538- workspace_buffer , kv_layout
555+ workspace_buffer_ref , kv_layout
539556 )
540557 plan_params_prefill = {
541558 "qo_indptr" : q_indptr ,
@@ -576,6 +593,9 @@ def test_trtllm_batch_decode(
576593 enable_pdl = enable_pdl ,
577594 q_len_per_req = q_len_per_req ,
578595 )
596+ # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
597+ # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
598+ assert (workspace_buffer [: 8192 * 256 * 4 ].cpu ().numpy () == 0 ).all ()
579599
580600 if o_dtype == "nvfp4" :
581601 output , output_ref = unpack_compare_nvfp4 (
@@ -648,6 +668,9 @@ def test_trtllm_batch_decode(
648668 atol = 1e-1 ,
649669 max_mismatched_elements = 5 ,
650670 )
671+ # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
672+ # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
673+ assert (workspace_buffer [: 8192 * 256 * 4 ].cpu ().numpy () == 0 ).all ()
651674
652675
653676@pytest .mark .parametrize ("batch_size" , [4 , 128 , 256 ])
@@ -699,7 +722,17 @@ def test_trtllm_gen_prefill_deepseek(
699722 # Initialize scale
700723 scale = float (1.0 / (head_dim_qk ** 0.5 ))
701724
702- workspace_buffer = torch .empty (workspace_size , dtype = torch .int8 , device = device )
725+ global global_workspace_buffer , global_trtllm_gen_fmha_workspace_buffer
726+ if global_workspace_buffer is None :
727+ global_workspace_buffer = torch .empty (
728+ workspace_size , dtype = torch .int8 , device = device
729+ )
730+ if global_trtllm_gen_fmha_workspace_buffer is None :
731+ global_trtllm_gen_fmha_workspace_buffer = torch .zeros (
732+ workspace_size , dtype = torch .int8 , device = device
733+ )
734+ workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
735+ workspace_buffer_ref = global_workspace_buffer
703736
704737 qo_indptr = torch .cat (
705738 [
@@ -722,7 +755,7 @@ def test_trtllm_gen_prefill_deepseek(
722755 ).int ()
723756
724757 wrapper = flashinfer .prefill .BatchPrefillWithRaggedKVCacheWrapper (
725- torch . zeros ( workspace_size , device = "cuda" , dtype = torch . uint8 ) ,
758+ workspace_buffer_ref ,
726759 kv_layout = "NHD" ,
727760 backend = "cutlass" ,
728761 )
@@ -775,6 +808,9 @@ def test_trtllm_gen_prefill_deepseek(
775808 atol = 1e-3 ,
776809 rtol = 1e-3 ,
777810 )
811+ # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
812+ # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
813+ assert (workspace_buffer [: 8192 * 256 * 4 ].cpu ().numpy () == 0 ).all ()
778814
779815
780816if __name__ == "__main__" :
0 commit comments