|
50 | 50 |
|
51 | 51 | logger = init_logger(__name__) |
52 | 52 |
|
| 53 | +trtllm_gen_workspace_buffer = None |
| 54 | + |
| 55 | + |
| 56 | +def _get_trtllm_gen_workspace_buffer(): |
| 57 | + global trtllm_gen_workspace_buffer |
| 58 | + if trtllm_gen_workspace_buffer is None: |
| 59 | + trtllm_gen_workspace_buffer = torch.zeros( |
| 60 | + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda') |
| 61 | + return trtllm_gen_workspace_buffer |
| 62 | + |
53 | 63 |
|
54 | 64 | @triton.jit |
55 | 65 | def _trtllm_prefill_attn_kvfp8_dequant( |
@@ -936,7 +946,7 @@ def forward( |
936 | 946 | else: |
937 | 947 | # prefill_query may be non-contiguous |
938 | 948 | prefill_query = prefill_query.contiguous() |
939 | | - workspace_buffer = prefill_wrapper._float_workspace_buffer |
| 949 | + workspace_buffer = _get_trtllm_gen_workspace_buffer() |
940 | 950 | block_tables_prefill = attn_metadata.block_table_tensor[ |
941 | 951 | num_decode_tokens:] |
942 | 952 | seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:] |
@@ -1038,7 +1048,7 @@ def forward( |
1038 | 1048 | else: |
1039 | 1049 | # decode_query may be non-contiguous |
1040 | 1050 | decode_query = decode_query.contiguous() |
1041 | | - workspace_buffer = decode_wrapper._float_workspace_buffer |
| 1051 | + workspace_buffer = _get_trtllm_gen_workspace_buffer() |
1042 | 1052 | block_tables_decode = attn_metadata.\ |
1043 | 1053 | block_table_tensor[:num_decode_tokens] |
1044 | 1054 | seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] |
|
0 commit comments