Skip to content

Commit cd7a0e7

Browse files
benchislettgjc0824
authored andcommitted
[Bugfix] Use a separate FlashInfer workspace buffer for trtllm-gen (vllm-project#25520)
Signed-off-by: gaojc <1055866782@qq.com>
1 parent d42c61e commit cd7a0e7

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@
5050

5151
logger = init_logger(__name__)
5252

53+
trtllm_gen_workspace_buffer = None
54+
55+
56+
def _get_trtllm_gen_workspace_buffer():
57+
global trtllm_gen_workspace_buffer
58+
if trtllm_gen_workspace_buffer is None:
59+
trtllm_gen_workspace_buffer = torch.zeros(
60+
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda')
61+
return trtllm_gen_workspace_buffer
62+
5363

5464
@triton.jit
5565
def _trtllm_prefill_attn_kvfp8_dequant(
@@ -936,7 +946,7 @@ def forward(
936946
else:
937947
# prefill_query may be non-contiguous
938948
prefill_query = prefill_query.contiguous()
939-
workspace_buffer = prefill_wrapper._float_workspace_buffer
949+
workspace_buffer = _get_trtllm_gen_workspace_buffer()
940950
block_tables_prefill = attn_metadata.block_table_tensor[
941951
num_decode_tokens:]
942952
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
@@ -1038,7 +1048,7 @@ def forward(
10381048
else:
10391049
# decode_query may be non-contiguous
10401050
decode_query = decode_query.contiguous()
1041-
workspace_buffer = decode_wrapper._float_workspace_buffer
1051+
workspace_buffer = _get_trtllm_gen_workspace_buffer()
10421052
block_tables_decode = attn_metadata.\
10431053
block_table_tensor[:num_decode_tokens]
10441054
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]

0 commit comments

Comments
 (0)