22
33import gc
44import time
5- from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , cast
5+ from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Union , cast
66
77import numpy as np
88import torch
@@ -149,6 +149,7 @@ def __init__(
149149 self .positions = torch .zeros (self .max_num_tokens ,
150150 dtype = torch .int64 ,
151151 device = self .device )
152+ # self.intermediate_tensors # Set after load_model
152153
153154 # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
154155 if self .uses_mrope :
@@ -869,7 +870,7 @@ def execute_model(
869870 self ,
870871 scheduler_output : "SchedulerOutput" ,
871872 intermediate_tensors : Optional [IntermediateTensors ] = None ,
872- ) -> ModelRunnerOutput :
873+ ) -> Union [ ModelRunnerOutput , torch . Tensor ] :
873874 batch_changed = self ._update_states (scheduler_output )
874875
875876 if self .is_multimodal_model :
@@ -919,6 +920,14 @@ def execute_model(
919920 else :
920921 positions = self .positions [:num_input_tokens ]
921922
923+ if get_pp_group ().is_first_rank :
924+ intermediate_tensors = None
925+ else :
926+ intermediate_tensors = IntermediateTensors ({
927+ k : v [:num_input_tokens ]
928+ for k , v in self .intermediate_tensors .items ()
929+ })
930+
922931 # Run the decoder.
923932 # Use persistent buffers for CUDA graphs.
924933 with set_forward_context (attn_metadata , self .vllm_config ):
@@ -931,7 +940,9 @@ def execute_model(
931940 inputs_embeds = inputs_embeds ,
932941 )
933942 if not get_pp_group ().is_last_rank :
943+ # For mid-pipeline stages, return the hidden states.
934944 return hidden_states
945+
935946 hidden_states = hidden_states [:num_scheduled_tokens ]
936947 sample_hidden_states = hidden_states [logits_indices ]
937948 logits = self .model .compute_logits (sample_hidden_states , None )
@@ -1118,12 +1129,21 @@ def _dummy_run(
11181129 positions = self .mrope_positions [:, :num_tokens ]
11191130 else :
11201131 positions = self .positions [:num_tokens ]
1121- intermediate_tensors = None
1122- if not get_pp_group ().is_first_rank :
1123- intermediate_tensors = self .model .make_empty_intermediate_tensors (
1124- batch_size = num_tokens ,
1125- dtype = self .model_config .dtype ,
1126- device = self .device )
1132+
1133+ if get_pp_group ().is_first_rank :
1134+ intermediate_tensors = None
1135+ else :
1136+ if not hasattr (self , "intermediate_tensors" ):
1137+ self .intermediate_tensors = (
1138+ self .model .make_empty_intermediate_tensors (
1139+ batch_size = self .max_num_tokens ,
1140+ dtype = self .model_config .dtype ,
1141+ device = self .device ))
1142+ intermediate_tensors = IntermediateTensors ({
1143+ k : v [:num_tokens ]
1144+ for k , v in self .intermediate_tensors .items ()
1145+ })
1146+
11271147 with set_forward_context (None , self .vllm_config ):
11281148 hidden_states = model (
11291149 input_ids = input_ids ,
0 commit comments