Skip to content

Commit 1c93a20

Browse files
benchislettlhsjohn
authored andcommitted
[Spec Decode] Enable FlashInfer Spec Decoding (vllm-project#25196)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: lhsjohn <huashuoli@tencent.com>
1 parent 9c00918 commit 1c93a20

File tree

12 files changed

+250
-49
lines changed

12 files changed

+250
-49
lines changed

tests/v1/attention/test_attention_splitting.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from vllm.v1.attention.backends.utils import (UBatchSlice,
1010
_make_metadata_with_slice,
1111
slice_query_start_locs,
12-
split_attn_metadata)
12+
split_attn_metadata,
13+
split_decodes_and_prefills)
1314
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
1415

1516

@@ -158,6 +159,112 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
158159
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
159160

160161

162+
def apply_split_decodes_and_prefills(query_lens: list[int],
163+
decode_threshold: int,
164+
require_uniform: bool):
165+
"""Helper function to apply split_decodes_and_prefills and return
166+
the results."""
167+
device = torch.device("cpu")
168+
seq_lens = [10 * (i + 1) for i in range(len(query_lens))]
169+
common_metadata = create_common_attn_metadata(BatchSpec(
170+
seq_lens=seq_lens, query_lens=query_lens),
171+
block_size=16,
172+
device=device)
173+
return split_decodes_and_prefills(common_metadata,
174+
decode_threshold=decode_threshold,
175+
require_uniform=require_uniform)
176+
177+
178+
def test_split_decodes_and_prefills_nonuniform_all_ones():
179+
query_lens = [1, 1, 1]
180+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
181+
apply_split_decodes_and_prefills(query_lens, 1, False))
182+
assert num_decodes == 3
183+
assert num_prefills == 0
184+
assert num_decode_tokens == 3
185+
assert num_prefill_tokens == 0
186+
187+
188+
def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
189+
query_lens = [1, 2, 1, 3, 2, 1, 2]
190+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
191+
apply_split_decodes_and_prefills(query_lens, 3, False))
192+
assert num_decodes == 7
193+
assert num_prefills == 0
194+
assert num_decode_tokens == sum(query_lens)
195+
assert num_prefill_tokens == 0
196+
197+
198+
def test_split_decodes_and_prefills_nonuniform_all_prefills():
199+
query_lens = [4, 5, 6, 7]
200+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
201+
apply_split_decodes_and_prefills(query_lens, 3, False))
202+
assert num_decodes == 0
203+
assert num_prefills == 4
204+
assert num_decode_tokens == 0
205+
assert num_prefill_tokens == sum(query_lens)
206+
207+
208+
def test_split_decodes_and_prefills_nonuniform_mixed_batch():
209+
query_lens = [2, 1, 3, 4, 5, 6, 7, 8]
210+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
211+
apply_split_decodes_and_prefills(query_lens, 4, False))
212+
assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4
213+
assert num_prefills == 4 # 5, 6, 7, 8 are all > 4
214+
assert num_decode_tokens == 10 # 2 + 1 + 3 + 4
215+
assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8
216+
217+
218+
def test_split_decodes_and_prefills_uniform_all_ones():
219+
query_lens = [1, 1, 1]
220+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
221+
apply_split_decodes_and_prefills(query_lens, 1, True))
222+
assert num_decodes == 3
223+
assert num_prefills == 0
224+
assert num_decode_tokens == 3
225+
assert num_prefill_tokens == 0
226+
227+
228+
def test_split_decodes_and_prefills_uniform_all_short_decodes():
229+
query_lens = [2, 2, 1, 3, 2, 1, 2]
230+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
231+
apply_split_decodes_and_prefills(query_lens, 3, True))
232+
assert num_decodes == 2
233+
assert num_prefills == 5
234+
assert num_decode_tokens == 4
235+
assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2)
236+
237+
238+
def test_split_decodes_and_prefills_uniform_all_prefills():
239+
query_lens = [4, 5, 6, 7]
240+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
241+
apply_split_decodes_and_prefills(query_lens, 3, True))
242+
assert num_decodes == 0
243+
assert num_prefills == 4
244+
assert num_decode_tokens == 0
245+
assert num_prefill_tokens == sum(query_lens)
246+
247+
248+
def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
249+
query_lens = [2, 2, 2, 4, 5, 6, 7, 8]
250+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
251+
apply_split_decodes_and_prefills(query_lens, 4, True))
252+
assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform
253+
assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4
254+
assert num_decode_tokens == 6 # 2 + 2 + 2
255+
assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8
256+
257+
258+
def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
259+
query_lens = [2, 1, 2, 4, 5, 6, 7, 8]
260+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
261+
apply_split_decodes_and_prefills(query_lens, 4, True))
262+
assert num_decodes == 1 # only the first 2 is taken as decode
263+
assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform
264+
assert num_decode_tokens == 2 # only the first 2
265+
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
266+
267+
161268
@pytest.mark.parametrize(
162269
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
163270
[

vllm/utils/flashinfer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,22 @@ def force_use_trtllm_attention() -> Optional[bool]:
181181
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
182182

183183

184+
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
185+
"""Check if the current configuration supports TRTLLM attention."""
186+
has_trtllm = supports_trtllm_attention()
187+
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
188+
189+
184190
def use_trtllm_attention(
185191
num_qo_heads: int,
186192
num_kv_heads: int,
187193
num_tokens: int,
188194
max_seq_len: int,
189195
kv_cache_dtype: str,
190196
q_dtype: torch.dtype,
197+
is_prefill: bool,
191198
has_sinks: bool = False,
199+
has_spec: bool = False,
192200
) -> bool:
193201
"""Return ``True`` if TRTLLM attention is used."""
194202
force_use_trtllm = force_use_trtllm_attention()
@@ -214,6 +222,12 @@ def use_trtllm_attention(
214222
)
215223
return False
216224

225+
if has_spec and not is_prefill:
226+
# Speculative decoding requires TRTLLM attention for decodes
227+
logger.info_once(
228+
"Using TRTLLM attention (enabled for speculative decoding).")
229+
return True
230+
217231
# Must use TRTLLM attention if query is FP8 quantized
218232
if q_dtype == current_platform.fp8_dtype():
219233
if has_sinks:
@@ -391,6 +405,7 @@ def flashinfer_disable_q_quantization() -> bool:
391405
"has_flashinfer_cutlass_fused_moe",
392406
"has_nvidia_artifactory",
393407
"supports_trtllm_attention",
408+
"can_use_trtllm_attention",
394409
"use_trtllm_attention",
395410
"flashinfer_disable_q_quantization",
396411
"flashinfer_scaled_fp4_mm",

vllm/v1/attention/backends/flashinfer.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from vllm.platforms import current_platform
2828
from vllm.triton_utils import tl, triton
2929
from vllm.utils import cdiv, is_pin_memory_available
30-
from vllm.utils.flashinfer import (flashinfer_disable_q_quantization,
30+
from vllm.utils.flashinfer import (can_use_trtllm_attention,
31+
flashinfer_disable_q_quantization,
3132
supports_trtllm_attention,
3233
use_trtllm_attention)
3334
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
@@ -225,6 +226,7 @@ class FlashInferMetadata:
225226

226227
# For flashinfer trtllm batch decode
227228
max_q_len: int
229+
max_q_len_prefill: int
228230
max_seq_len: int
229231
seq_lens: torch.Tensor
230232
block_table_tensor: torch.Tensor
@@ -252,7 +254,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
252254
cudagraph_support: ClassVar[AttentionCGSupport] = \
253255
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
254256

255-
reorder_batch_threshold: ClassVar[int] = 1
257+
reorder_batch_threshold: int = 1
256258

257259
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
258260
vllm_config: VllmConfig, device: torch.device):
@@ -311,6 +313,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
311313
else:
312314
self.q_data_type = self.model_config.dtype
313315

316+
supports_spec_as_decode = \
317+
can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
318+
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
319+
314320
self._cascade_wrapper = None # Wrapper for cascade attention
315321

316322
# Global hyperparameters shared by all attention layers
@@ -425,7 +431,8 @@ def build(self,
425431
num_actual_tokens = common_attn_metadata.num_actual_tokens
426432
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
427433
split_decodes_and_prefills(common_attn_metadata,
428-
decode_threshold=self.reorder_batch_threshold)
434+
decode_threshold=self.reorder_batch_threshold,
435+
require_uniform=True)
429436

430437
page_size = self.page_size
431438
max_q_len = common_attn_metadata.max_query_len
@@ -503,20 +510,25 @@ def build(self,
503510
paged_kv_last_page_len_np,
504511
)
505512

513+
uses_spec_reorder = self.reorder_batch_threshold > 1
506514
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
507515
self.num_kv_heads,
508516
num_prefill_tokens,
509517
max_seq_len,
510518
self.cache_dtype,
511519
self.q_data_type,
512-
has_sinks=self.has_sinks)
520+
is_prefill=True,
521+
has_sinks=self.has_sinks,
522+
has_spec=uses_spec_reorder)
513523
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
514524
self.num_kv_heads,
515525
num_decode_tokens,
516526
max_seq_len,
517527
self.cache_dtype,
518528
self.q_data_type,
519-
has_sinks=self.has_sinks)
529+
is_prefill=False,
530+
has_sinks=self.has_sinks,
531+
has_spec=uses_spec_reorder)
520532
if self.dcp_world_size > 1 and (prefill_use_trtllm
521533
or decode_use_trtllm):
522534
raise NotImplementedError(
@@ -538,6 +550,7 @@ def build(self,
538550
q_data_type=self.q_data_type,
539551
slot_mapping=common_attn_metadata.slot_mapping,
540552
max_q_len=max_q_len,
553+
max_q_len_prefill=max_q_len,
541554
max_seq_len=max_seq_len,
542555
seq_lens=seq_lens,
543556
block_table_tensor=block_table_tensor,
@@ -595,6 +608,15 @@ def build(self,
595608
prefill_start]
596609
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
597610

611+
# Recompute max_q_len for the slice of requests we are using
612+
# for prefills. This can be different from max_q_len when
613+
# we have a non-uniform batch with some short decodes offloaded
614+
# to the prefill pathway
615+
query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]
616+
attn_metadata.max_q_len_prefill = \
617+
int(query_lens_prefill.max().item())
618+
619+
598620
if self.dcp_world_size > 1:
599621
# init custom mask for interleave kv cache
600622
mask_arr = []
@@ -660,7 +682,7 @@ def build(self,
660682
num_decodes <= self._decode_cudagraph_max_bs)
661683
if use_cudagraph:
662684
num_input_tokens = (
663-
self.vllm_config.pad_for_cudagraph(num_decodes))
685+
self.vllm_config.pad_for_cudagraph(num_decode_tokens))
664686
# Carefully fulfill the padding region with reasonable value
665687
# on cpu.
666688
# Make sure paged_kv_indptr_cpu is not decreasing
@@ -674,7 +696,7 @@ def build(self,
674696
num_decodes:num_input_tokens].fill_(1)
675697

676698
else:
677-
num_input_tokens = num_decodes
699+
num_input_tokens = num_decode_tokens
678700

679701
attn_metadata.decode_wrapper = self._get_decode_wrapper(
680702
num_input_tokens, use_cudagraph)
@@ -897,6 +919,9 @@ def forward(
897919
output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
898920
return output
899921

922+
# When using spec decoding, num_decodes can be < num_decode_tokens
923+
# because some decode requests may have more than one query token.
924+
num_decodes = attn_metadata.num_decodes
900925
num_decode_tokens = attn_metadata.num_decode_tokens
901926
num_prefill_tokens = attn_metadata.num_prefill_tokens
902927

@@ -948,8 +973,8 @@ def forward(
948973
prefill_query = prefill_query.contiguous()
949974
workspace_buffer = _get_trtllm_gen_workspace_buffer()
950975
block_tables_prefill = attn_metadata.block_table_tensor[
951-
num_decode_tokens:]
952-
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
976+
num_decodes:]
977+
seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]
953978

954979
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
955980
assert get_kv_cache_layout() == "HND"
@@ -993,7 +1018,7 @@ def forward(
9931018
workspace_buffer=workspace_buffer,
9941019
block_tables=mock_block_table,
9951020
seq_lens=seq_lens_prefill,
996-
max_q_len=attn_metadata.max_q_len,
1021+
max_q_len=attn_metadata.max_q_len_prefill,
9971022
max_kv_len=attn_metadata.max_seq_len,
9981023
bmm1_scale=self.bmm1_scale,
9991024
bmm2_scale=self.bmm2_scale,
@@ -1071,6 +1096,14 @@ def forward(
10711096
assert self.o_sf_scale is None
10721097
out = output[:num_decode_tokens]
10731098

1099+
if num_decode_tokens % attn_metadata.num_decodes != 0:
1100+
# This gets triggered when the dummy_run forces
1101+
# attention to be initialized with q_len = 0
1102+
q_len_per_req = 1
1103+
else:
1104+
q_len_per_req = \
1105+
num_decode_tokens // attn_metadata.num_decodes
1106+
10741107
trtllm_batch_decode_with_kv_cache(
10751108
query=decode_query,
10761109
kv_cache=kv_cache_permute,
@@ -1084,7 +1117,7 @@ def forward(
10841117
sinks=self.sinks,
10851118
o_sf_scale=self.o_sf_scale,
10861119
out=out,
1087-
)
1120+
q_len_per_req=q_len_per_req)
10881121
return output_padded
10891122

10901123

vllm/v1/attention/backends/gdn_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Backend for GatedDeltaNet attention."""
44
from dataclasses import dataclass
5-
from typing import ClassVar, Optional
5+
from typing import Optional
66

77
import torch
88

@@ -62,7 +62,7 @@ class GDNAttentionMetadataBuilder(
6262

6363
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
6464

65-
reorder_batch_threshold: ClassVar[int] = 1
65+
reorder_batch_threshold: int = 1
6666

6767
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
6868
vllm_config: VllmConfig, device: torch.device):
@@ -76,7 +76,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
7676
else:
7777
self.num_spec = 0
7878
self.use_spec_decode = self.num_spec > 0
79-
self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc]
79+
self._init_reorder_batch_threshold(1, self.use_spec_decode)
8080

8181
self.use_full_cuda_graph = \
8282
self.compilation_config.cudagraph_mode.has_full_cudagraphs()

vllm/v1/attention/backends/linear_attn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from dataclasses import dataclass
4-
from typing import ClassVar
54

65
import torch
76

@@ -35,7 +34,7 @@ class LinearAttentionMetadata:
3534
class LinearAttentionMetadataBuilder(
3635
AttentionMetadataBuilder[LinearAttentionMetadata]):
3736

38-
reorder_batch_threshold: ClassVar[int] = 1
37+
reorder_batch_threshold: int = 1
3938

4039
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
4140
vllm_config: VllmConfig, device: torch.device):

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
19-
reorder_batch_threshold: ClassVar[int] = 1
19+
reorder_batch_threshold: int = 1
2020
cudagraph_support: ClassVar[AttentionCGSupport] = \
2121
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
2222

0 commit comments

Comments
 (0)