|
19 | 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
20 | 20 | # See the License for the specific language governing permissions and |
21 | 21 | # limitations under the License. |
22 | | -from typing import Optional |
| 22 | +import fractions |
| 23 | +from collections.abc import Sequence |
| 24 | +from typing import Any, Optional, Union |
23 | 25 |
|
24 | 26 | from ...configuration_utils import PretrainedConfig |
25 | | -from ...modeling_rope_utils import rope_config_validation |
26 | 27 | from ...utils import logging |
27 | | -from ..siglip import SiglipVisionConfig |
| 28 | +from ..gemma3 import Gemma3TextConfig |
28 | 29 |
|
29 | 30 |
|
30 | 31 | logger = logging.get_logger(__name__) |
31 | 32 |
|
32 | 33 |
|
33 | | -class Gemma3p5TextConfig(PretrainedConfig): |
| 34 | +class Gemma3p5TextConfig(Gemma3TextConfig): |
34 | 35 | r""" |
35 | 36 | This is the configuration class to store the configuration of a [`Gemma3p5TextModel`]. It is used to instantiate an Gemma3p5Text |
36 | 37 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the |
37 | | - defaults will yield a similar configuration to that of the Gemma3p5Text-7B. |
38 | | - e.g. [google/gemma3p5_text-7b](https://huggingface.co/google/gemma3p5_text-7b) |
39 | | - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
40 | | - documentation from [`PretrainedConfig`] for more information. |
| 38 | + defaults will yield a similar configuration to that of the Gemma3p5Text-4B. |
| 39 | + e.g. [google/gemma3p5_text-4b](https://huggingface.co/google/gemma3p5_text-4b) #TODO (sindhuraghuram): Update the link here |
| 40 | + Configuration objects inherit from [`Gemma3TextConfig`] and can be used to control the model outputs. Read the |
| 41 | + documentation from [`Gemma3TextConfig`] for more information. |
41 | 42 | Args: |
42 | 43 | vocab_size (`int`, *optional*, defaults to 262208): |
43 | 44 | Vocabulary size of the Gemma3p5Text model. Defines the number of different tokens that can be represented by the |
@@ -134,105 +135,100 @@ class Gemma3p5TextConfig(PretrainedConfig): |
134 | 135 | Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE |
135 | 136 | rope_local_base_freq (float, *optional*, defaults to 10000.0): |
136 | 137 | The base period of the RoPE embeddings for local attention. |
137 | | - sliding_window_pattern (`int`, *optional*, defaults to 6): |
| 138 | + sliding_window_pattern (`int`, *optional*, defaults to 5): |
138 | 139 | Pattern for the sliding window attention. |
139 | 140 |
|
| 141 | + TODO (sindhuraghuram): Update the list of configs |
| 142 | +
|
140 | 143 | ```python |
141 | 144 | >>> from transformers import Gemma3p5TextModel, Gemma3p5TextConfig |
142 | | - >>> # Initializing a Gemma3p5Text gemma3p5_text-7b style configuration |
| 145 | + >>> # Initializing a Gemma3p5Text gemma3p5_text-4b style configuration |
143 | 146 | >>> configuration = Gemma3p5TextConfig() |
144 | | - >>> # Initializing a model from the gemma3p5_text-7b style configuration |
| 147 | + >>> # Initializing a model from the gemma3p5_text-4b style configuration |
145 | 148 | >>> model = Gemma3p5TextModel(configuration) |
146 | 149 | >>> # Accessing the model configuration |
147 | 150 | >>> configuration = model.config |
148 | 151 | ``` |
149 | 152 | rope_local_base_freq (float, *optional*, defaults to 10000.0): |
150 | 153 | The base period of the RoPE embeddings for local attention. |
151 | | - sliding_window_pattern (`int`, *optional*, defaults to 6): |
| 154 | + sliding_window_pattern (`int`, *optional*, defaults to 5): |
152 | 155 | Pattern for the sliding window attention. |
153 | 156 | """ |
154 | 157 |
|
155 | 158 | model_type = "gemma3p5_text" |
156 | | - keys_to_ignore_at_inference = ["past_key_values"] |
157 | | - base_model_tp_plan = { |
158 | | - "layers.*.self_attn.q_proj": "colwise", |
159 | | - "layers.*.self_attn.k_proj": "colwise", |
160 | | - "layers.*.self_attn.v_proj": "colwise", |
161 | | - "layers.*.self_attn.o_proj": "rowwise", |
162 | | - "layers.*.mlp.gate_proj": "colwise", |
163 | | - "layers.*.mlp.up_proj": "colwise", |
164 | | - "layers.*.mlp.down_proj": "rowwise", |
165 | | - } |
166 | | - base_model_pp_plan = { |
167 | | - "embed_tokens": (["input_ids"], ["inputs_embeds"]), |
168 | | - "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), |
169 | | - "norm": (["hidden_states"], ["hidden_states"]), |
170 | | - } |
171 | 159 |
|
172 | 160 | def __init__( |
173 | 161 | self, |
174 | | - vocab_size=262_208, |
175 | | - hidden_size=2304, |
176 | | - intermediate_size=9216, |
177 | | - num_hidden_layers=26, |
178 | | - num_attention_heads=8, |
179 | | - num_key_value_heads=4, |
180 | | - head_dim=256, |
181 | | - hidden_activation="gelu_pytorch_tanh", |
182 | | - max_position_embeddings=131_072, |
183 | | - initializer_range=0.02, |
184 | | - rms_norm_eps=1e-6, |
185 | | - use_cache=True, |
186 | | - pad_token_id=0, |
187 | | - eos_token_id=1, |
188 | | - bos_token_id=2, |
189 | | - tie_word_embeddings=True, |
190 | | - rope_theta=1_000_000.0, |
191 | | - attention_bias=False, |
192 | | - attention_dropout=0.0, |
193 | | - query_pre_attn_scalar=256, |
194 | | - sliding_window=4096, |
195 | | - final_logit_softcapping=None, |
196 | | - attn_logit_softcapping=None, |
197 | | - cache_implementation="hybrid", |
198 | | - rope_scaling=None, |
199 | | - rope_local_base_freq=10_000.0, |
200 | | - sliding_window_pattern=6, |
201 | | - **kwargs, |
| 162 | + vocab_size: int = 262_144, |
| 163 | + hidden_size: int = 2048, |
| 164 | + hidden_size_per_layer_input: int = 256, |
| 165 | + num_hidden_layers: int = 35, |
| 166 | + sliding_window: int = 512, |
| 167 | + intermediate_size: int = 16_384, |
| 168 | + num_key_value_heads: int = 2, |
| 169 | + rope_theta: float = 1_000_000.0, |
| 170 | + rope_local_base_freq: float = 10_000.0, |
| 171 | + sliding_window_pattern: int = 5, |
| 172 | + final_logit_softcapping: float = 30.0, |
| 173 | + altup_active_idx: int = 0, |
| 174 | + altup_coef_clip: float = 120.0, |
| 175 | + altup_lr_multiplier: float = 1.0, |
| 176 | + altup_num_inputs: int = 8, |
| 177 | + altup_num_modalities: int = 4, |
| 178 | + frac_shared_layers: Union[float, fractions.Fraction] = 0.5, |
| 179 | + laurel_rank: int = 64, |
| 180 | + activation_sparsity_pattern: Optional[Sequence[float]] = None, |
| 181 | + **super_kwargs, |
202 | 182 | ): |
| 183 | + super_kwargs["rope_scaling"] = None |
| 184 | + |
203 | 185 | super().__init__( |
204 | | - pad_token_id=pad_token_id, |
205 | | - bos_token_id=bos_token_id, |
206 | | - eos_token_id=eos_token_id, |
207 | | - tie_word_embeddings=tie_word_embeddings, |
208 | | - **kwargs, |
| 186 | + vocab_size=vocab_size, |
| 187 | + hidden_size=hidden_size, |
| 188 | + num_hidden_layers=num_hidden_layers, |
| 189 | + num_key_value_heads=num_key_value_heads, |
| 190 | + intermediate_size=intermediate_size, |
| 191 | + rope_theta=rope_theta, |
| 192 | + rope_local_base_freq=rope_local_base_freq, |
| 193 | + sliding_window=sliding_window, |
| 194 | + sliding_window_pattern=sliding_window_pattern, |
| 195 | + final_logit_softcapping=final_logit_softcapping, |
| 196 | + **super_kwargs, |
209 | 197 | ) |
210 | | - self.vocab_size = vocab_size |
211 | | - self.max_position_embeddings = max_position_embeddings |
212 | | - self.hidden_size = hidden_size |
213 | | - self.intermediate_size = intermediate_size |
214 | | - self.num_hidden_layers = num_hidden_layers |
215 | | - self.num_attention_heads = num_attention_heads |
216 | | - self.head_dim = head_dim |
217 | | - self.num_key_value_heads = num_key_value_heads |
218 | | - self.initializer_range = initializer_range |
219 | | - self.rms_norm_eps = rms_norm_eps |
220 | | - self.use_cache = use_cache |
221 | | - self.rope_theta = rope_theta |
222 | | - self.attention_bias = attention_bias |
223 | | - self.attention_dropout = attention_dropout |
224 | | - self.hidden_activation = hidden_activation |
225 | | - self.query_pre_attn_scalar = query_pre_attn_scalar |
226 | | - self.sliding_window = sliding_window |
227 | | - self.final_logit_softcapping = final_logit_softcapping |
228 | | - self.attn_logit_softcapping = attn_logit_softcapping |
229 | | - self.cache_implementation = cache_implementation |
230 | | - |
231 | | - self.rope_local_base_freq = rope_local_base_freq |
232 | | - # For configuring HybridCache to work with 5:1 attention pattern |
233 | | - self.sliding_window_pattern = sliding_window_pattern |
234 | | - self.rope_scaling = rope_scaling |
235 | | - rope_config_validation(self) |
| 198 | + self.hidden_size_per_layer_input = hidden_size_per_layer_input |
| 199 | + |
| 200 | + self.altup_active_idx = altup_active_idx |
| 201 | + self.altup_coef_clip = altup_coef_clip |
| 202 | + self.altup_lr_multiplier = altup_lr_multiplier |
| 203 | + self.altup_num_inputs = altup_num_inputs |
| 204 | + self.altup_num_modalities = altup_num_modalities |
| 205 | + |
| 206 | + self.laurel_rank = laurel_rank |
| 207 | + |
| 208 | + self.frac_shared_layers = frac_shared_layers |
| 209 | + if ( |
| 210 | + activation_sparsity_pattern is not None |
| 211 | + and (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers |
| 212 | + ): |
| 213 | + raise ValueError( |
| 214 | + "activation_sparsity_pattern must have an explicit activation sparsity value for every layer." |
| 215 | + f"Expected {num_hidden_layers} values but got {len_asp}." |
| 216 | + ) |
| 217 | + self.activation_sparsity_pattern = activation_sparsity_pattern |
| 218 | + |
| 219 | + |
| 220 | +class Gemma3p5AudioConfig(PretrainedConfig): |
| 221 | + model_type = "gemma3p5" |
| 222 | + |
| 223 | + def __init__(self): |
| 224 | + pass |
| 225 | + |
| 226 | + |
| 227 | +class Gemma3p5VisionConfig(PretrainedConfig): |
| 228 | + model_type = "gemma3p5" |
| 229 | + |
| 230 | + def __init__(self): |
| 231 | + pass |
236 | 232 |
|
237 | 233 |
|
238 | 234 | class Gemma3p5Config(PretrainedConfig): |
@@ -287,37 +283,43 @@ class Gemma3p5Config(PretrainedConfig): |
287 | 283 | model_type = "gemma3p5" |
288 | 284 | sub_configs = { |
289 | 285 | "text_config": Gemma3p5TextConfig, |
290 | | - "vision_config": SiglipVisionConfig, |
| 286 | + "vision_config": Gemma3p5VisionConfig, |
| 287 | + "audio_config": Gemma3p5AudioConfig, |
291 | 288 | } |
292 | 289 |
|
293 | 290 | def __init__( |
294 | 291 | self, |
295 | | - text_config: Optional[Gemma3p5TextConfig] = None, |
296 | | - vision_config: Optional[SiglipVisionConfig] = None, |
| 292 | + text_config: Optional[Union[Gemma3p5TextConfig, dict[str, Any]]] = None, |
| 293 | + vision_config: Optional[Union[Gemma3p5VisionConfig, dict[str, Any]]] = None, |
| 294 | + audio_config: Optional[Union[Gemma3p5AudioConfig, dict[str, Any]]] = None, |
297 | 295 | mm_tokens_per_image: int = 256, |
298 | 296 | boi_token_index: int = 255_999, |
299 | 297 | eoi_token_index: int = 256_000, |
300 | 298 | image_token_index: int = 262_144, |
301 | 299 | initializer_range: float = 0.02, |
302 | 300 | **kwargs, |
303 | 301 | ): |
304 | | - if text_config is None: |
305 | | - text_config = Gemma3p5TextConfig() |
306 | | - logger.info("text_config is None, using default Gemma3p5TextConfig vision config.") |
307 | | - elif isinstance(text_config, dict): |
| 302 | + if isinstance(text_config, dict): |
308 | 303 | text_config = Gemma3p5TextConfig(**text_config) |
| 304 | + elif text_config is None: |
| 305 | + text_config = Gemma3p5TextConfig() |
| 306 | + logger.info("text_config is None. Using default Gemma3p5TextConfig.") |
309 | 307 |
|
310 | 308 | if isinstance(vision_config, dict): |
311 | | - vision_config = SiglipVisionConfig(**vision_config) |
312 | | - else: |
313 | | - vision_config = SiglipVisionConfig() |
314 | | - logger.info( |
315 | | - "vision_config is None or incompatible with Gemma3p5VisionConfig intialization. Gemma3p5 will be limited " |
316 | | - "to text tasks." |
317 | | - ) |
| 309 | + vision_config = Gemma3p5VisionConfig(**vision_config) |
| 310 | + elif vision_config is None: |
| 311 | + vision_config = Gemma3p5VisionConfig() |
| 312 | + logger.info("vision_config is None. Using default Gemma3p5VisionConfig.") |
| 313 | + |
| 314 | + if isinstance(audio_config, dict): |
| 315 | + audio_config = Gemma3p5AudioConfig(**audio_config) |
| 316 | + elif audio_config is None: |
| 317 | + audio_config = Gemma3p5AudioConfig() |
| 318 | + logger.info("audio_config is None. Using default Gemma3p5AudioConfig.") |
318 | 319 |
|
319 | 320 | self.text_config = text_config |
320 | 321 | self.vision_config = vision_config |
| 322 | + self.audio_config = audio_config |
321 | 323 | self.mm_tokens_per_image = mm_tokens_per_image |
322 | 324 | self.boi_token_index = boi_token_index |
323 | 325 | self.eoi_token_index = eoi_token_index |
|
0 commit comments