From b6b59d0d9aa203c2047164745c5726a4a8d2a36d Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 27 Aug 2025 21:14:59 +0800 Subject: [PATCH] Add missing arguments Signed-off-by: cyy --- src/transformers/models/bamba/modular_bamba.py | 2 +- src/transformers/models/cohere2/modular_cohere2.py | 2 +- src/transformers/models/d_fine/modular_d_fine.py | 2 +- .../models/data2vec/modular_data2vec_audio.py | 6 +++--- src/transformers/models/dia/modular_dia.py | 2 +- src/transformers/models/dots1/modular_dots1.py | 2 +- src/transformers/models/ernie4_5/modular_ernie4_5.py | 2 +- .../models/esm/openfold_utils/residue_constants.py | 2 -- src/transformers/models/evolla/modular_evolla.py | 4 ++-- src/transformers/models/falcon_h1/modeling_falcon_h1.py | 2 +- src/transformers/models/falcon_h1/modular_falcon_h1.py | 6 +++--- .../models/falcon_mamba/modular_falcon_mamba.py | 2 +- src/transformers/models/florence2/modular_florence2.py | 2 +- src/transformers/models/gemma/modular_gemma.py | 2 +- src/transformers/models/gemma2/modular_gemma2.py | 2 +- src/transformers/models/gemma3/modular_gemma3.py | 4 ++-- src/transformers/models/gemma3n/modular_gemma3n.py | 4 ++-- src/transformers/models/glm4_moe/modular_glm4_moe.py | 4 ++-- src/transformers/models/glm4v/modular_glm4v.py | 4 ++-- src/transformers/models/got_ocr2/modular_got_ocr2.py | 2 +- src/transformers/models/helium/modular_helium.py | 2 +- .../models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py | 2 +- .../models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py | 2 +- src/transformers/models/internvl/modular_internvl.py | 2 +- src/transformers/models/janus/modular_janus.py | 2 +- src/transformers/models/mistral/modular_mistral.py | 2 +- .../models/mm_grounding_dino/modular_mm_grounding_dino.py | 4 ++-- src/transformers/models/ovis2/modular_ovis2.py | 4 ++-- .../models/phi4_multimodal/modular_phi4_multimodal.py | 8 ++++---- .../models/qwen2_5_omni/modular_qwen2_5_omni.py | 2 +- src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py | 2 +- src/transformers/models/siglip2/modular_siglip2.py | 2 +- src/transformers/models/unispeech/modular_unispeech.py | 2 +- .../models/unispeech_sat/modular_unispeech_sat.py | 4 ++-- .../models/wav2vec2_bert/modular_wav2vec2_bert.py | 6 +++--- .../wav2vec2_conformer/modular_wav2vec2_conformer.py | 2 +- src/transformers/models/xlstm/modeling_xlstm.py | 2 +- 37 files changed, 53 insertions(+), 55 deletions(-) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 03fa58825c50..f0aac49dd1c7 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -708,7 +708,7 @@ class BambaRMSNorm(LlamaRMSNorm): class BambaDecoderLayer(JambaAttentionDecoderLayer): def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"): - super().__init__() + super().__init__(config, layer_idx) del self.self_attn diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index ea72b603f60f..56a72b102203 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -271,7 +271,7 @@ class Cohere2LayerNorm(CohereLayerNorm): pass -class Cohere2Attention(CohereAttention, nn.Module): +class Cohere2Attention(CohereAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None): diff --git a/src/transformers/models/d_fine/modular_d_fine.py b/src/transformers/models/d_fine/modular_d_fine.py index 883f07a9979b..52ac7fef7b0d 100644 --- a/src/transformers/models/d_fine/modular_d_fine.py +++ b/src/transformers/models/d_fine/modular_d_fine.py @@ -898,7 +898,7 @@ def __init__(self, config: DFineConfig): class DFineForObjectDetection(RTDetrForObjectDetection, DFinePreTrainedModel): def __init__(self, config: DFineConfig): - DFinePreTrainedModel.__init__(config) + DFinePreTrainedModel.__init__(self, config) # D-FINE encoder-decoder model self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py index ff93462c5b81..91cb04730e4a 100644 --- a/src/transformers/models/data2vec/modular_data2vec_audio.py +++ b/src/transformers/models/data2vec/modular_data2vec_audio.py @@ -112,7 +112,7 @@ def forward(self, hidden_states): return hidden_states -class Data2VecAudioFeatureEncoder(Wav2Vec2FeatureEncoder, nn.Module): +class Data2VecAudioFeatureEncoder(Wav2Vec2FeatureEncoder): def __init__(self, config): nn.Module.__init__(self) self.conv_layers = nn.ModuleList( @@ -183,7 +183,7 @@ def load_adapter(self): class Data2VecAudioModel(Data2VecAudioPreTrainedModel, Wav2Vec2Model): def __init__(self, config: Data2VecAudioConfig): - Data2VecAudioPreTrainedModel.__init__(config) + Data2VecAudioPreTrainedModel.__init__(self, config) self.config = config self.feature_extractor = Data2VecAudioFeatureEncoder(config) self.feature_projection = Data2VecAudioFeatureProjection(config) @@ -215,7 +215,7 @@ def forward(self, **super_kwargs): class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel, Wav2Vec2ForCTC): def __init__(self, config): - Data2VecAudioPreTrainedModel.__init__(config) + Data2VecAudioPreTrainedModel.__init__(self, config) self.data2vec_audio = Data2VecAudioModel(config) self.dropout = nn.Dropout(config.final_dropout) diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index 267c444d952c..3ee8d15a5873 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -107,7 +107,7 @@ class DiaRotaryEmbedding(LlamaRotaryEmbedding): pass -class DiaSelfAttention(LlamaAttention, nn.Module): +class DiaSelfAttention(LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False): diff --git a/src/transformers/models/dots1/modular_dots1.py b/src/transformers/models/dots1/modular_dots1.py index 9bd14c8dc5a9..345265a14080 100644 --- a/src/transformers/models/dots1/modular_dots1.py +++ b/src/transformers/models/dots1/modular_dots1.py @@ -62,7 +62,7 @@ class Dots1TopkRouter(DeepseekV3TopkRouter): class Dots1DecoderLayer(DeepseekV3DecoderLayer): def __init__(self, config: Dots1Config, layer_idx: int): - super().__init__() + super().__init__(config, layer_idx) self.attention_type = config.layer_types[layer_idx] diff --git a/src/transformers/models/ernie4_5/modular_ernie4_5.py b/src/transformers/models/ernie4_5/modular_ernie4_5.py index f76c7c6bdae7..7cec0232ca68 100644 --- a/src/transformers/models/ernie4_5/modular_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modular_ernie4_5.py @@ -84,7 +84,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class Ernie4_5MLP(LlamaMLP): def __init__(self, config: Ernie4_5Config): - super().__init__() + super().__init__(config) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) diff --git a/src/transformers/models/esm/openfold_utils/residue_constants.py b/src/transformers/models/esm/openfold_utils/residue_constants.py index 9af36d7db74e..e92e65d29bfb 100644 --- a/src/transformers/models/esm/openfold_utils/residue_constants.py +++ b/src/transformers/models/esm/openfold_utils/residue_constants.py @@ -541,7 +541,6 @@ def make_bond_key(atom1_name: str, atom2_name: str) -> str: # A compact atom encoding with 14 columns # pylint: disable=line-too-long -# pylint: disable=bad-whitespace restype_name_to_atom14_names: dict[str, list[str]] = { "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""], "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2", "", "", ""], @@ -566,7 +565,6 @@ def make_bond_key(atom1_name: str, atom2_name: str) -> str: "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""], } # pylint: enable=line-too-long -# pylint: enable=bad-whitespace # This is the standard residue order when coding AA type as a number. diff --git a/src/transformers/models/evolla/modular_evolla.py b/src/transformers/models/evolla/modular_evolla.py index 630c99a382f9..7b4252477d74 100644 --- a/src/transformers/models/evolla/modular_evolla.py +++ b/src/transformers/models/evolla/modular_evolla.py @@ -65,7 +65,7 @@ class EvollaSaProtEmbeddings(EsmEmbeddings): def __init__(self, config): - super().__init__() + super().__init__(config) # remove the position_ids in EsmEmbeddings self.position_ids = None @@ -127,7 +127,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch ) -class EvollaSaProtSelfAttention(EsmSelfAttention, nn.Module): +class EvollaSaProtSelfAttention(EsmSelfAttention): def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False): nn.Module.__init__(self) self.config = config diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 779cfd14bde0..fc90ac622486 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -1007,7 +1007,7 @@ def forward( class FalconH1MLP(nn.Module): - def __init__(self, config: FalconH1Config = None): + def __init__(self, config: FalconH1Config): super().__init__() self.config = config self.hidden_size = config.hidden_size diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 10aaadcdd1e2..1ff8288f9c4f 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -252,7 +252,7 @@ def forward( class FalconH1RMSNormGated(MambaRMSNormGated): def __init__(self, hidden_size, eps=1e-6, n_groups=1, norm_before_gate=True): - super().__init__() + super().__init__(hidden_size=hidden_size, eps=eps) self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps self.n_groups = n_groups @@ -812,8 +812,8 @@ def forward( class FalconH1MLP(LlamaMLP): - def __init__(self, config: FalconH1Config = None): - super().__init__() + def __init__(self, config: FalconH1Config): + super().__init__(config) self.gate_multiplier, self.down_multiplier = config.mlp_multipliers def forward(self, x): diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py index e28a84e2f0c2..090a147d31e2 100644 --- a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -517,7 +517,7 @@ class FalconMambaCausalLMOutput(MambaCausalLMOutput): class FalconMambaModel(MambaModel, FalconMambaPreTrainedModel): def __init__(self, config): - FalconMambaPreTrainedModel.__init__(config) + FalconMambaPreTrainedModel.__init__(self, config) self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( diff --git a/src/transformers/models/florence2/modular_florence2.py b/src/transformers/models/florence2/modular_florence2.py index 73ba9882eae0..417e296071de 100644 --- a/src/transformers/models/florence2/modular_florence2.py +++ b/src/transformers/models/florence2/modular_florence2.py @@ -1065,7 +1065,7 @@ def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: class Florence2VisionMLP(Llama4VisionMLP): def __init__(self, config: Florence2VisionConfig, stage_idx: int): - super().__init__() + super().__init__(config) self.fc1 = nn.Linear(config.embed_dim[stage_idx], int(config.embed_dim[stage_idx] * config.mlp_ratio)) self.activation_fn = ACT2FN[config.activation_function] self.fc2 = nn.Linear(int(config.embed_dim[stage_idx] * config.mlp_ratio), config.embed_dim[stage_idx]) diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 67aedbd55115..ea37bd31ef12 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -360,7 +360,7 @@ def extra_repr(self): class GemmaMLP(LlamaMLP): def __init__(self, config): - super().__init__() + super().__init__(config) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index e0124339852e..7f101ff1ec0a 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -209,7 +209,7 @@ class Gemma2RMSNorm(GemmaRMSNorm): class Gemma2MLP(GemmaMLP): def __init__(self, config): - super().__init__() + super().__init__(config) self.act_fn = ACT2FN[config.hidden_activation] diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index d507397b0e31..6e06671ea0bb 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -383,7 +383,7 @@ def __init__(self, config: Gemma3TextConfig): class Gemma3RMSNorm(Gemma2RMSNorm): def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() + super().__init__(dim=dim, eps=eps) class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding): @@ -396,7 +396,7 @@ class Gemma3Attention(Gemma2Attention): def __init__(self, config: Gemma3TextConfig, layer_idx: int): self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" - super().__init__() + super().__init__(config, layer_idx) self.sliding_window = config.sliding_window if self.is_sliding else None self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index 3a8db73cdd66..9df2d6a05bf1 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1739,7 +1739,7 @@ def apply_rotary_pos_emb( class Gemma3nTextAttention(Gemma3Attention): def __init__(self, config: Gemma3nTextConfig, layer_idx: int): - super().__init__() + super().__init__(config, layer_idx) del self.attn_logit_softcapping del self.scaling self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False) @@ -2234,7 +2234,7 @@ class Gemma3nModel(PaliGemmaModel): _checkpoint_conversion_mapping = {} def __init__(self, config: Gemma3nConfig): - super().__init__() + super().__init__(config) del self.multi_modal_projector # Replaced by Gemma3nVisionEmbedder self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input self.audio_tower = AutoModel.from_config(config.audio_config) diff --git a/src/transformers/models/glm4_moe/modular_glm4_moe.py b/src/transformers/models/glm4_moe/modular_glm4_moe.py index cf157ad9b26a..bc07483c7f22 100644 --- a/src/transformers/models/glm4_moe/modular_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modular_glm4_moe.py @@ -255,7 +255,7 @@ def __init__( ) -class Glm4MoeAttention(CohereAttention, nn.Module): +class Glm4MoeAttention(CohereAttention): def __init__(self, config: Glm4MoeConfig, layer_idx: Optional[int] = None): nn.Module.__init__(self) self.config = config @@ -287,7 +287,7 @@ class Glm4MoeMLP(DeepseekV3MLP): pass -class Glm4MoeTopkRouter(DeepseekV3TopkRouter, nn.Module): +class Glm4MoeTopkRouter(DeepseekV3TopkRouter): def __init__(self, config: Glm4MoeConfig): nn.Module.__init__(self) self.config = config diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index ca17a630a825..e872995aed32 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -507,7 +507,7 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc class Glm4vVisionAttention(Qwen2_5_VLVisionAttention): def __init__(self, config: Glm4vVisionConfig) -> None: - super().__init__() + super().__init__(config) self.attention_dropout = config.attention_dropout self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) @@ -515,7 +515,7 @@ def __init__(self, config: Glm4vVisionConfig) -> None: class Glm4vVisionBlock(Qwen2_5_VLVisionBlock): def __init__(self, config) -> None: - super().__init__() + super().__init__(config) self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attn = Glm4vVisionAttention(config) diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index cf98029307c0..e51e4b12c798 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -240,7 +240,7 @@ class GotOcr2VisionAttention(SamVisionAttention): class GotOcr2VisionLayer(SamVisionLayer): def __init__(self, config, window_size): - super().__init__() + super().__init__(config, window_size) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attn = GotOcr2VisionAttention(config, window_size) self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) diff --git a/src/transformers/models/helium/modular_helium.py b/src/transformers/models/helium/modular_helium.py index 0549c229546d..fe53f7820abb 100644 --- a/src/transformers/models/helium/modular_helium.py +++ b/src/transformers/models/helium/modular_helium.py @@ -104,7 +104,7 @@ def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None): class HeliumDecoderLayer(LlamaDecoderLayer): def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None): - super().__init__() + super().__init__(config, layer_idx) self.mlp = HeliumMLP(config) self.input_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py index ada3fbde0986..c79ccc6a616d 100644 --- a/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py @@ -114,7 +114,7 @@ def forward( class HunYuanDenseV1DecoderLayer(LlamaDecoderLayer): def __init__(self, config: HunYuanDenseV1Config, layer_idx: int): - super().__init__() + super().__init__(config, layer_idx) self.layer_idx = layer_idx diff --git a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py index 12a460b3108d..645c54ae73af 100644 --- a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py @@ -187,7 +187,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class HunYuanMoEV1DecoderLayer(LlamaDecoderLayer): def __init__(self, config: HunYuanMoEV1Config, layer_idx: int): - super().__init__() + super().__init__(config, layer_idx) self.hidden_size = config.hidden_size self.self_attn = HunYuanMoEV1Attention(config=config, layer_idx=layer_idx) self.mlp = HunYuanMoEV1Moe(config, layer_idx=layer_idx) diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index 1c3d7bec639b..1e0757d6cf0c 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -79,7 +79,7 @@ class InternVLVisionRMSNorm(LlamaRMSNorm): class InternVLVisionAttention(JanusVisionAttention): def __init__(self, config: InternVLVisionConfig): - super().__init__() + super().__init__(config) del self.num_key_value_groups # Needed for flash attention diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 2ae65710b2a5..2b5a4c09e023 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -536,7 +536,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class JanusVisionEncoderLayer(SiglipEncoderLayer): def __init__(self, config: JanusVisionConfig): - super().__init__() + super().__init__(config) self.config = config self.embed_dim = config.hidden_size self.self_attn = JanusVisionAttention(config) diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 470c552ce647..290d60b91e66 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -44,7 +44,7 @@ def __init__(self, config): class MistralAttention(LlamaAttention): def __init__(self, config: MistralConfig, layer_idx: int): - super().__init__() + super().__init__(config, layer_idx) self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) diff --git a/src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py b/src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py index aea644fdd656..a05045a68cb5 100644 --- a/src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +++ b/src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py @@ -337,7 +337,7 @@ class MMGroundingDinoDecoder(GroundingDinoDecoder): class MMGroundingDinoModel(GroundingDinoModel, MMGroundingDinoPreTrainedModel): def __init__(self, config: MMGroundingDinoConfig): - MMGroundingDinoPreTrainedModel.__init__(config) + MMGroundingDinoPreTrainedModel.__init__(self, config) # Create backbone + positional encoding backbone = MMGroundingDinoConvEncoder(config) @@ -400,7 +400,7 @@ class MMGroundingDinoForObjectDetection(GroundingDinoForObjectDetection, MMGroun ] def __init__(self, config: MMGroundingDinoConfig): - MMGroundingDinoPreTrainedModel.__init__(config) + MMGroundingDinoPreTrainedModel.__init__(self, config) self.model = MMGroundingDinoModel(config) diff --git a/src/transformers/models/ovis2/modular_ovis2.py b/src/transformers/models/ovis2/modular_ovis2.py index fee26273d1ed..2eeff19fa852 100644 --- a/src/transformers/models/ovis2/modular_ovis2.py +++ b/src/transformers/models/ovis2/modular_ovis2.py @@ -60,7 +60,7 @@ class Ovis2VisionMLP(LlamaMLP): class Ovis2VisionEmbeddings(SiglipVisionEmbeddings): def __init__(self, config: Ovis2VisionConfig): - super().__init__() + super().__init__(config) self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps) def interpolate_pos_encoding(self): @@ -87,7 +87,7 @@ class Ovis2VisionEncoderLayer(Aimv2EncoderLayer): class Ovis2VisionEncoder(SiglipEncoder): def __init__(self, config: Ovis2VisionConfig): - super().__init__() + super().__init__(config) self.layers = nn.ModuleList([Ovis2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index c6bdc0cc6c34..bfab86560abe 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -528,7 +528,7 @@ def __init__(self, config: Phi4MultimodalVisionConfig): class Phi4MultimodalVisionEncoder(SiglipEncoder): def __init__(self, config: Phi4MultimodalVisionConfig): - super().__init__() + super().__init__(config) self.layers = nn.ModuleList( [Phi4MultimodalVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] ) @@ -582,7 +582,7 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) -class Phi4MultimodalVisionEmbeddings(SiglipVisionEmbeddings, nn.Module): +class Phi4MultimodalVisionEmbeddings(SiglipVisionEmbeddings): def __init__(self, config: Phi4MultimodalVisionConfig): nn.Module.__init__(self) self.config = config @@ -1455,7 +1455,7 @@ def _init_weights(self, module): module.sub_img_feature_extensor.data.zero_() -class Phi4MultimodalModel(Phi3Model, nn.Module): +class Phi4MultimodalModel(Phi3Model): def __init__(self, config: Phi4MultimodalConfig): super().__init__(config) self.padding_idx = config.pad_token_id @@ -1570,7 +1570,7 @@ def forward( ) -class Phi4MultimodalForCausalLM(Phi3ForCausalLM, nn.Module): +class Phi4MultimodalForCausalLM(Phi3ForCausalLM): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 1ee9411cd7aa..6425f6df9129 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -2061,7 +2061,7 @@ def __init__(self, config: Qwen2_5OmniThinkerConfig, device=None): # It's same as `Qwen2_5_VLAttention`, but talker model's hidden_size isn't divisible by num_heads. # Removes the value error as a workaround. -class Qwen2_5OmniAttention(Qwen2_5_VLAttention, nn.Module): +class Qwen2_5OmniAttention(Qwen2_5_VLAttention): def __init__(self, config: Qwen2_5OmniConfig, layer_idx: Optional[int] = None): nn.Module.__init__(self) self.config = config diff --git a/src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py b/src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py index 9c8477c86e0f..af28015d1462 100644 --- a/src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py +++ b/src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py @@ -610,7 +610,7 @@ class RTDetrV2MLPPredictionHead(RTDetrMLPPredictionHead): class RTDetrV2ForObjectDetection(RTDetrForObjectDetection, RTDetrV2PreTrainedModel): def __init__(self, config: RTDetrV2Config): - RTDetrV2PreTrainedModel.__init__(config) + RTDetrV2PreTrainedModel.__init__(self, config) # RTDETR encoder-decoder model self.model = RTDetrV2Model(config) diff --git a/src/transformers/models/siglip2/modular_siglip2.py b/src/transformers/models/siglip2/modular_siglip2.py index 803c1b070b3f..2956187763c8 100644 --- a/src/transformers/models/siglip2/modular_siglip2.py +++ b/src/transformers/models/siglip2/modular_siglip2.py @@ -232,7 +232,7 @@ def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTen class Siglip2VisionTransformer(SiglipVisionTransformer): def __init__(self, config: Siglip2VisionConfig): - super().__init__() + super().__init__(config) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" # Update: add `spatial_shapes` and `attention_mask` diff --git a/src/transformers/models/unispeech/modular_unispeech.py b/src/transformers/models/unispeech/modular_unispeech.py index 5ab58ff18748..900079b7bb9b 100644 --- a/src/transformers/models/unispeech/modular_unispeech.py +++ b/src/transformers/models/unispeech/modular_unispeech.py @@ -215,7 +215,7 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti class UniSpeechModel(UniSpeechPreTrainedModel, Wav2Vec2Model): def __init__(self, config: UniSpeechConfig): - UniSpeechPreTrainedModel.__init__(config) + UniSpeechPreTrainedModel.__init__(self, config) self.config = config self.feature_extractor = UniSpeechFeatureEncoder(config) self.feature_projection = UniSpeechFeatureProjection(config) diff --git a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py index ff1a2fefe4e5..3e1d99939215 100644 --- a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py @@ -101,7 +101,7 @@ class UniSpeechSatEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm): class UniSpeechSatGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer): def __init__(self, config): - super().__init__() + super().__init__(config) self.weight_proj = nn.Linear(config.hidden_size, self.num_groups * self.num_vars) @staticmethod @@ -227,7 +227,7 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti class UniSpeechSatModel(UniSpeechSatPreTrainedModel, Wav2Vec2Model): def __init__(self, config: UniSpeechSatConfig): - UniSpeechSatPreTrainedModel.__init__(config) + UniSpeechSatPreTrainedModel.__init__(self, config) self.config = config self.feature_extractor = UniSpeechSatFeatureEncoder(config) self.feature_projection = UniSpeechSatFeatureProjection(config) diff --git a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py index c78e5bf55072..b9b60a6bd3ad 100644 --- a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py @@ -64,7 +64,7 @@ def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Ten return mask -class Wav2Vec2BertRotaryPositionalEmbedding(Wav2Vec2ConformerRotaryPositionalEmbedding, nn.Module): +class Wav2Vec2BertRotaryPositionalEmbedding(Wav2Vec2ConformerRotaryPositionalEmbedding): def __init__(self, config): nn.Module.__init__(self) dim = config.hidden_size // config.num_attention_heads @@ -96,7 +96,7 @@ def forward(self, hidden_states): return hidden_states, norm_hidden_states -class Wav2Vec2BertFeedForward(Wav2Vec2FeedForward, nn.Module): +class Wav2Vec2BertFeedForward(Wav2Vec2FeedForward): def __init__(self, config, act_fn=None, hidden_size=None): nn.Module.__init__(self) act_fn = act_fn if act_fn is not None else config.hidden_act @@ -671,7 +671,7 @@ def _get_feature_vector_attention_mask( class Wav2Vec2BertModel(Wav2Vec2Model, Wav2Vec2BertPreTrainedModel): def __init__(self, config: Wav2Vec2BertConfig): - Wav2Vec2BertPreTrainedModel.__init__(config) + Wav2Vec2BertPreTrainedModel.__init__(self, config) self.config = config self.feature_projection = Wav2Vec2BertFeatureProjection(config) diff --git a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py index dc3a937e56ba..2c009c004453 100644 --- a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py @@ -647,7 +647,7 @@ def _get_feature_vector_attention_mask( class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel, Wav2Vec2Model): def __init__(self, config: Wav2Vec2ConformerConfig): - Wav2Vec2ConformerPreTrainedModel.__init__(config) + Wav2Vec2ConformerPreTrainedModel.__init__(self, config) self.config = config self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config) self.feature_projection = Wav2Vec2ConformerFeatureProjection(config) diff --git a/src/transformers/models/xlstm/modeling_xlstm.py b/src/transformers/models/xlstm/modeling_xlstm.py index 820baf06b5a9..14f189d2f1cc 100644 --- a/src/transformers/models/xlstm/modeling_xlstm.py +++ b/src/transformers/models/xlstm/modeling_xlstm.py @@ -1434,7 +1434,7 @@ def forward( offset = 0 with torch.no_grad(): if cache_params is None: - cache_params = xLSTMCache(config=self.config, batch_size=hidden_states.shape[0]) + cache_params = xLSTMCache(config=self.config, max_batch_size=hidden_states.shape[0]) final_state = torch.zeros_like(hidden_states) while offset < hidden_states.shape[1]: hidden_states_chunk = hidden_states[