diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 17d080fa5a28..04415dace4d9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -345,7 +345,8 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, + list[torch.Tensor]]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 06f7cb08a7c8..56e53ac2b815 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -70,7 +70,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) hidden_states = self.fc( torch.cat((input_embeds, hidden_states), dim=-1)) @@ -133,7 +133,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index ffbb9d75a06b..0b18e4a8fe2f 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -117,7 +117,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) if (hidden_states.shape[-1] != input_embeds.shape[-1]): hidden_states = self.fc(hidden_states) @@ -194,7 +194,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states) def compute_logits(