@@ -61,11 +61,13 @@ def forward(
6161 self ,
6262 input_ids : torch .Tensor ,
6363 position_ids : torch .Tensor ,
64+ inputs_embeds : Optional [torch .Tensor ] = None ,
6465 ) -> torch .Tensor :
65-
6666 token_type_ids = _decode_token_type_ids (input_ids )
6767
68- inputs_embeds = self .word_embeddings (input_ids )
68+ if inputs_embeds is None :
69+ inputs_embeds = self .word_embeddings (input_ids )
70+
6971 position_embeddings = self .position_embeddings (position_ids )
7072
7173 token_type_embeddings = self .token_type_embeddings (token_type_ids )
@@ -358,11 +360,12 @@ def forward(
358360 intermediate_tensors : Optional [IntermediateTensors ] = None ,
359361 inputs_embeds : Optional [torch .Tensor ] = None ,
360362 ) -> torch .Tensor :
361- if inputs_embeds is not None :
362- hidden_states = inputs_embeds
363- else :
364- hidden_states = self .embeddings (input_ids = input_ids ,
365- position_ids = positions )
363+ hidden_states = self .embeddings (
364+ input_ids = input_ids ,
365+ position_ids = positions ,
366+ inputs_embeds = inputs_embeds ,
367+ )
368+
366369 return self .encoder (hidden_states )
367370
368371 def _load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
0 commit comments