Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,12 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
embeddings = patch_embeds.flatten(2).transpose(1, 2)

max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
)

for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
Expand All @@ -158,9 +162,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids

position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings

Expand Down
11 changes: 7 additions & 4 deletions src/transformers/models/idefics3/modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,12 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
embeddings = patch_embeds.flatten(2).transpose(1, 2)

max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
)

for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
Expand All @@ -157,9 +161,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids

position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings

Expand Down
11 changes: 7 additions & 4 deletions src/transformers/models/smolvlm/modeling_smolvlm.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_export_smolvlm_vision_encoder does not seem to have been triggered

Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,12 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
embeddings = patch_embeds.flatten(2).transpose(1, 2)

max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
)

for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
Expand All @@ -152,9 +156,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids

position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings

Expand Down