Skip to content

Commit 3ab1c99

Browse files
committed
Revert "use is_composition for pixtral"
This reverts commit a53d5f9.
1 parent a53d5f9 commit 3ab1c99

File tree

8 files changed

+126
-4
lines changed

8 files changed

+126
-4
lines changed

docs/source/en/model_doc/pixtral.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up
7878

7979
[[autodoc]] PixtralVisionConfig
8080

81+
## PixtralTextConfig
82+
83+
[[autodoc]] PixtralTextConfig
84+
8185
## PixtralVisionModel
8286

8387
[[autodoc]] PixtralVisionModel

src/transformers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@
700700
"Pix2StructTextConfig",
701701
"Pix2StructVisionConfig",
702702
],
703-
"models.pixtral": ["PixtralProcessor", "PixtralVisionConfig"],
703+
"models.pixtral": ["PixtralProcessor", "PixtralVisionConfig", "PixtralTextConfig"],
704704
"models.plbart": ["PLBartConfig"],
705705
"models.poolformer": ["PoolFormerConfig"],
706706
"models.pop2piano": ["Pop2PianoConfig"],

src/transformers/models/auto/configuration_auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@
232232
("phimoe", "PhimoeConfig"),
233233
("pix2struct", "Pix2StructConfig"),
234234
("pixtral", "PixtralVisionConfig"),
235+
("pixtral_text", "PixtralTextConfig"),
235236
("plbart", "PLBartConfig"),
236237
("poolformer", "PoolFormerConfig"),
237238
("pop2piano", "Pop2PianoConfig"),
@@ -574,6 +575,7 @@
574575
("phobert", "PhoBERT"),
575576
("pix2struct", "Pix2Struct"),
576577
("pixtral", "Pixtral"),
578+
("pixtral_text", "PixtralMistral"),
577579
("plbart", "PLBart"),
578580
("poolformer", "PoolFormer"),
579581
("pop2piano", "Pop2Piano"),
@@ -740,6 +742,7 @@
740742
("chinese_clip_vision_model", "chinese_clip"),
741743
("rt_detr_resnet", "rt_detr"),
742744
("granitevision", "llava_next"),
745+
("pixtral_text", "pixtral"),
743746
]
744747
)
745748

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@
555555
("phi", "PhiForCausalLM"),
556556
("phi3", "Phi3ForCausalLM"),
557557
("phimoe", "PhimoeForCausalLM"),
558+
("pixtral_text", "MistralForCausalLM"),
558559
("plbart", "PLBartForCausalLM"),
559560
("prophetnet", "ProphetNetForCausalLM"),
560561
("qdqbert", "QDQBertLMHeadModel"),

src/transformers/models/llava/configuration_llava.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ class LlavaConfig(PretrainedConfig):
7878

7979
model_type = "llava"
8080
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
81-
is_composition = True
8281

8382
def __init__(
8483
self,

src/transformers/models/pixtral/configuration_pixtral.py

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Pixtral model configuration"""
1515

1616
from ...configuration_utils import PretrainedConfig
17+
from ...models.mistral.configuration_mistral import MistralConfig
1718
from ...utils import logging
1819

1920

@@ -103,4 +104,116 @@ def __init__(
103104
self.initializer_range = initializer_range
104105

105106

106-
__all__ = ["PixtralVisionConfig"]
107+
class PixtralTextConfig(MistralConfig):
108+
r"""
109+
TODO
110+
111+
Args:
112+
vocab_size (`int`, *optional*, defaults to 32000):
113+
Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
114+
`inputs_ids` passed when calling [`MistralModel`]
115+
hidden_size (`int`, *optional*, defaults to 4096):
116+
Dimension of the hidden representations.
117+
intermediate_size (`int`, *optional*, defaults to 14336):
118+
Dimension of the MLP representations.
119+
num_hidden_layers (`int`, *optional*, defaults to 32):
120+
Number of hidden layers in the Transformer encoder.
121+
num_attention_heads (`int`, *optional*, defaults to 32):
122+
Number of attention heads for each attention layer in the Transformer encoder.
123+
num_key_value_heads (`int`, *optional*, defaults to 8):
124+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
125+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
126+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
127+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
128+
by meanpooling all the original heads within that group. For more details checkout [this
129+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
130+
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
131+
The attention head dimension.
132+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
133+
The non-linear activation function (function or string) in the decoder.
134+
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
135+
The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
136+
allows sequence of up to 4096*32 tokens.
137+
initializer_range (`float`, *optional*, defaults to 0.02):
138+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
139+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
140+
The epsilon used by the rms normalization layers.
141+
use_cache (`bool`, *optional*, defaults to `True`):
142+
Whether or not the model should return the last key/values attentions (not used by all models). Only
143+
relevant if `config.is_decoder=True`.
144+
pad_token_id (`int`, *optional*):
145+
The id of the padding token.
146+
bos_token_id (`int`, *optional*, defaults to 1):
147+
The id of the "beginning-of-sequence" token.
148+
eos_token_id (`int`, *optional*, defaults to 2):
149+
The id of the "end-of-sequence" token.
150+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
151+
Whether the model's input and output word embeddings should be tied.
152+
rope_theta (`float`, *optional*, defaults to 10000.0):
153+
The base period of the RoPE embeddings.
154+
sliding_window (`int`, *optional*, defaults to 4096):
155+
Sliding window attention window size. If not specified, will default to `4096`.
156+
attention_dropout (`float`, *optional*, defaults to 0.0):
157+
The dropout ratio for the attention probabilities.
158+
159+
```python
160+
>>> TODO
161+
```"""
162+
163+
model_type = "pixtral_text"
164+
165+
def __init__(
166+
self,
167+
vocab_size=32000,
168+
hidden_size=4096,
169+
intermediate_size=14336,
170+
num_hidden_layers=32,
171+
num_attention_heads=32,
172+
num_key_value_heads=8,
173+
head_dim=None,
174+
hidden_act="silu",
175+
max_position_embeddings=4096 * 32,
176+
initializer_range=0.02,
177+
rms_norm_eps=1e-6,
178+
use_cache=True,
179+
pad_token_id=None,
180+
bos_token_id=1,
181+
eos_token_id=2,
182+
tie_word_embeddings=False,
183+
rope_theta=10000.0,
184+
sliding_window=4096,
185+
attention_dropout=0.0,
186+
**kwargs,
187+
):
188+
self.vocab_size = vocab_size
189+
self.max_position_embeddings = max_position_embeddings
190+
self.hidden_size = hidden_size
191+
self.intermediate_size = intermediate_size
192+
self.num_hidden_layers = num_hidden_layers
193+
self.num_attention_heads = num_attention_heads
194+
self.sliding_window = sliding_window
195+
self.head_dim = head_dim # as opposed to MistralConfig, do not auto-populate
196+
197+
# for backward compatibility
198+
if num_key_value_heads is None:
199+
num_key_value_heads = num_attention_heads
200+
201+
self.num_key_value_heads = num_key_value_heads
202+
self.hidden_act = hidden_act
203+
self.initializer_range = initializer_range
204+
self.rms_norm_eps = rms_norm_eps
205+
self.use_cache = use_cache
206+
self.rope_theta = rope_theta
207+
self.attention_dropout = attention_dropout
208+
209+
PretrainedConfig.__init__(
210+
self,
211+
pad_token_id=pad_token_id,
212+
bos_token_id=bos_token_id,
213+
eos_token_id=eos_token_id,
214+
tie_word_embeddings=tie_word_embeddings,
215+
**kwargs,
216+
)
217+
218+
219+
__all__ = ["PixtralVisionConfig", "PixtralTextConfig"]

tests/models/llava/test_configuration_llava.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def test_pixtral_reload(self):
3232
}
3333

3434
text_config = {
35-
"model_type": "mistral",
35+
# "model_type": "mistral",
36+
"model_type": "pixtral_text",
3637
"hidden_size": 5120,
3738
"head_dim": 128,
3839
"num_attention_heads": 32,

utils/check_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def _center_text(text: str, width: int) -> str:
180180
"CLIPVisionModel",
181181
"Qwen2AudioEncoder",
182182
"SiglipVisionModel",
183+
"PixtralMistral", # not a real model
183184
]
184185

185186

0 commit comments

Comments
 (0)