2121from vllm .sequence import IntermediateTensors
2222from vllm .tasks import SupportedTask
2323from vllm .utils .math_utils import cdiv
24+ from vllm .v1 .core .sched .output import GrammarOutput
2425from vllm .v1 .core .sched .output import SchedulerOutput as VllmSchedulerOutput
2526from vllm .v1 .kv_cache_interface import KVCacheConfig
2627from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput ,
27- DraftTokenIds , ModelRunnerOutput )
28+ DraftTokenIds , KVConnectorOutput ,
29+ ModelRunnerOutput )
2830from vllm .v1 .request import Request
2931from vllm .v1 .spec_decode .ngram_proposer import NgramProposer
3032from vllm .v1 .worker .kv_connector_model_runner_mixin import \
5153from tpu_inference .runner .multimodal_manager import MultiModalManager
5254from tpu_inference .runner .persistent_batch_manager import \
5355 PersistentBatchManager
54- from tpu_inference .runner .speculative_decoding_manager import \
55- SpeculativeDecodingManager
56+ from tpu_inference .runner .speculative_decoding_manager import (
57+ SpecDecodeMetadata , SpeculativeDecodingManager )
5658from tpu_inference .runner .structured_decoding_manager import \
5759 StructuredDecodingManager
5860from tpu_inference .spec_decode .jax .eagle3 import Eagle3Proposer
@@ -126,6 +128,21 @@ class AsyncPreResults:
126128 placeholder_req_id_to_index : dict [str , int ]
127129
128130
131+ @dataclass
132+ class ExecuteModelState :
133+ """Ephemeral cached state transferred between execute_model() and
134+ sample_tokens(), after execute_model() returns None."""
135+
136+ scheduler_output : "VllmSchedulerOutput"
137+ attn_metadata : AttentionMetadata
138+ input_ids : Optional [jax .Array ]
139+ hidden_states : jax .Array
140+ logits : jax .Array
141+ aux_hidden_states : Optional [jax .Array ]
142+ spec_decode_metadata : Optional [SpecDecodeMetadata ]
143+ kv_connector_output : Optional [KVConnectorOutput ]
144+
145+
129146@functools .partial (jax .jit , donate_argnums = (0 , 1 , 2 ))
130147def _substitute_placeholder_token (
131148 input_ids : jax .Array , token_in_tpu_cur_input_indices : jax .Array ,
@@ -215,6 +232,7 @@ def __init__(
215232
216233 self ._pre_async_results : AsyncPreResults | None = None
217234 self ._substitute_placeholder_token_fn = _substitute_placeholder_token
235+ self .execute_model_state : ExecuteModelState | None = None
218236
219237 def _init_random (self ):
220238 if self .model_config .seed is None :
@@ -430,9 +448,49 @@ def execute_model(
430448 self ,
431449 scheduler_output : "VllmSchedulerOutput" ,
432450 intermediate_tensors : Optional [IntermediateTensors ] = None ,
451+ ) -> ModelRunnerOutput | None :
452+ if self .execute_model_state is not None :
453+ raise RuntimeError ("State error: sample_tokens() must be called "
454+ "after execute_model() returns None." )
455+ _ , output = self ._execute_model (scheduler_output )
456+ return output
457+
458+ def sample_tokens (
459+ self ,
460+ grammar_output : "GrammarOutput | None" ,
433461 ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput :
434-
435- return self ._execute_model (scheduler_output )[1 ]
462+ if self .execute_model_state is None :
463+ # This can happen in pipeline parallel case.
464+ return EMPTY_MODEL_RUNNER_OUTPUT
465+
466+ (scheduler_output , attn_metadata , input_ids , hidden_states , logits ,
467+ aux_hidden_states , spec_decode_metadata ,
468+ kv_connector_output ) = (self .execute_model_state .scheduler_output ,
469+ self .execute_model_state .attn_metadata ,
470+ self .execute_model_state .input_ids ,
471+ self .execute_model_state .hidden_states ,
472+ self .execute_model_state .logits ,
473+ self .execute_model_state .aux_hidden_states ,
474+ self .execute_model_state .spec_decode_metadata ,
475+ self .execute_model_state .kv_connector_output )
476+ self .execute_model_state = None
477+
478+ if grammar_output is not None :
479+ (
480+ require_struct_decoding , grammar_bitmask_padded , arange
481+ ) = self .structured_decoding_manager .prepare_structured_decoding_input (
482+ logits , grammar_output )
483+ logits = self .structured_decoding_manager .structured_decode_fn (
484+ require_struct_decoding ,
485+ grammar_bitmask_padded ,
486+ logits ,
487+ arange ,
488+ )
489+ return self ._sample_from_logits (scheduler_output , attn_metadata ,
490+ input_ids , hidden_states , logits ,
491+ aux_hidden_states ,
492+ spec_decode_metadata ,
493+ kv_connector_output )
436494
437495 def _modify_prev_results (self ):
438496 # If copy to host has not been done, we just wait.
@@ -510,8 +568,7 @@ def _update_placeholder(self, discard_sampled_tokens_req_indices,
510568 def _execute_model (
511569 self ,
512570 scheduler_output : "VllmSchedulerOutput" ,
513- ) -> tuple [AttentionMetadata , ModelRunnerOutput
514- | AsyncTPUModelRunnerOutput ]:
571+ ) -> tuple [AttentionMetadata , ModelRunnerOutput | None ]:
515572 self .persistent_batch_manager .update_states (
516573 scheduler_output , self .get_mrope_input_positions_fn )
517574 if not scheduler_output .total_num_scheduled_tokens :
@@ -529,9 +586,10 @@ def _execute_model(
529586 "Should not schedule a request that does nothing!" )
530587 # raise Exception(
531588 # "Should not schedule a request that does nothing!")
532- return DUMMY_METADATA , EMPTY_MODEL_RUNNER_OUTPUT ,
589+ return DUMMY_METADATA , EMPTY_MODEL_RUNNER_OUTPUT
533590
534- (input_ids , attn_metadata , sampling_metadata , logits_indices ,
591+ # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
592+ (input_ids , attn_metadata , _ , logits_indices ,
535593 spec_decode_metadata ) = self ._prepare_inputs (scheduler_output )
536594
537595 # multi-modal support
@@ -584,51 +642,67 @@ def _execute_model(
584642 hidden_states ,
585643 lora_metadata ,
586644 )
587- if scheduler_output .grammar_bitmask is not None :
588- (
589- require_struct_decoding , grammar_bitmask_padded , arange
590- ) = self .structured_decoding_manager .prepare_structured_decoding_input (
591- logits , scheduler_output )
592- logits = self .structured_decoding_manager .structured_decode_fn (
593- require_struct_decoding ,
594- grammar_bitmask_padded ,
595- logits ,
596- arange ,
597- )
598- tpu_sampling_metadata = sampling_metadata
599- if spec_decode_metadata is None :
600- next_tokens = sample (
601- self .rng_params_for_sampling ,
602- self .mesh ,
603- logits ,
604- tpu_sampling_metadata ,
605- )
606- else :
607- bonus_logits = self ._select_from_array_fn (
608- logits , spec_decode_metadata .bonus_logits_indices )
609- bonus_token_ids = sample (
610- self .rng_params_for_sampling ,
611- self .mesh ,
612- bonus_logits ,
613- tpu_sampling_metadata ,
614- )
615- target_logits = self ._select_from_array_fn (
616- logits , spec_decode_metadata .target_logits_indices )
617- next_tokens = self .rejection_sampler (
618- draft_token_ids = spec_decode_metadata .draft_token_ids ,
619- num_draft_tokens = spec_decode_metadata .draft_lengths ,
620- draft_probs = None ,
621- target_logits = target_logits ,
622- bonus_token_ids = bonus_token_ids ,
623- sampling_metadata = tpu_sampling_metadata ,
624- key = self .rng_params_for_sampling ,
625- )
626645
627- if tpu_sampling_metadata .logprobs :
628- logprobs = self ._compute_and_gather_logprobs (
629- logits , next_tokens , self .model_config .max_logprobs )
630- else :
631- logprobs = None
646+ self .execute_model_state = ExecuteModelState (
647+ scheduler_output = scheduler_output ,
648+ attn_metadata = attn_metadata ,
649+ input_ids = input_ids ,
650+ hidden_states = hidden_states ,
651+ logits = logits ,
652+ aux_hidden_states = aux_hidden_states ,
653+ spec_decode_metadata = spec_decode_metadata ,
654+ kv_connector_output = kv_connector_output ,
655+ )
656+ return attn_metadata , None
657+
658+ def _sample_from_logits (
659+ self ,
660+ scheduler_output : "VllmSchedulerOutput" ,
661+ attn_metadata : AttentionMetadata ,
662+ input_ids : Optional [jax .Array ],
663+ hidden_states : jax .Array ,
664+ logits : jax .Array ,
665+ aux_hidden_states : Optional [jax .Array ],
666+ spec_decode_metadata : Optional [SpecDecodeMetadata ],
667+ kv_connector_output : Optional [KVConnectorOutput ],
668+ ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput :
669+ padded_num_reqs = runner_utils .get_padded_num_reqs_with_upper_limit (
670+ self .input_batch .num_reqs , self .max_num_reqs )
671+ tpu_sampling_metadata = TPUSupportedSamplingMetadata .from_input_batch (
672+ self .mesh , self .input_batch , padded_num_reqs )
673+ if spec_decode_metadata is None :
674+ next_tokens = sample (
675+ self .rng_params_for_sampling ,
676+ self .mesh ,
677+ logits ,
678+ tpu_sampling_metadata ,
679+ )
680+ else :
681+ bonus_logits = self ._select_from_array_fn (
682+ logits , spec_decode_metadata .bonus_logits_indices )
683+ bonus_token_ids = sample (
684+ self .rng_params_for_sampling ,
685+ self .mesh ,
686+ bonus_logits ,
687+ tpu_sampling_metadata ,
688+ )
689+ target_logits = self ._select_from_array_fn (
690+ logits , spec_decode_metadata .target_logits_indices )
691+ next_tokens = self .rejection_sampler (
692+ draft_token_ids = spec_decode_metadata .draft_token_ids ,
693+ num_draft_tokens = spec_decode_metadata .draft_lengths ,
694+ draft_probs = None ,
695+ target_logits = target_logits ,
696+ bonus_token_ids = bonus_token_ids ,
697+ sampling_metadata = tpu_sampling_metadata ,
698+ key = self .rng_params_for_sampling ,
699+ )
700+
701+ if tpu_sampling_metadata .logprobs :
702+ logprobs = self ._compute_and_gather_logprobs (
703+ logits , next_tokens , self .model_config .max_logprobs )
704+ else :
705+ logprobs = None
632706
633707 num_reqs = self .input_batch .num_reqs
634708
@@ -707,7 +781,7 @@ def _execute_model(
707781 async_model_runner_output = AsyncTPUModelRunnerOutput (
708782 model_runner_output , next_tokens , num_reqs ,
709783 discard_sampled_tokens_req_indices )
710- return attn_metadata , async_model_runner_output
784+ return async_model_runner_output
711785
712786 if spec_decode_metadata is None :
713787 next_tokens = np .asarray (jax .device_get (next_tokens ))
@@ -766,7 +840,7 @@ def _execute_model(
766840 pooler_output = [],
767841 kv_connector_output = kv_connector_output ,
768842 )
769- return attn_metadata , model_runner_output
843+ return model_runner_output
770844
771845 @functools .partial (jax .jit , static_argnums = (0 , ))
772846 def _select_from_array_fn (self , array , indices_to_select ):
0 commit comments