Skip to content

Commit 10a4dd6

Browse files
benchislettlhsjohn
authored andcommitted
[Spec Decode] Enable FlashInfer Spec Decoding (vllm-project#25196)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 0d5c159 commit 10a4dd6

File tree

12 files changed

+250
-49
lines changed

12 files changed

+250
-49
lines changed

tests/v1/attention/test_attention_splitting.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from vllm.v1.attention.backends.utils import (UBatchSlice,
1010
_make_metadata_with_slice,
1111
slice_query_start_locs,
12-
split_attn_metadata)
12+
split_attn_metadata,
13+
split_decodes_and_prefills)
1314
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
1415

1516

@@ -158,6 +159,112 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
158159
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
159160

160161

162+
def apply_split_decodes_and_prefills(query_lens: list[int],
163+
decode_threshold: int,
164+
require_uniform: bool):
165+
"""Helper function to apply split_decodes_and_prefills and return
166+
the results."""
167+
device = torch.device("cpu")
168+
seq_lens = [10 * (i + 1) for i in range(len(query_lens))]
169+
common_metadata = create_common_attn_metadata(BatchSpec(
170+
seq_lens=seq_lens, query_lens=query_lens),
171+
block_size=16,
172+
device=device)
173+
return split_decodes_and_prefills(common_metadata,
174+
decode_threshold=decode_threshold,
175+
require_uniform=require_uniform)
176+
177+
178+
def test_split_decodes_and_prefills_nonuniform_all_ones():
179+
query_lens = [1, 1, 1]
180+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
181+
apply_split_decodes_and_prefills(query_lens, 1, False))
182+
assert num_decodes == 3
183+
assert num_prefills == 0
184+
assert num_decode_tokens == 3
185+
assert num_prefill_tokens == 0
186+
187+
188+
def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
189+
query_lens = [1, 2, 1, 3, 2, 1, 2]
190+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
191+
apply_split_decodes_and_prefills(query_lens, 3, False))
192+
assert num_decodes == 7
193+
assert num_prefills == 0
194+
assert num_decode_tokens == sum(query_lens)
195+
assert num_prefill_tokens == 0
196+
197+
198+
def test_split_decodes_and_prefills_nonuniform_all_prefills():
199+
query_lens = [4, 5, 6, 7]
200+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
201+
apply_split_decodes_and_prefills(query_lens, 3, False))
202+
assert num_decodes == 0
203+
assert num_prefills == 4
204+
assert num_decode_tokens == 0
205+
assert num_prefill_tokens == sum(query_lens)
206+
207+
208+
def test_split_decodes_and_prefills_nonuniform_mixed_batch():
209+
query_lens = [2, 1, 3, 4, 5, 6, 7, 8]
210+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
211+
apply_split_decodes_and_prefills(query_lens, 4, False))
212+
assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4
213+
assert num_prefills == 4 # 5, 6, 7, 8 are all > 4
214+
assert num_decode_tokens == 10 # 2 + 1 + 3 + 4
215+
assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8
216+
217+
218+
def test_split_decodes_and_prefills_uniform_all_ones():
219+
query_lens = [1, 1, 1]
220+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
221+
apply_split_decodes_and_prefills(query_lens, 1, True))
222+
assert num_decodes == 3
223+
assert num_prefills == 0
224+
assert num_decode_tokens == 3
225+
assert num_prefill_tokens == 0
226+
227+
228+
def test_split_decodes_and_prefills_uniform_all_short_decodes():
229+
query_lens = [2, 2, 1, 3, 2, 1, 2]
230+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
231+
apply_split_decodes_and_prefills(query_lens, 3, True))
232+
assert num_decodes == 2
233+
assert num_prefills == 5
234+
assert num_decode_tokens == 4
235+
assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2)
236+
237+
238+
def test_split_decodes_and_prefills_uniform_all_prefills():
239+
query_lens = [4, 5, 6, 7]
240+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
241+
apply_split_decodes_and_prefills(query_lens, 3, True))
242+
assert num_decodes == 0
243+
assert num_prefills == 4
244+
assert num_decode_tokens == 0
245+
assert num_prefill_tokens == sum(query_lens)
246+
247+
248+
def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
249+
query_lens = [2, 2, 2, 4, 5, 6, 7, 8]
250+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
251+
apply_split_decodes_and_prefills(query_lens, 4, True))
252+
assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform
253+
assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4
254+
assert num_decode_tokens == 6 # 2 + 2 + 2
255+
assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8
256+
257+
258+
def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
259+
query_lens = [2, 1, 2, 4, 5, 6, 7, 8]
260+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
261+
apply_split_decodes_and_prefills(query_lens, 4, True))
262+
assert num_decodes == 1 # only the first 2 is taken as decode
263+
assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform
264+
assert num_decode_tokens == 2 # only the first 2
265+
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
266+
267+
161268
@pytest.mark.parametrize(
162269
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
163270
[

vllm/utils/flashinfer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,22 @@ def force_use_trtllm_attention() -> Optional[bool]:
181181
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
182182

183183

184+
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
185+
"""Check if the current configuration supports TRTLLM attention."""
186+
has_trtllm = supports_trtllm_attention()
187+
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
188+
189+
184190
def use_trtllm_attention(
185191
num_qo_heads: int,
186192
num_kv_heads: int,
187193
num_tokens: int,
188194
max_seq_len: int,
189195
kv_cache_dtype: str,
190196
q_dtype: torch.dtype,
197+
is_prefill: bool,
191198
has_sinks: bool = False,
199+
has_spec: bool = False,
192200
) -> bool:
193201
"""Return ``True`` if TRTLLM attention is used."""
194202
force_use_trtllm = force_use_trtllm_attention()
@@ -214,6 +222,12 @@ def use_trtllm_attention(
214222
)
215223
return False
216224

225+
if has_spec and not is_prefill:
226+
# Speculative decoding requires TRTLLM attention for decodes
227+
logger.info_once(
228+
"Using TRTLLM attention (enabled for speculative decoding).")
229+
return True
230+
217231
# Must use TRTLLM attention if query is FP8 quantized
218232
if q_dtype == current_platform.fp8_dtype():
219233
if has_sinks:
@@ -391,6 +405,7 @@ def flashinfer_disable_q_quantization() -> bool:
391405
"has_flashinfer_cutlass_fused_moe",
392406
"has_nvidia_artifactory",
393407
"supports_trtllm_attention",
408+
"can_use_trtllm_attention",
394409
"use_trtllm_attention",
395410
"flashinfer_disable_q_quantization",
396411
"flashinfer_scaled_fp4_mm",

vllm/v1/attention/backends/flashinfer.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from vllm.platforms import current_platform
2626
from vllm.triton_utils import tl, triton
2727
from vllm.utils import cdiv, is_pin_memory_available
28-
from vllm.utils.flashinfer import (flashinfer_disable_q_quantization,
28+
from vllm.utils.flashinfer import (can_use_trtllm_attention,
29+
flashinfer_disable_q_quantization,
2930
supports_trtllm_attention,
3031
use_trtllm_attention)
3132
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
@@ -223,6 +224,7 @@ class FlashInferMetadata:
223224

224225
# For flashinfer trtllm batch decode
225226
max_q_len: int
227+
max_q_len_prefill: int
226228
max_seq_len: int
227229
seq_lens: torch.Tensor
228230
block_table_tensor: torch.Tensor
@@ -250,7 +252,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
250252
cudagraph_support: ClassVar[AttentionCGSupport] = \
251253
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
252254

253-
reorder_batch_threshold: ClassVar[int] = 1
255+
reorder_batch_threshold: int = 1
254256

255257
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
256258
vllm_config: VllmConfig, device: torch.device):
@@ -302,6 +304,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
302304
else:
303305
self.q_data_type = self.model_config.dtype
304306

307+
supports_spec_as_decode = \
308+
can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
309+
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
310+
305311
self._cascade_wrapper = None # Wrapper for cascade attention
306312

307313
# Global hyperparameters shared by all attention layers
@@ -416,7 +422,8 @@ def build(self,
416422
num_actual_tokens = common_attn_metadata.num_actual_tokens
417423
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
418424
split_decodes_and_prefills(common_attn_metadata,
419-
decode_threshold=self.reorder_batch_threshold)
425+
decode_threshold=self.reorder_batch_threshold,
426+
require_uniform=True)
420427

421428
page_size = self.page_size
422429
max_q_len = common_attn_metadata.max_query_len
@@ -491,20 +498,25 @@ def build(self,
491498
paged_kv_last_page_len_np,
492499
)
493500

501+
uses_spec_reorder = self.reorder_batch_threshold > 1
494502
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
495503
self.num_kv_heads,
496504
num_prefill_tokens,
497505
max_seq_len,
498506
self.cache_dtype,
499507
self.q_data_type,
500-
has_sinks=self.has_sinks)
508+
is_prefill=True,
509+
has_sinks=self.has_sinks,
510+
has_spec=uses_spec_reorder)
501511
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
502512
self.num_kv_heads,
503513
num_decode_tokens,
504514
max_seq_len,
505515
self.cache_dtype,
506516
self.q_data_type,
507-
has_sinks=self.has_sinks)
517+
is_prefill=False,
518+
has_sinks=self.has_sinks,
519+
has_spec=uses_spec_reorder)
508520
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
509521
raise NotImplementedError(
510522
"FlashInfer backend currently does not support attention "
@@ -521,6 +533,7 @@ def build(self,
521533
q_data_type=self.q_data_type,
522534
slot_mapping=common_attn_metadata.slot_mapping,
523535
max_q_len=max_q_len,
536+
max_q_len_prefill=max_q_len,
524537
max_seq_len=max_seq_len,
525538
seq_lens=seq_lens,
526539
block_table_tensor=block_table_tensor,
@@ -577,6 +590,15 @@ def build(self,
577590
qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
578591
prefill_start]
579592
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
593+
594+
# Recompute max_q_len for the slice of requests we are using
595+
# for prefills. This can be different from max_q_len when
596+
# we have a non-uniform batch with some short decodes offloaded
597+
# to the prefill pathway
598+
query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]
599+
attn_metadata.max_q_len_prefill = \
600+
int(query_lens_prefill.max().item())
601+
580602
if not attn_metadata.prefill_use_trtllm:
581603
attn_metadata.prefill_wrapper.plan(
582604
qo_indptr_cpu,
@@ -607,7 +629,7 @@ def build(self,
607629
num_decodes <= self._decode_cudagraph_max_bs)
608630
if use_cudagraph:
609631
num_input_tokens = (
610-
self.vllm_config.pad_for_cudagraph(num_decodes))
632+
self.vllm_config.pad_for_cudagraph(num_decode_tokens))
611633
# Carefully fulfill the padding region with reasonable value
612634
# on cpu.
613635
# Make sure paged_kv_indptr_cpu is not decreasing
@@ -621,7 +643,7 @@ def build(self,
621643
num_decodes:num_input_tokens].fill_(1)
622644

623645
else:
624-
num_input_tokens = num_decodes
646+
num_input_tokens = num_decode_tokens
625647

626648
attn_metadata.decode_wrapper = self._get_decode_wrapper(
627649
num_input_tokens, use_cudagraph)
@@ -842,6 +864,9 @@ def forward(
842864
output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
843865
return output
844866

867+
# When using spec decoding, num_decodes can be < num_decode_tokens
868+
# because some decode requests may have more than one query token.
869+
num_decodes = attn_metadata.num_decodes
845870
num_decode_tokens = attn_metadata.num_decode_tokens
846871
num_prefill_tokens = attn_metadata.num_prefill_tokens
847872

@@ -874,8 +899,8 @@ def forward(
874899
prefill_query = prefill_query.contiguous()
875900
workspace_buffer = _get_trtllm_gen_workspace_buffer()
876901
block_tables_prefill = attn_metadata.block_table_tensor[
877-
num_decode_tokens:]
878-
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
902+
num_decodes:]
903+
seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]
879904

880905
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
881906
assert get_kv_cache_layout() == "HND"
@@ -919,7 +944,7 @@ def forward(
919944
workspace_buffer=workspace_buffer,
920945
block_tables=mock_block_table,
921946
seq_lens=seq_lens_prefill,
922-
max_q_len=attn_metadata.max_q_len,
947+
max_q_len=attn_metadata.max_q_len_prefill,
923948
max_kv_len=attn_metadata.max_seq_len,
924949
bmm1_scale=self.bmm1_scale,
925950
bmm2_scale=self.bmm2_scale,
@@ -976,6 +1001,14 @@ def forward(
9761001
assert self.o_sf_scale is None
9771002
out = output[:num_decode_tokens]
9781003

1004+
if num_decode_tokens % attn_metadata.num_decodes != 0:
1005+
# This gets triggered when the dummy_run forces
1006+
# attention to be initialized with q_len = 0
1007+
q_len_per_req = 1
1008+
else:
1009+
q_len_per_req = \
1010+
num_decode_tokens // attn_metadata.num_decodes
1011+
9791012
trtllm_batch_decode_with_kv_cache(
9801013
query=decode_query,
9811014
kv_cache=kv_cache_permute,
@@ -989,7 +1022,7 @@ def forward(
9891022
sinks=self.sinks,
9901023
o_sf_scale=self.o_sf_scale,
9911024
out=out,
992-
)
1025+
q_len_per_req=q_len_per_req)
9931026
return output_padded
9941027

9951028

vllm/v1/attention/backends/gdn_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Backend for GatedDeltaNet attention."""
44
from dataclasses import dataclass
5-
from typing import ClassVar, Optional
5+
from typing import Optional
66

77
import torch
88

@@ -62,7 +62,7 @@ class GDNAttentionMetadataBuilder(
6262

6363
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
6464

65-
reorder_batch_threshold: ClassVar[int] = 1
65+
reorder_batch_threshold: int = 1
6666

6767
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
6868
vllm_config: VllmConfig, device: torch.device):
@@ -76,7 +76,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
7676
else:
7777
self.num_spec = 0
7878
self.use_spec_decode = self.num_spec > 0
79-
self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc]
79+
self._init_reorder_batch_threshold(1, self.use_spec_decode)
8080

8181
self.use_full_cuda_graph = \
8282
self.compilation_config.cudagraph_mode.has_full_cudagraphs()

vllm/v1/attention/backends/linear_attn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from dataclasses import dataclass
4-
from typing import ClassVar
54

65
import torch
76

@@ -35,7 +34,7 @@ class LinearAttentionMetadata:
3534
class LinearAttentionMetadataBuilder(
3635
AttentionMetadataBuilder[LinearAttentionMetadata]):
3736

38-
reorder_batch_threshold: ClassVar[int] = 1
37+
reorder_batch_threshold: int = 1
3938

4039
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
4140
vllm_config: VllmConfig, device: torch.device):

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
19-
reorder_batch_threshold: ClassVar[int] = 1
19+
reorder_batch_threshold: int = 1
2020
cudagraph_support: ClassVar[AttentionCGSupport] = \
2121
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
2222

0 commit comments

Comments
 (0)