2727from vllm .platforms import current_platform
2828from vllm .triton_utils import tl , triton
2929from vllm .utils import cdiv , is_pin_memory_available
30- from vllm .utils .flashinfer import (flashinfer_disable_q_quantization ,
30+ from vllm .utils .flashinfer import (can_use_trtllm_attention ,
31+ flashinfer_disable_q_quantization ,
3132 supports_trtllm_attention ,
3233 use_trtllm_attention )
3334from vllm .v1 .attention .backends .flash_attn import use_cascade_attention
@@ -225,6 +226,7 @@ class FlashInferMetadata:
225226
226227 # For flashinfer trtllm batch decode
227228 max_q_len : int
229+ max_q_len_prefill : int
228230 max_seq_len : int
229231 seq_lens : torch .Tensor
230232 block_table_tensor : torch .Tensor
@@ -252,7 +254,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
252254 cudagraph_support : ClassVar [AttentionCGSupport ] = \
253255 AttentionCGSupport .UNIFORM_SINGLE_TOKEN_DECODE
254256
255- reorder_batch_threshold : ClassVar [ int ] = 1
257+ reorder_batch_threshold : int = 1
256258
257259 def __init__ (self , kv_cache_spec : AttentionSpec , layer_names : list [str ],
258260 vllm_config : VllmConfig , device : torch .device ):
@@ -311,6 +313,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
311313 else :
312314 self .q_data_type = self .model_config .dtype
313315
316+ supports_spec_as_decode = \
317+ can_use_trtllm_attention (self .num_qo_heads , self .num_kv_heads )
318+ self ._init_reorder_batch_threshold (1 , supports_spec_as_decode )
319+
314320 self ._cascade_wrapper = None # Wrapper for cascade attention
315321
316322 # Global hyperparameters shared by all attention layers
@@ -425,7 +431,8 @@ def build(self,
425431 num_actual_tokens = common_attn_metadata .num_actual_tokens
426432 num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens = \
427433 split_decodes_and_prefills (common_attn_metadata ,
428- decode_threshold = self .reorder_batch_threshold )
434+ decode_threshold = self .reorder_batch_threshold ,
435+ require_uniform = True )
429436
430437 page_size = self .page_size
431438 max_q_len = common_attn_metadata .max_query_len
@@ -503,20 +510,25 @@ def build(self,
503510 paged_kv_last_page_len_np ,
504511 )
505512
513+ uses_spec_reorder = self .reorder_batch_threshold > 1
506514 prefill_use_trtllm = use_trtllm_attention (self .num_qo_heads ,
507515 self .num_kv_heads ,
508516 num_prefill_tokens ,
509517 max_seq_len ,
510518 self .cache_dtype ,
511519 self .q_data_type ,
512- has_sinks = self .has_sinks )
520+ is_prefill = True ,
521+ has_sinks = self .has_sinks ,
522+ has_spec = uses_spec_reorder )
513523 decode_use_trtllm = use_trtllm_attention (self .num_qo_heads ,
514524 self .num_kv_heads ,
515525 num_decode_tokens ,
516526 max_seq_len ,
517527 self .cache_dtype ,
518528 self .q_data_type ,
519- has_sinks = self .has_sinks )
529+ is_prefill = False ,
530+ has_sinks = self .has_sinks ,
531+ has_spec = uses_spec_reorder )
520532 if self .dcp_world_size > 1 and (prefill_use_trtllm
521533 or decode_use_trtllm ):
522534 raise NotImplementedError (
@@ -538,6 +550,7 @@ def build(self,
538550 q_data_type = self .q_data_type ,
539551 slot_mapping = common_attn_metadata .slot_mapping ,
540552 max_q_len = max_q_len ,
553+ max_q_len_prefill = max_q_len ,
541554 max_seq_len = max_seq_len ,
542555 seq_lens = seq_lens ,
543556 block_table_tensor = block_table_tensor ,
@@ -595,6 +608,15 @@ def build(self,
595608 prefill_start ]
596609 paged_kv_indptr_cpu = paged_kv_indptr_cpu [prefill_start :]
597610
611+ # Recompute max_q_len for the slice of requests we are using
612+ # for prefills. This can be different from max_q_len when
613+ # we have a non-uniform batch with some short decodes offloaded
614+ # to the prefill pathway
615+ query_lens_prefill = qo_indptr_cpu [1 :] - qo_indptr_cpu [:- 1 ]
616+ attn_metadata .max_q_len_prefill = \
617+ int (query_lens_prefill .max ().item ())
618+
619+
598620 if self .dcp_world_size > 1 :
599621 # init custom mask for interleave kv cache
600622 mask_arr = []
@@ -660,7 +682,7 @@ def build(self,
660682 num_decodes <= self ._decode_cudagraph_max_bs )
661683 if use_cudagraph :
662684 num_input_tokens = (
663- self .vllm_config .pad_for_cudagraph (num_decodes ))
685+ self .vllm_config .pad_for_cudagraph (num_decode_tokens ))
664686 # Carefully fulfill the padding region with reasonable value
665687 # on cpu.
666688 # Make sure paged_kv_indptr_cpu is not decreasing
@@ -674,7 +696,7 @@ def build(self,
674696 num_decodes :num_input_tokens ].fill_ (1 )
675697
676698 else :
677- num_input_tokens = num_decodes
699+ num_input_tokens = num_decode_tokens
678700
679701 attn_metadata .decode_wrapper = self ._get_decode_wrapper (
680702 num_input_tokens , use_cudagraph )
@@ -897,6 +919,9 @@ def forward(
897919 output .copy_ (attn_metadata .cascade_wrapper .run (query , kv_cache ))
898920 return output
899921
922+ # When using spec decoding, num_decodes can be < num_decode_tokens
923+ # because some decode requests may have more than one query token.
924+ num_decodes = attn_metadata .num_decodes
900925 num_decode_tokens = attn_metadata .num_decode_tokens
901926 num_prefill_tokens = attn_metadata .num_prefill_tokens
902927
@@ -948,8 +973,8 @@ def forward(
948973 prefill_query = prefill_query .contiguous ()
949974 workspace_buffer = _get_trtllm_gen_workspace_buffer ()
950975 block_tables_prefill = attn_metadata .block_table_tensor [
951- num_decode_tokens :]
952- seq_lens_prefill = attn_metadata .seq_lens [num_decode_tokens :]
976+ num_decodes :]
977+ seq_lens_prefill = attn_metadata .seq_lens [num_decodes :]
953978
954979 # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
955980 assert get_kv_cache_layout () == "HND"
@@ -993,7 +1018,7 @@ def forward(
9931018 workspace_buffer = workspace_buffer ,
9941019 block_tables = mock_block_table ,
9951020 seq_lens = seq_lens_prefill ,
996- max_q_len = attn_metadata .max_q_len ,
1021+ max_q_len = attn_metadata .max_q_len_prefill ,
9971022 max_kv_len = attn_metadata .max_seq_len ,
9981023 bmm1_scale = self .bmm1_scale ,
9991024 bmm2_scale = self .bmm2_scale ,
@@ -1071,6 +1096,14 @@ def forward(
10711096 assert self .o_sf_scale is None
10721097 out = output [:num_decode_tokens ]
10731098
1099+ if num_decode_tokens % attn_metadata .num_decodes != 0 :
1100+ # This gets triggered when the dummy_run forces
1101+ # attention to be initialized with q_len = 0
1102+ q_len_per_req = 1
1103+ else :
1104+ q_len_per_req = \
1105+ num_decode_tokens // attn_metadata .num_decodes
1106+
10741107 trtllm_batch_decode_with_kv_cache (
10751108 query = decode_query ,
10761109 kv_cache = kv_cache_permute ,
@@ -1084,7 +1117,7 @@ def forward(
10841117 sinks = self .sinks ,
10851118 o_sf_scale = self .o_sf_scale ,
10861119 out = out ,
1087- )
1120+ q_len_per_req = q_len_per_req )
10881121 return output_padded
10891122
10901123
0 commit comments