2525from vllm .platforms import current_platform
2626from vllm .triton_utils import tl , triton
2727from vllm .utils import cdiv , is_pin_memory_available
28- from vllm .utils .flashinfer import (flashinfer_disable_q_quantization ,
28+ from vllm .utils .flashinfer import (can_use_trtllm_attention ,
29+ flashinfer_disable_q_quantization ,
2930 supports_trtllm_attention ,
3031 use_trtllm_attention )
3132from vllm .v1 .attention .backends .flash_attn import use_cascade_attention
@@ -223,6 +224,7 @@ class FlashInferMetadata:
223224
224225 # For flashinfer trtllm batch decode
225226 max_q_len : int
227+ max_q_len_prefill : int
226228 max_seq_len : int
227229 seq_lens : torch .Tensor
228230 block_table_tensor : torch .Tensor
@@ -250,7 +252,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
250252 cudagraph_support : ClassVar [AttentionCGSupport ] = \
251253 AttentionCGSupport .UNIFORM_SINGLE_TOKEN_DECODE
252254
253- reorder_batch_threshold : ClassVar [ int ] = 1
255+ reorder_batch_threshold : int = 1
254256
255257 def __init__ (self , kv_cache_spec : AttentionSpec , layer_names : list [str ],
256258 vllm_config : VllmConfig , device : torch .device ):
@@ -302,6 +304,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
302304 else :
303305 self .q_data_type = self .model_config .dtype
304306
307+ supports_spec_as_decode = \
308+ can_use_trtllm_attention (self .num_qo_heads , self .num_kv_heads )
309+ self ._init_reorder_batch_threshold (1 , supports_spec_as_decode )
310+
305311 self ._cascade_wrapper = None # Wrapper for cascade attention
306312
307313 # Global hyperparameters shared by all attention layers
@@ -416,7 +422,8 @@ def build(self,
416422 num_actual_tokens = common_attn_metadata .num_actual_tokens
417423 num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens = \
418424 split_decodes_and_prefills (common_attn_metadata ,
419- decode_threshold = self .reorder_batch_threshold )
425+ decode_threshold = self .reorder_batch_threshold ,
426+ require_uniform = True )
420427
421428 page_size = self .page_size
422429 max_q_len = common_attn_metadata .max_query_len
@@ -491,20 +498,25 @@ def build(self,
491498 paged_kv_last_page_len_np ,
492499 )
493500
501+ uses_spec_reorder = self .reorder_batch_threshold > 1
494502 prefill_use_trtllm = use_trtllm_attention (self .num_qo_heads ,
495503 self .num_kv_heads ,
496504 num_prefill_tokens ,
497505 max_seq_len ,
498506 self .cache_dtype ,
499507 self .q_data_type ,
500- has_sinks = self .has_sinks )
508+ is_prefill = True ,
509+ has_sinks = self .has_sinks ,
510+ has_spec = uses_spec_reorder )
501511 decode_use_trtllm = use_trtllm_attention (self .num_qo_heads ,
502512 self .num_kv_heads ,
503513 num_decode_tokens ,
504514 max_seq_len ,
505515 self .cache_dtype ,
506516 self .q_data_type ,
507- has_sinks = self .has_sinks )
517+ is_prefill = False ,
518+ has_sinks = self .has_sinks ,
519+ has_spec = uses_spec_reorder )
508520 if self .has_sinks and not (prefill_use_trtllm and decode_use_trtllm ):
509521 raise NotImplementedError (
510522 "FlashInfer backend currently does not support attention "
@@ -521,6 +533,7 @@ def build(self,
521533 q_data_type = self .q_data_type ,
522534 slot_mapping = common_attn_metadata .slot_mapping ,
523535 max_q_len = max_q_len ,
536+ max_q_len_prefill = max_q_len ,
524537 max_seq_len = max_seq_len ,
525538 seq_lens = seq_lens ,
526539 block_table_tensor = block_table_tensor ,
@@ -577,6 +590,15 @@ def build(self,
577590 qo_indptr_cpu = qo_indptr_cpu [prefill_start :] - qo_indptr_cpu [
578591 prefill_start ]
579592 paged_kv_indptr_cpu = paged_kv_indptr_cpu [prefill_start :]
593+
594+ # Recompute max_q_len for the slice of requests we are using
595+ # for prefills. This can be different from max_q_len when
596+ # we have a non-uniform batch with some short decodes offloaded
597+ # to the prefill pathway
598+ query_lens_prefill = qo_indptr_cpu [1 :] - qo_indptr_cpu [:- 1 ]
599+ attn_metadata .max_q_len_prefill = \
600+ int (query_lens_prefill .max ().item ())
601+
580602 if not attn_metadata .prefill_use_trtllm :
581603 attn_metadata .prefill_wrapper .plan (
582604 qo_indptr_cpu ,
@@ -607,7 +629,7 @@ def build(self,
607629 num_decodes <= self ._decode_cudagraph_max_bs )
608630 if use_cudagraph :
609631 num_input_tokens = (
610- self .vllm_config .pad_for_cudagraph (num_decodes ))
632+ self .vllm_config .pad_for_cudagraph (num_decode_tokens ))
611633 # Carefully fulfill the padding region with reasonable value
612634 # on cpu.
613635 # Make sure paged_kv_indptr_cpu is not decreasing
@@ -621,7 +643,7 @@ def build(self,
621643 num_decodes :num_input_tokens ].fill_ (1 )
622644
623645 else :
624- num_input_tokens = num_decodes
646+ num_input_tokens = num_decode_tokens
625647
626648 attn_metadata .decode_wrapper = self ._get_decode_wrapper (
627649 num_input_tokens , use_cudagraph )
@@ -842,6 +864,9 @@ def forward(
842864 output .copy_ (attn_metadata .cascade_wrapper .run (query , kv_cache ))
843865 return output
844866
867+ # When using spec decoding, num_decodes can be < num_decode_tokens
868+ # because some decode requests may have more than one query token.
869+ num_decodes = attn_metadata .num_decodes
845870 num_decode_tokens = attn_metadata .num_decode_tokens
846871 num_prefill_tokens = attn_metadata .num_prefill_tokens
847872
@@ -874,8 +899,8 @@ def forward(
874899 prefill_query = prefill_query .contiguous ()
875900 workspace_buffer = _get_trtllm_gen_workspace_buffer ()
876901 block_tables_prefill = attn_metadata .block_table_tensor [
877- num_decode_tokens :]
878- seq_lens_prefill = attn_metadata .seq_lens [num_decode_tokens :]
902+ num_decodes :]
903+ seq_lens_prefill = attn_metadata .seq_lens [num_decodes :]
879904
880905 # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
881906 assert get_kv_cache_layout () == "HND"
@@ -919,7 +944,7 @@ def forward(
919944 workspace_buffer = workspace_buffer ,
920945 block_tables = mock_block_table ,
921946 seq_lens = seq_lens_prefill ,
922- max_q_len = attn_metadata .max_q_len ,
947+ max_q_len = attn_metadata .max_q_len_prefill ,
923948 max_kv_len = attn_metadata .max_seq_len ,
924949 bmm1_scale = self .bmm1_scale ,
925950 bmm2_scale = self .bmm2_scale ,
@@ -976,6 +1001,14 @@ def forward(
9761001 assert self .o_sf_scale is None
9771002 out = output [:num_decode_tokens ]
9781003
1004+ if num_decode_tokens % attn_metadata .num_decodes != 0 :
1005+ # This gets triggered when the dummy_run forces
1006+ # attention to be initialized with q_len = 0
1007+ q_len_per_req = 1
1008+ else :
1009+ q_len_per_req = \
1010+ num_decode_tokens // attn_metadata .num_decodes
1011+
9791012 trtllm_batch_decode_with_kv_cache (
9801013 query = decode_query ,
9811014 kv_cache = kv_cache_permute ,
@@ -989,7 +1022,7 @@ def forward(
9891022 sinks = self .sinks ,
9901023 o_sf_scale = self .o_sf_scale ,
9911024 out = out ,
992- )
1025+ q_len_per_req = q_len_per_req )
9931026 return output_padded
9941027
9951028
0 commit comments