Skip to content

Commit f22cb1e

Browse files
authored
fix qwen text config (#41158)
* fix qwen text config * fix tests * fix one more test * address comments
1 parent 374ded5 commit f22cb1e

File tree

8 files changed

+207
-120
lines changed

8 files changed

+207
-120
lines changed

src/transformers/models/glm4v/configuration_glm4v.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,6 @@ def __init__(
330330
video_end_token_id=151342,
331331
**kwargs,
332332
):
333-
super().__init__(**kwargs)
334333
if isinstance(vision_config, dict):
335334
self.vision_config = self.sub_configs["vision_config"](**vision_config)
336335
elif vision_config is None:
@@ -339,7 +338,6 @@ def __init__(
339338
if isinstance(text_config, dict):
340339
self.text_config = self.sub_configs["text_config"](**text_config)
341340
elif text_config is None:
342-
# For BC use all kwargs to init `TextConfig`
343341
self.text_config = self.sub_configs["text_config"](**kwargs)
344342

345343
self.image_token_id = image_token_id
@@ -349,5 +347,7 @@ def __init__(
349347
self.image_start_token_id = image_start_token_id
350348
self.image_end_token_id = image_end_token_id
351349

350+
super().__init__(**kwargs)
351+
352352

353353
__all__ = ["Glm4vConfig", "Glm4vTextConfig"]

src/transformers/models/glm4v/modular_glm4v.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from ...utils.generic import check_model_inputs
3939
from ...video_utils import VideoInput
4040
from ..glm4.modeling_glm4 import Glm4MLP, Glm4RMSNorm, eager_attention_forward
41-
from ..qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig
4241
from ..qwen2_5_vl.modeling_qwen2_5_vl import (
4342
Qwen2_5_VisionPatchEmbed,
4443
Qwen2_5_VisionRotaryEmbedding,
@@ -313,7 +312,7 @@ def __init__(
313312
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
314313

315314

316-
class Glm4vConfig(Qwen2_5_VLConfig):
315+
class Glm4vConfig(PretrainedConfig):
317316
r"""
318317
This is the configuration class to store the configuration of a [`Glm4vModel`]. It is used to instantiate a
319318
GLM-4.1V model according to the specified arguments, defining the model architecture. Instantiating a
@@ -355,6 +354,10 @@ class Glm4vConfig(Qwen2_5_VLConfig):
355354
>>> configuration = model.config
356355
```"""
357356

357+
model_type = "glm4v"
358+
sub_configs = {"vision_config": Glm4vVisionConfig, "text_config": Glm4vTextConfig}
359+
keys_to_ignore_at_inference = ["past_key_values"]
360+
358361
def __init__(
359362
self,
360363
text_config=None,
@@ -367,12 +370,25 @@ def __init__(
367370
video_end_token_id=151342,
368371
**kwargs,
369372
):
370-
super().__init__()
373+
if isinstance(vision_config, dict):
374+
self.vision_config = self.sub_configs["vision_config"](**vision_config)
375+
elif vision_config is None:
376+
self.vision_config = self.sub_configs["vision_config"]()
377+
378+
if isinstance(text_config, dict):
379+
self.text_config = self.sub_configs["text_config"](**text_config)
380+
elif text_config is None:
381+
self.text_config = self.sub_configs["text_config"](**kwargs)
382+
383+
self.image_token_id = image_token_id
384+
self.video_token_id = video_token_id
371385
self.video_start_token_id = video_start_token_id
372386
self.video_end_token_id = video_end_token_id
373387
self.image_start_token_id = image_start_token_id
374388
self.image_end_token_id = image_end_token_id
375389

390+
super().__init__(**kwargs)
391+
376392

377393
# Will be used for both Text and Vision modalities
378394
class Glm4vRMSNorm(Glm4RMSNorm):

src/transformers/models/glm4v_moe/configuration_glm4v_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,6 @@ def __init__(
371371
if isinstance(text_config, dict):
372372
self.text_config = self.sub_configs["text_config"](**text_config)
373373
elif text_config is None:
374-
# For BC use all kwargs to init `TextConfig`
375374
self.text_config = self.sub_configs["text_config"](**kwargs)
376375

377376
self.image_token_id = image_token_id

src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,6 @@ class Qwen2_5_VLTextConfig(PretrainedConfig):
159159
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
160160
`high_freq_factor` (`float`, *optional*):
161161
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
162-
image_token_id (`int`, *optional*):
163-
Token index used as placeholder for image embeddings.
164-
video_token_id (`int`, *optional*):
165-
Token index used as placeholder for video embeddings.
166162
167163
```python
168164
>>> from transformers import Qwen2_5_VLTextModel, Qwen2_5_VLConfig
@@ -217,8 +213,6 @@ def __init__(
217213
layer_types=None,
218214
attention_dropout=0.0,
219215
rope_scaling=None,
220-
image_token_id=None,
221-
video_token_id=None,
222216
**kwargs,
223217
):
224218
self.vocab_size = vocab_size
@@ -264,9 +258,6 @@ def __init__(
264258
self.rope_scaling["type"] = "default"
265259
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
266260
rope_config_validation(self, ignore_keys={"mrope_section"})
267-
self.image_token_id = image_token_id
268-
self.video_token_id = video_token_id
269-
270261
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
271262

272263

@@ -290,6 +281,10 @@ class Qwen2_5_VLConfig(PretrainedConfig):
290281
The image token index to encode the image prompt.
291282
video_token_id (`int`, *optional*, defaults to 151656):
292283
The video token index to encode the image prompt.
284+
vision_start_token_id (`int`, *optional*, defaults to 151652):
285+
The token index to denote start of vision input.
286+
vision_end_token_id (`int`, *optional*, defaults to 151653):
287+
The token index to denote end of vision input.
293288
294289
```python
295290
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
@@ -314,8 +309,15 @@ def __init__(
314309
vision_config=None,
315310
image_token_id=151655,
316311
video_token_id=151656,
312+
vision_start_token_id=151652,
313+
vision_end_token_id=151653,
317314
**kwargs,
318315
):
316+
# We need to init super() here so that it does not reset values
317+
# that are in text config to the BaseClass defaults. The Base
318+
# config has many text related defaults and not all defaults are same as for `Qwen2_5_VLTextConfig`
319+
super().__init__(**kwargs)
320+
319321
if isinstance(vision_config, dict):
320322
self.vision_config = self.sub_configs["vision_config"](**vision_config)
321323
elif vision_config is None:
@@ -329,8 +331,32 @@ def __init__(
329331

330332
self.image_token_id = image_token_id
331333
self.video_token_id = video_token_id
332-
333-
super().__init__(**kwargs)
334+
self.vision_start_token_id = vision_start_token_id
335+
self.vision_end_token_id = vision_end_token_id
336+
337+
# Attention implementation to use. It sets it recursively on sub-configs so we call it again in the end
338+
self._attn_implementation = kwargs.pop("attn_implementation", None)
339+
340+
def __setattr__(self, key, value):
341+
if (
342+
(text_config := super().__getattribute__("__dict__").get("text_config")) is not None
343+
and key not in ["dtype", "_attn_implementation_internal"]
344+
and key in text_config.__dict__
345+
):
346+
setattr(text_config, key, value)
347+
else:
348+
super().__setattr__(key, value)
349+
350+
def __getattribute__(self, key):
351+
if "text_config" in super().__getattribute__("__dict__") and key not in [
352+
"dtype",
353+
"_attn_implementation_internal",
354+
]:
355+
text_config = super().__getattribute__("text_config")
356+
if key in text_config.__dict__:
357+
return getattr(text_config, key)
358+
359+
return super().__getattribute__(key)
334360

335361

336362
__all__ = ["Qwen2_5_VLConfig", "Qwen2_5_VLTextConfig"]

src/transformers/models/qwen2_vl/configuration_qwen2_vl.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,6 @@ class Qwen2VLTextConfig(PretrainedConfig):
148148
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
149149
`high_freq_factor` (`float`, *optional*):
150150
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
151-
image_token_id (`int`, *optional*):
152-
Token index used as placeholder for image embeddings.
153-
video_token_id (`int`, *optional*):
154-
Token index used as placeholder for video embeddings.
155151
156152
```python
157153
>>> from transformers import Qwen2VLTextModel, Qwen2VLConfig
@@ -206,8 +202,6 @@ def __init__(
206202
layer_types=None,
207203
attention_dropout=0.0,
208204
rope_scaling=None,
209-
image_token_id=None,
210-
video_token_id=None,
211205
**kwargs,
212206
):
213207
self.vocab_size = vocab_size
@@ -253,9 +247,6 @@ def __init__(
253247
self.rope_scaling["type"] = "default"
254248
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
255249
rope_config_validation(self, ignore_keys={"mrope_section"})
256-
self.image_token_id = image_token_id
257-
self.video_token_id = video_token_id
258-
259250
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
260251

261252

@@ -271,23 +262,27 @@ class Qwen2VLConfig(PretrainedConfig):
271262
272263
273264
Args:
274-
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLTextConfig`):
265+
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2VLTextConfig`):
275266
The config object or dictionary of the text backbone.
276-
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLVisionConfig`):
267+
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2VLVisionConfig`):
277268
The config object or dictionary of the vision backbone.
278269
image_token_id (`int`, *optional*, defaults to 151655):
279270
The image token index to encode the image prompt.
280271
video_token_id (`int`, *optional*, defaults to 151656):
281272
The video token index to encode the image prompt.
273+
vision_start_token_id (`int`, *optional*, defaults to 151652):
274+
The token index to denote start of vision input.
275+
vision_end_token_id (`int`, *optional*, defaults to 151653):
276+
The token index to denote end of vision input.
282277
283278
```python
284-
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
279+
>>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
285280
286-
>>> # Initializing a Qwen2_5_VL style configuration
287-
>>> configuration = Qwen2_5_VLConfig()
281+
>>> # Initializing a Qwen2VL style configuration
282+
>>> configuration = Qwen2VLConfig()
288283
289284
>>> # Initializing a model from the Qwen2-VL-7B style configuration
290-
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
285+
>>> model = Qwen2VLForConditionalGeneration(configuration)
291286
292287
>>> # Accessing the model configuration
293288
>>> configuration = model.config
@@ -303,8 +298,15 @@ def __init__(
303298
vision_config=None,
304299
image_token_id=151655,
305300
video_token_id=151656,
301+
vision_start_token_id=151652,
302+
vision_end_token_id=151653,
306303
**kwargs,
307304
):
305+
# We need to init super() here so that it does not reset values
306+
# that are in text config to the BaseClass defaults. The Base
307+
# config has many text related defaults and not all defaults are same as for `Qwen2VLTextConfig`
308+
super().__init__(**kwargs)
309+
308310
if isinstance(vision_config, dict):
309311
self.vision_config = self.sub_configs["vision_config"](**vision_config)
310312
elif vision_config is None:
@@ -318,8 +320,32 @@ def __init__(
318320

319321
self.image_token_id = image_token_id
320322
self.video_token_id = video_token_id
321-
322-
super().__init__(**kwargs)
323+
self.vision_start_token_id = vision_start_token_id
324+
self.vision_end_token_id = vision_end_token_id
325+
326+
# Attention implementation to use. It sets it recursively on sub-configs so we call it again in the end
327+
self._attn_implementation = kwargs.pop("attn_implementation", None)
328+
329+
def __setattr__(self, key, value):
330+
if (
331+
(text_config := super().__getattribute__("__dict__").get("text_config")) is not None
332+
and key not in ["dtype", "_attn_implementation_internal"]
333+
and key in text_config.__dict__
334+
):
335+
setattr(text_config, key, value)
336+
else:
337+
super().__setattr__(key, value)
338+
339+
def __getattribute__(self, key):
340+
if "text_config" in super().__getattribute__("__dict__") and key not in [
341+
"dtype",
342+
"_attn_implementation_internal",
343+
]:
344+
text_config = super().__getattribute__("text_config")
345+
if key in text_config.__dict__:
346+
return getattr(text_config, key)
347+
348+
return super().__getattribute__(key)
323349

324350

325351
__all__ = ["Qwen2VLConfig", "Qwen2VLTextConfig"]

0 commit comments

Comments
 (0)