@@ -143,9 +143,6 @@ def forward(
143143 image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
144144 The temporal, height and width of feature shape of each image in LLM.
145145 """
146- if pixel_values is not None :
147- pixel_values = pixel_values .to (dtype = self .dtype ) # (batch_size, max_num_patches, pixel_values)
148-
149146 # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
150147 if pixel_values is not None and image_grid_thw is not None :
151148 # NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size)
@@ -182,9 +179,6 @@ def forward(
182179 image_embeds = image_embeds .to (inputs_embeds .device , inputs_embeds .dtype )
183180 inputs_embeds = inputs_embeds .masked_scatter (image_mask , image_embeds )
184181
185- if attention_mask is not None :
186- attention_mask = attention_mask .to (inputs_embeds .device )
187-
188182 vlm_output = self .vlm .model (
189183 input_ids = None ,
190184 position_ids = position_ids ,
@@ -201,7 +195,8 @@ def forward(
201195 vlm_hidden_states = vlm_output .hidden_states if output_hidden_states else None
202196
203197 last_hidden_states = vlm_output [0 ] # (batch_size, sequence_length, hidden_size)
204- embeddings = self .embedding_proj_layer (last_hidden_states ) # (batch_size, sequence_length, dim)
198+ proj_dtype = self .embedding_proj_layer .weight .dtype
199+ embeddings = self .embedding_proj_layer (last_hidden_states .to (proj_dtype )) # (batch_size, sequence_length, dim)
205200
206201 # L2 normalization
207202 embeddings = embeddings / embeddings .norm (dim = - 1 , keepdim = True ) # (batch_size, sequence_length, dim)
0 commit comments