diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 12b7ce18fbc2..d3995b619d31 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,7 +2,7 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast import numpy as np import torch @@ -149,6 +149,7 @@ def __init__( self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) + # self.intermediate_tensors # Set after load_model # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -869,7 +870,7 @@ def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> ModelRunnerOutput: + ) -> Union[ModelRunnerOutput, torch.Tensor]: batch_changed = self._update_states(scheduler_output) if self.is_multimodal_model: @@ -919,6 +920,14 @@ def execute_model( else: positions = self.positions[:num_input_tokens] + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + intermediate_tensors = IntermediateTensors({ + k: v[:num_input_tokens] + for k, v in self.intermediate_tensors.items() + }) + # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): @@ -931,7 +940,9 @@ def execute_model( inputs_embeds=inputs_embeds, ) if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. return hidden_states + hidden_states = hidden_states[:num_scheduled_tokens] sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1118,12 +1129,21 @@ def _dummy_run( positions = self.mrope_positions[:, :num_tokens] else: positions = self.positions[:num_tokens] - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = self.model.make_empty_intermediate_tensors( - batch_size=num_tokens, - dtype=self.model_config.dtype, - device=self.device) + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if not hasattr(self, "intermediate_tensors"): + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + intermediate_tensors = IntermediateTensors({ + k: v[:num_tokens] + for k, v in self.intermediate_tensors.items() + }) + with set_forward_context(None, self.vllm_config): hidden_states = model( input_ids=input_ids,