Skip to content

Commit cb825af

Browse files
benchislettyewentao256
authored andcommitted
[Bugfix] Use a separate FlashInfer workspace buffer for trtllm-gen (#25520)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 342d17f commit cb825af

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@
4848

4949
logger = init_logger(__name__)
5050

51+
trtllm_gen_workspace_buffer = None
52+
53+
54+
def _get_trtllm_gen_workspace_buffer():
55+
global trtllm_gen_workspace_buffer
56+
if trtllm_gen_workspace_buffer is None:
57+
trtllm_gen_workspace_buffer = torch.zeros(
58+
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda')
59+
return trtllm_gen_workspace_buffer
60+
5161

5262
@triton.jit
5363
def _trtllm_prefill_attn_kvfp8_dequant(
@@ -862,7 +872,7 @@ def forward(
862872
else:
863873
# prefill_query may be non-contiguous
864874
prefill_query = prefill_query.contiguous()
865-
workspace_buffer = prefill_wrapper._float_workspace_buffer
875+
workspace_buffer = _get_trtllm_gen_workspace_buffer()
866876
block_tables_prefill = attn_metadata.block_table_tensor[
867877
num_decode_tokens:]
868878
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
@@ -943,7 +953,7 @@ def forward(
943953
else:
944954
# decode_query may be non-contiguous
945955
decode_query = decode_query.contiguous()
946-
workspace_buffer = decode_wrapper._float_workspace_buffer
956+
workspace_buffer = _get_trtllm_gen_workspace_buffer()
947957
block_tables_decode = attn_metadata.\
948958
block_table_tensor[:num_decode_tokens]
949959
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]

0 commit comments

Comments
 (0)