Skip to content

Commit 39ddd6e

Browse files
committed
last fixes
1 parent 587c95b commit 39ddd6e

File tree

4 files changed

+39
-2
lines changed

4 files changed

+39
-2
lines changed

src/transformers/models/mllama/modeling_mllama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,9 @@ def _init_weights(self, module):
10811081
elif isinstance(module, MllamaCrossAttentionDecoderLayer):
10821082
module.cross_attn_attn_gate.data.zero_()
10831083
module.cross_attn_mlp_gate.data.zero_()
1084+
elif isinstance(module, MllamaPrecomputedAspectRatioEmbedding):
1085+
if module.is_gated:
1086+
module.gate.data.zero_()
10841087

10851088
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
10861089
def _update_causal_mask(

src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,6 +1738,9 @@ def _init_weights(self, module):
17381738
module.weight.data[module.padding_idx].zero_()
17391739
elif isinstance(module, Phi4MultimodalRMSNorm):
17401740
module.weight.data.fill_(1.0)
1741+
elif isinstance(module, Phi4MultimodalImageEmbedding):
1742+
module.global_img_feature_extensor.data.zero_()
1743+
module.sub_img_feature_extensor.data.zero_()
17411744

17421745

17431746
PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r"""

src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,14 @@
4040
replace_return_docstrings,
4141
)
4242
from ..phi3.configuration_phi3 import Phi3Config
43-
from ..phi3.modeling_phi3 import Phi3DecoderLayer, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm
43+
from ..phi3.modeling_phi3 import (
44+
Phi3DecoderLayer,
45+
Phi3ForCausalLM,
46+
Phi3Model,
47+
Phi3PreTrainedModel,
48+
Phi3RMSNorm,
49+
Phi3RotaryEmbedding,
50+
)
4451
from ..siglip.configuration_siglip import SiglipVisionConfig
4552
from ..siglip.modeling_siglip import (
4653
SiglipEncoder,
@@ -1522,6 +1529,28 @@ def forward(
15221529
"""
15231530

15241531

1532+
class Phi4MultimodalRotaryEmbedding(Phi3RotaryEmbedding):
1533+
pass
1534+
1535+
1536+
class Phi4MultimodalPreTrainedModel(Phi3PreTrainedModel):
1537+
def _init_weights(self, module):
1538+
std = self.config.initializer_range
1539+
if isinstance(module, nn.Linear):
1540+
module.weight.data.normal_(mean=0.0, std=std)
1541+
if module.bias is not None:
1542+
module.bias.data.zero_()
1543+
elif isinstance(module, nn.Embedding):
1544+
module.weight.data.normal_(mean=0.0, std=std)
1545+
if module.padding_idx is not None:
1546+
module.weight.data[module.padding_idx].zero_()
1547+
elif isinstance(module, Phi4MultimodalRMSNorm):
1548+
module.weight.data.fill_(1.0)
1549+
elif isinstance(module, Phi4MultimodalImageEmbedding):
1550+
module.global_img_feature_extensor.data.zero_()
1551+
module.sub_img_feature_extensor.data.zero_()
1552+
1553+
15251554
class Phi4MultimodalModel(Phi3Model, nn.Module):
15261555
"""
15271556
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi4MultimodalMMDecoderLayer`]
@@ -1832,7 +1861,7 @@ def prepare_inputs_for_generation(
18321861
"Phi4MultimodalAudioModel",
18331862
"Phi4MultimodalVisionPreTrainedModel",
18341863
"Phi4MultimodalVisionModel",
1835-
"Phi4MultimodalPreTrainedModel", # noqa
1864+
"Phi4MultimodalPreTrainedModel",
18361865
"Phi4MultimodalModel",
18371866
"Phi4MultimodalForCausalLM",
18381867
"Phi4MultimodalVisionConfig",

src/transformers/models/qwen2_moe/modeling_qwen2_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,8 @@ def _init_weights(self, module):
807807
module.weight.data.normal_(mean=0.0, std=std)
808808
if module.padding_idx is not None:
809809
module.weight.data[module.padding_idx].zero_()
810+
elif isinstance(module, Qwen2MoeRMSNorm):
811+
module.weight.data.fill_(1.0)
810812

811813

812814
QWEN2MOE_INPUTS_DOCSTRING = r"""

0 commit comments

Comments
 (0)