Skip to content

Commit

Permalink
Merge pull request #10 from huggingface/refactor-configuration
Browse files Browse the repository at this point in the history
Refactor configuration
  • Loading branch information
qubvel authored Sep 16, 2024
2 parents f51ccec + 41a4d8a commit 0350540
Show file tree
Hide file tree
Showing 13 changed files with 324 additions and 392 deletions.
8 changes: 2 additions & 6 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2690,12 +2690,7 @@
_import_structure["models.mllama"].extend(
[
"MllamaForConditionalGeneration",
"MllamaPreTrainedModel",
]
)
_import_structure["models.mllama"].extend(
[
"MllamaForConditionalGeneration",
"MllamaForCausalLM",
"MllamaPreTrainedModel",
]
)
Expand Down Expand Up @@ -7265,6 +7260,7 @@
)
from .models.mllama import (
MllamaForConditionalGeneration,
MllamaForCausalLM,
MllamaPreTrainedModel,
)
from .models.mobilebert import (
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@
("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForPreTraining"),
("mllama", "MllamaForConditionalGeneration"),
("mllama", "MllamaForConditionalGeneration"),
("mobilebert", "MobileBertForPreTraining"),
("mpnet", "MPNetForMaskedLM"),
("mpt", "MptForCausalLM"),
Expand Down Expand Up @@ -496,6 +495,7 @@
("megatron-bert", "MegatronBertForCausalLM"),
("mistral", "MistralForCausalLM"),
("mixtral", "MixtralForCausalLM"),
("mllama", "MllamaForCausalLM"),
("mpt", "MptForCausalLM"),
("musicgen", "MusicgenForCausalLM"),
("musicgen_melody", "MusicgenMelodyForCausalLM"),
Expand Down Expand Up @@ -734,7 +734,6 @@
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
("mllama", "MllamaForConditionalGeneration"),
("mllama", "MllamaForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"),
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/mllama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
else:
_import_structure["modeling_mllama"] = [
"MllamaForConditionalGeneration",
"MllamaForCausalLM",
"MllamaPreTrainedModel",
]

Expand All @@ -59,6 +60,7 @@
else:
from .modeling_mllama import (
MllamaForConditionalGeneration,
MllamaForCausalLM,
MllamaPreTrainedModel,
)

Expand Down
17 changes: 4 additions & 13 deletions src/transformers/models/mllama/configuration_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,41 +87,32 @@ def __init__(
layer_norm_eps=1e-6,
attention_dropout=0.0,
num_global_layers=8,
vision_chunk_size=448,
projection_dim=4096,
vision_input_dim=1280,
vision_output_dim=7680,
return_intermediate=None,
intermediate_layers_indices=[3, 7, 15, 23, 30],
max_num_tiles=4, # same as vision max num chunks? yes ;-)
norm_eps=1.0e-5,
in_channels=3,
supported_aspect_ratios=None,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.num_channels = num_channels
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.vision_output_dim = vision_output_dim
self.vision_chunk_size = vision_chunk_size
self.patch_size = patch_size
self.projection_dim = projection_dim
self.vision_input_dim = vision_input_dim
if return_intermediate is None:
return_intermediate = [3, 7, 15, 23, 30]
self.return_intermediate = return_intermediate
self.intermediate_layers_indices = intermediate_layers_indices
self.num_global_layers = num_global_layers
self.max_num_tiles = max_num_tiles
self.norm_eps = norm_eps
self.in_channels = in_channels

self.hidden_size = vision_input_dim
self.attention_heads = num_attention_heads
self.intermediate_size = 4 * vision_input_dim
self.hidden_act = hidden_act
self.supported_aspect_ratios = supported_aspect_ratios

@property
Expand Down
24 changes: 16 additions & 8 deletions src/transformers/models/mllama/convert_mllama_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,14 @@
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).ln_1": r"vision_model.\1.layers.\2.input_layernorm",
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).ln_2": r"vision_model.\1.layers.\2.post_attention_layernorm",
r"vision_model.vision_encoder.global_transformer.resblocks.(\d+).(gate_ffn|gate_attn)": r"vision_model.global_transformer.layers.\1.\2",
r'vision_model.vision_encoder.ln_(pre|post).(weight|bias)': r'vision_model.vision_encoder.ln_\1.\2',
r'vision_model.vision_encoder.ln_(pre|post).(weight|bias)': r'vision_model.vision_encoder.layernorm_\1.\2',
r'vision_model.vision_encoder.positional_embedding\b': r'vision_model.gated_positional_embedding.embedding',
r'vision_model.vision_encoder.gated_positional_embedding\b': r'vision_model.gated_positional_embedding.tile_embedding',
r'vision_model.vision_encoder.gated_positional_embedding\b': r'vision_model.gated_positional_embedding.tile_embedding.weight',
r'vision_model.vision_encoder.gated_positional_embedding_gate': r'vision_model.gated_positional_embedding.gate',
r"vision_model.vision_encoder.pre_tile_pos_embed.embedding": r"vision_model.pre_tile_positional_embedding.embedding.weight",
r"vision_model.vision_encoder.post_tile_pos_embed.embedding": r"vision_model.post_tile_positional_embedding.embedding.weight",
r"vision_model.vision_encoder.pre_tile_pos_embed.gate": r"vision_model.pre_tile_positional_embedding.gate",
r"vision_model.vision_encoder.post_tile_pos_embed.gate": r"vision_model.post_tile_positional_embedding.gate",
r"vision_model.vision_encoder.(?=\w)": r"vision_model.",
}
# fmt: on
Expand Down Expand Up @@ -159,6 +163,7 @@ def pre_compute_positional_embedding(embedding):
aspect_ratio_id = i + 1 # we keep 0 index for padding
current_embedding = embedding[:height, :width].reshape(height * width, num_patches, hidden_size)
precomputed_embeddings[aspect_ratio_id, : height * width] = current_embedding
precomputed_embeddings = precomputed_embeddings.flatten(1)
return precomputed_embeddings


Expand Down Expand Up @@ -230,6 +235,7 @@ def write_model(
num_channels = 3
# intermediate size: 28672 for 90B, 5120 for 11B
intermediate_size = compute_intermediate_size(dim, multiple_of=params["multiple_of"])
intermediate_layers_indices = [3, 7, 15, 23, 30] # TODO: Check for 90B model

# vision model
n_layers_vision = 32 # constant
Expand Down Expand Up @@ -338,7 +344,9 @@ def write_model(
elif new_key.endswith("gate"):
state_dict[new_key] = current_parameter[0].view(1)

elif "tile_pos_embed.embedding" in new_key or "gated_positional_embedding.tile_embedding" in new_key:
elif (
"tile_positional_embedding.embedding" in new_key or "gated_positional_embedding.tile_embedding" in new_key
):
# pre-compute the embeddings
state_dict[new_key] = pre_compute_positional_embedding(current_parameter)

Expand All @@ -360,20 +368,20 @@ def write_model(
# Write configs
config_parameters = {CONFIG_KEY_MAPPING[key]: params[key] for key in CONFIG_KEY_MAPPING.keys()}
vision_config = MllamaVisionConfig(
hidden_size=dim_vision, # Constant, taken directly from your notes
intermediate_size=dim_vision * 4,
num_hidden_layers=n_layers_vision,
vision_input_dim=dim_vision, # Constant, taken directly from your notes
return_intermediate=[3, 7, 15, 23, 30], # Based on return_intermediate indices
num_global_layers=n_layers_vision_global,
vision_chunk_size=params["vision_chunk_size"],
num_attention_heads=n_heads_vision,
num_global_layers=n_layers_vision_global,
intermediate_layers_indices=intermediate_layers_indices, # Based on return_intermediate indices
image_size=params["vision_chunk_size"],
max_num_tiles=4,
supported_aspect_ratios=get_all_supported_aspect_ratios(4),
)
text_config = MllamaTextConfig(
**config_parameters,
num_hidden_layers=len(cross_layer_shift) + n_layers,
cross_attention_layers=cross_layer_shift,
vision_input_dim=dim_vision, # Constant, aligned with vision config
attention_bias=False, # Constant set to False
tie_word_embeddings=False, # Constant set to False
intermediate_size=intermediate_size,
Expand Down
130 changes: 0 additions & 130 deletions src/transformers/models/mllama/dummy_convert.py

This file was deleted.

2 changes: 0 additions & 2 deletions src/transformers/models/mllama/image_processing_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,6 @@ def pack_aspect_ratios(aspect_ratios: List[List[Tuple[int, int]]], pad_value: in
The aspect ratios stacked into a numpy array with shape (batch_size, max_num_images, 2).
"""
batch_size = len(aspect_ratios)

# TODO: in original code there is also max_images = max(max_images, 1)
max_num_images = max([len(row) for row in aspect_ratios])

aspect_ratios_stacked = np.full((batch_size, max_num_images, 2), pad_value, dtype=np.int64)
Expand Down
Loading

0 comments on commit 0350540

Please sign in to comment.