Skip to content

Commit

Permalink
Merge pull request #15 from huggingface/embed-interpolation
Browse files Browse the repository at this point in the history
Interpolate embeddings for 560 size and update integration tests
  • Loading branch information
qubvel authored Sep 25, 2024
2 parents f40ce28 + 6c58488 commit 554ea46
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
54 changes: 50 additions & 4 deletions src/transformers/models/mllama/convert_mllama_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import regex as re
import torch
import torch.nn.functional as F

from transformers import (
GenerationConfig,
Expand Down Expand Up @@ -173,6 +174,38 @@ def compute_intermediate_size(hidden_dim, multiple_of=1024, ffn_dim_multiplier=1
return hidden_dim


def interpolate_positional_embedding(
embeddings: torch.Tensor, vision_tile_size: int, vision_patch_size: int
) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position embeddings, to be able to use the model on higher resolution
images.
"""
cls_embedding, positional_embedding = embeddings[:1], embeddings[1:]
total_num_patches, dim = positional_embedding.shape

# compute current and target number of patches for height and width
num_patches = int(round(total_num_patches**0.5))
new_num_patches = vision_tile_size // vision_patch_size

# Check if the number of patches is already the desired size
if num_patches == new_num_patches:
return embeddings

positional_embedding = positional_embedding.transpose(0, 1)
positional_embedding = positional_embedding.reshape(1, dim, num_patches, num_patches)
positional_embedding = F.interpolate(
positional_embedding,
size=(new_num_patches, new_num_patches),
mode="bicubic",
align_corners=False,
)
positional_embedding = positional_embedding.reshape(dim, -1).transpose(0, 1)

embeddings = torch.cat([cls_embedding, positional_embedding], dim=0)
return embeddings


def write_model(
model_path,
input_base_path,
Expand Down Expand Up @@ -364,10 +397,23 @@ def write_model(
elif new_key.endswith("gate"):
state_dict[new_key] = current_parameter[0].view(1)

elif (
"tile_positional_embedding.embedding" in new_key or "gated_positional_embedding.tile_embedding" in new_key
):
# pre-compute the embeddings
elif "vision_model.gated_positional_embedding.embedding" in new_key:
current_parameter = interpolate_positional_embedding(
current_parameter, vision_tile_size, vision_patch_size
)
state_dict[new_key] = current_parameter

elif "vision_model.gated_positional_embedding.tile_embedding.weight" in new_key:
current_parameter = current_parameter.permute(2, 0, 1, 3).flatten(1)
current_parameter = interpolate_positional_embedding(
current_parameter, vision_tile_size, vision_patch_size
)
current_parameter = current_parameter.reshape(
-1, vision_max_num_tiles, vision_max_num_tiles, vision_dim
).permute(1, 2, 0, 3)
state_dict[new_key] = pre_compute_positional_embedding(current_parameter)

elif "tile_positional_embedding.embedding" in new_key:
state_dict[new_key] = pre_compute_positional_embedding(current_parameter)

elif new_key != "":
Expand Down
8 changes: 4 additions & 4 deletions tests/models/mllama/test_modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def test_11b_model_integration_generate(self):
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)

decoded_output = processor.decode(output[0], skip_special_tokens=True)
expected_output = "If I had to write a haiku for this one, it would be:.\\nLong exposure dock.\\nWhistler, British Columbia.\\nNikon D800E" # fmt: skip
expected_output = "If I had to write a haiku for this one, it would be:.\\nI'm not a poet.\\nBut I'm a photographer.\\nAnd I'm a" # fmt: skip

self.assertEqual(
decoded_output,
Expand Down Expand Up @@ -470,7 +470,7 @@ def test_11b_model_integration_forward(self):
output = model(**inputs)

actual_logits = output.logits[0, -1, :5].cpu()
expected_logits = torch.tensor([8.5781, 7.6719, 4.6406, 0.7192, 3.0918])
expected_logits = torch.tensor([8.3594, 7.7148, 4.7266, 0.7803, 3.1504])
self.assertTrue(
torch.allclose(actual_logits, expected_logits, atol=0.1),
f"Actual logits: {actual_logits}"
Expand Down Expand Up @@ -506,7 +506,7 @@ def test_11b_model_integration_batched_generate(self):

# Check first output
decoded_output = processor.decode(output[0], skip_special_tokens=True)
expected_output = "If I had to write a haiku for this one, it would be:.\\nLong exposure dock.\\nWhistler, British Columbia.\\nNikon D800E" # fmt: skip
expected_output = "If I had to write a haiku for this one, it would be:.\\nI'm not a poet.\\nBut I'm a photographer.\\nAnd I'm a" # fmt: skip

self.assertEqual(
decoded_output,
Expand All @@ -516,7 +516,7 @@ def test_11b_model_integration_batched_generate(self):

# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = "This image shows is a photo of a stop sign in front of a Chinese arch. The stop sign is red and white, and the arch" # fmt: skip
expected_output = "This image shows is a photograph of a stop sign in front of a Chinese archway. The stop sign is red with white letters and is" # fmt: skip

self.assertEqual(
decoded_output,
Expand Down

0 comments on commit 554ea46

Please sign in to comment.