Skip to content

Commit cccef4b

Browse files
authored
Fix dtype in Paligemma (#40912)
* fix dtypes * fix copies * delete unused attr
1 parent beb09cb commit cccef4b

File tree

6 files changed

+17
-17
lines changed

6 files changed

+17
-17
lines changed

src/transformers/models/colpali/modeling_colpali.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def forward(
156156
vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None
157157

158158
last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
159-
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
159+
proj_dtype = self.embedding_proj_layer.weight.dtype
160+
embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) # (batch_size, sequence_length, dim)
160161

161162
# L2 normalization
162163
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)

src/transformers/models/colqwen2/modeling_colqwen2.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,6 @@ def forward(
143143
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
144144
The temporal, height and width of feature shape of each image in LLM.
145145
"""
146-
if pixel_values is not None:
147-
pixel_values = pixel_values.to(dtype=self.dtype) # (batch_size, max_num_patches, pixel_values)
148-
149146
# Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
150147
if pixel_values is not None and image_grid_thw is not None:
151148
# NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size)
@@ -182,9 +179,6 @@ def forward(
182179
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
183180
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
184181

185-
if attention_mask is not None:
186-
attention_mask = attention_mask.to(inputs_embeds.device)
187-
188182
vlm_output = self.vlm.model(
189183
input_ids=None,
190184
position_ids=position_ids,
@@ -201,7 +195,8 @@ def forward(
201195
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
202196

203197
last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
204-
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
198+
proj_dtype = self.embedding_proj_layer.weight.dtype
199+
embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) # (batch_size, sequence_length, dim)
205200

206201
# L2 normalization
207202
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)

src/transformers/models/colqwen2/modular_colqwen2.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,6 @@ def forward(
336336
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
337337
The temporal, height and width of feature shape of each image in LLM.
338338
"""
339-
if pixel_values is not None:
340-
pixel_values = pixel_values.to(dtype=self.dtype) # (batch_size, max_num_patches, pixel_values)
341-
342339
# Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
343340
if pixel_values is not None and image_grid_thw is not None:
344341
# NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size)
@@ -375,9 +372,6 @@ def forward(
375372
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
376373
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
377374

378-
if attention_mask is not None:
379-
attention_mask = attention_mask.to(inputs_embeds.device)
380-
381375
vlm_output = self.vlm.model(
382376
input_ids=None,
383377
position_ids=position_ids,
@@ -394,7 +388,8 @@ def forward(
394388
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
395389

396390
last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
397-
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
391+
proj_dtype = self.embedding_proj_layer.weight.dtype
392+
embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) # (batch_size, sequence_length, dim)
398393

399394
# L2 normalization
400395
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)

src/transformers/models/gemma3/modular_gemma3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,10 @@ class Gemma3Model(PaliGemmaModel):
756756
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
757757
accepts_loss_kwargs = False
758758

759+
def __init__(self, config: Gemma3Config):
760+
super().__init__(config)
761+
del self.text_config_dtype
762+
759763
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
760764
"""
761765
Projects the last hidden state from the vision model into language model space.

src/transformers/models/gemma3n/modular_gemma3n.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2241,6 +2241,7 @@ class Gemma3nModel(PaliGemmaModel):
22412241
def __init__(self, config: Gemma3nConfig):
22422242
super().__init__(config)
22432243
del self.multi_modal_projector # Replaced by Gemma3nVisionEmbedder
2244+
del self.text_config_dtype
22442245
self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
22452246
self.audio_tower = AutoModel.from_config(config.audio_config)
22462247
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config)

src/transformers/models/paligemma/modeling_paligemma.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def __init__(self, config: PaliGemmaConfig):
143143
self.language_model = language_model
144144

145145
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
146+
self.text_config_dtype = self.config.get_text_config().dtype or self.dtype
146147
self.post_init()
147148

148149
# Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
@@ -174,7 +175,7 @@ def _update_causal_mask(
174175
return None
175176
is_training = is_training if is_training is not None else self.training
176177
using_static_cache = isinstance(past_key_values, StaticCache)
177-
min_dtype = torch.finfo(self.dtype).min
178+
min_dtype = torch.finfo(self.text_config_dtype).min
178179
if input_tensor is None:
179180
input_tensor = attention_mask
180181

@@ -193,7 +194,10 @@ def _update_causal_mask(
193194
return attention_mask
194195

195196
causal_mask = torch.full(
196-
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
197+
(sequence_length, target_length),
198+
fill_value=min_dtype,
199+
dtype=self.text_config_dtype,
200+
device=cache_position.device,
197201
)
198202
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
199203
if sequence_length != 1:

0 commit comments

Comments
 (0)