diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index c984845204c4..2ec3edc5a0a7 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -61,11 +61,13 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - token_type_ids = _decode_token_type_ids(input_ids) - inputs_embeds = self.word_embeddings(input_ids) + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) @@ -358,11 +360,12 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embeddings(input_ids=input_ids, - position_ids=positions) + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=positions, + inputs_embeds=inputs_embeds, + ) + return self.encoder(hidden_states) def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 53e698c4fa80..a13042a6367c 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -56,11 +56,13 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - token_type_ids = _decode_token_type_ids(input_ids) - inputs_embeds = self.word_embeddings(input_ids) + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)