|
48 | 48 |
|
49 | 49 | logger = init_logger(__name__) |
50 | 50 |
|
| 51 | +trtllm_gen_workspace_buffer = None |
| 52 | + |
| 53 | + |
| 54 | +def _get_trtllm_gen_workspace_buffer(): |
| 55 | + global trtllm_gen_workspace_buffer |
| 56 | + if trtllm_gen_workspace_buffer is None: |
| 57 | + trtllm_gen_workspace_buffer = torch.zeros( |
| 58 | + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda') |
| 59 | + return trtllm_gen_workspace_buffer |
| 60 | + |
51 | 61 |
|
52 | 62 | @triton.jit |
53 | 63 | def _trtllm_prefill_attn_kvfp8_dequant( |
@@ -862,7 +872,7 @@ def forward( |
862 | 872 | else: |
863 | 873 | # prefill_query may be non-contiguous |
864 | 874 | prefill_query = prefill_query.contiguous() |
865 | | - workspace_buffer = prefill_wrapper._float_workspace_buffer |
| 875 | + workspace_buffer = _get_trtllm_gen_workspace_buffer() |
866 | 876 | block_tables_prefill = attn_metadata.block_table_tensor[ |
867 | 877 | num_decode_tokens:] |
868 | 878 | seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:] |
@@ -943,7 +953,7 @@ def forward( |
943 | 953 | else: |
944 | 954 | # decode_query may be non-contiguous |
945 | 955 | decode_query = decode_query.contiguous() |
946 | | - workspace_buffer = decode_wrapper._float_workspace_buffer |
| 956 | + workspace_buffer = _get_trtllm_gen_workspace_buffer() |
947 | 957 | block_tables_decode = attn_metadata.\ |
948 | 958 | block_table_tensor[:num_decode_tokens] |
949 | 959 | seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] |
|
0 commit comments