Skip to content

Commit 719411d

Browse files
Initial Gemm3p5TextModel (#4)
NOTE: This implementation WILL CHANGE in the coming weeks, however, changes will be strictly additive and this will remain a suitable baseline for downstream implementations to reference. * Adding KV Cache Sharing * Adds Einsum layer to Gemma 3.5 * Updating EinsumLayer API * Refactored kv cache sharing in attention * Adding KVStore for cache sharing * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/cache_utils.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3.5 * Updating KV Cache Sharing implementation * Updating the q and k norm definitions in the attention module * Fixing name error for q,k,v RMS norm to use the right 3p5 module * Updating MLP with activation sparsity * Updating DecoderBlock for Gemma 3.5 * Updating kv cache sharing implementation with the use of a cache buffer and refactoring some lines of code * Isolating KV Cache logic to relevant components * Fixing logic error in Gemma3p5Attention.forward * Refactoring caching contributions and fixing kv_store initialization * Simplifying Configs * Remove errant self from super init call * Bug fix in the Attention module - changing self.head_dim to config.head_dim * Bug fixes in the LaurelBlock and RMS Norm super init call * removing redundant code from a merge * Adding per_layer_inputs to TextModel * Adding preprocess embeddings with altup * Adds per-layer-to-single output and a host of TODOs * Integrating altup predict with the model workflow and other minor bug fixes * Using nn.Embedding temporarily for text model * It goes forward * Minor refactor of attention sparsity and RoPE initialization * Fixing duplicate rope_scaling param bug when loading from pretrained --------- Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com>
1 parent f76c5f9 commit 719411d

File tree

4 files changed

+1142
-850
lines changed

4 files changed

+1142
-850
lines changed

src/transformers/cache_utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,3 +2445,68 @@ def _prefetch_layer_in_context(self, layer_idx: int) -> None:
24452445

24462446
self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True)
24472447
self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True)
2448+
2449+
2450+
@dataclass
2451+
class KVStore:
2452+
"""KV cache container for a single layer.
2453+
2454+
For holding the k, v values during training for sharing them across depth.
2455+
Unlike AttenionKVCache, it does not split the cache into prefill and
2456+
generation segments or implements a cursor.
2457+
This is not stacked in a layerscan.
2458+
2459+
Shapes:
2460+
k: [batch_size, seq_len, num_kv_heads, head_dim]
2461+
v: [batch_size, seq_len, num_kv_heads, head_dim]
2462+
"""
2463+
2464+
def __init__(self, x: torch.Tensor, num_kv_heads: int, head_dim: int,) -> None:
2465+
"""
2466+
Shape of x: [batch_size, seq_len, head_dim]
2467+
2468+
Returns:
2469+
A KVStore with zero initialized 'k' and 'v', each with a shape of [batch_size, seq_len, num_kv_heads, head_dim]
2470+
"""
2471+
b, t, _ = x.shape
2472+
self.k = torch.zeros((b, t, num_kv_heads, head_dim), dtype=x.dtype, device=x.device)
2473+
self.v = torch.zeros((b, t, num_kv_heads, head_dim), dtype=x.dtype, device=x.device)
2474+
2475+
2476+
@dataclass
2477+
class BlockKVStore:
2478+
"""
2479+
Stores 2 KVStore objects:
2480+
- kv_local: for local/sliding window attention
2481+
- kv_global: for global attention
2482+
"""
2483+
2484+
def __init__(
2485+
self,
2486+
kv_local: Optional[KVStore] = None,
2487+
kv_global: Optional[KVStore] = None,
2488+
x: Optional[torch.Tensor] = None,
2489+
num_kv_heads: Optional[int] = None,
2490+
num_global_kv_heads: Optional[int] = None,
2491+
head_dim: Optional[int] = None,
2492+
):
2493+
if kv_local is None:
2494+
self.kv_local = KVStore(x, num_kv_heads, head_dim)
2495+
else:
2496+
self.kv_local = kv_local
2497+
2498+
if kv_global is None:
2499+
self.kv_global = KVStore(x, num_global_kv_heads, head_dim)
2500+
else:
2501+
self.kv_global = kv_global
2502+
2503+
def update_kv_store(self, is_sliding: bool, new_kv_store: KVStore) -> "BlockKVStore":
2504+
"""Return a new BlockKVStore with either the local or global KVStore replaced."""
2505+
if is_sliding:
2506+
return BlockKVStore(kv_local=new_kv_store, kv_global=self.kv_global)
2507+
else:
2508+
return BlockKVStore(kv_local=self.kv_local, kv_global=new_kv_store)
2509+
2510+
def get_kv_store(self, is_sliding: bool) -> Optional[KVStore]:
2511+
"""Return the relevant KVStore if sharing is enabled. Otherwise return None. """
2512+
return self.kv_local if is_sliding else self.kv_global

src/transformers/models/gemma3p5/configuration_gemma3p5.py

Lines changed: 102 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,26 @@
1919
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2020
# See the License for the specific language governing permissions and
2121
# limitations under the License.
22-
from typing import Optional
22+
import fractions
23+
from collections.abc import Sequence
24+
from typing import Any, Optional, Union
2325

2426
from ...configuration_utils import PretrainedConfig
25-
from ...modeling_rope_utils import rope_config_validation
2627
from ...utils import logging
27-
from ..siglip import SiglipVisionConfig
28+
from ..gemma3 import Gemma3TextConfig
2829

2930

3031
logger = logging.get_logger(__name__)
3132

3233

33-
class Gemma3p5TextConfig(PretrainedConfig):
34+
class Gemma3p5TextConfig(Gemma3TextConfig):
3435
r"""
3536
This is the configuration class to store the configuration of a [`Gemma3p5TextModel`]. It is used to instantiate an Gemma3p5Text
3637
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
37-
defaults will yield a similar configuration to that of the Gemma3p5Text-7B.
38-
e.g. [google/gemma3p5_text-7b](https://huggingface.co/google/gemma3p5_text-7b)
39-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40-
documentation from [`PretrainedConfig`] for more information.
38+
defaults will yield a similar configuration to that of the Gemma3p5Text-4B.
39+
e.g. [google/gemma3p5_text-4b](https://huggingface.co/google/gemma3p5_text-4b) #TODO (sindhuraghuram): Update the link here
40+
Configuration objects inherit from [`Gemma3TextConfig`] and can be used to control the model outputs. Read the
41+
documentation from [`Gemma3TextConfig`] for more information.
4142
Args:
4243
vocab_size (`int`, *optional*, defaults to 262208):
4344
Vocabulary size of the Gemma3p5Text model. Defines the number of different tokens that can be represented by the
@@ -134,105 +135,100 @@ class Gemma3p5TextConfig(PretrainedConfig):
134135
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
135136
rope_local_base_freq (float, *optional*, defaults to 10000.0):
136137
The base period of the RoPE embeddings for local attention.
137-
sliding_window_pattern (`int`, *optional*, defaults to 6):
138+
sliding_window_pattern (`int`, *optional*, defaults to 5):
138139
Pattern for the sliding window attention.
139140
141+
TODO (sindhuraghuram): Update the list of configs
142+
140143
```python
141144
>>> from transformers import Gemma3p5TextModel, Gemma3p5TextConfig
142-
>>> # Initializing a Gemma3p5Text gemma3p5_text-7b style configuration
145+
>>> # Initializing a Gemma3p5Text gemma3p5_text-4b style configuration
143146
>>> configuration = Gemma3p5TextConfig()
144-
>>> # Initializing a model from the gemma3p5_text-7b style configuration
147+
>>> # Initializing a model from the gemma3p5_text-4b style configuration
145148
>>> model = Gemma3p5TextModel(configuration)
146149
>>> # Accessing the model configuration
147150
>>> configuration = model.config
148151
```
149152
rope_local_base_freq (float, *optional*, defaults to 10000.0):
150153
The base period of the RoPE embeddings for local attention.
151-
sliding_window_pattern (`int`, *optional*, defaults to 6):
154+
sliding_window_pattern (`int`, *optional*, defaults to 5):
152155
Pattern for the sliding window attention.
153156
"""
154157

155158
model_type = "gemma3p5_text"
156-
keys_to_ignore_at_inference = ["past_key_values"]
157-
base_model_tp_plan = {
158-
"layers.*.self_attn.q_proj": "colwise",
159-
"layers.*.self_attn.k_proj": "colwise",
160-
"layers.*.self_attn.v_proj": "colwise",
161-
"layers.*.self_attn.o_proj": "rowwise",
162-
"layers.*.mlp.gate_proj": "colwise",
163-
"layers.*.mlp.up_proj": "colwise",
164-
"layers.*.mlp.down_proj": "rowwise",
165-
}
166-
base_model_pp_plan = {
167-
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
168-
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
169-
"norm": (["hidden_states"], ["hidden_states"]),
170-
}
171159

172160
def __init__(
173161
self,
174-
vocab_size=262_208,
175-
hidden_size=2304,
176-
intermediate_size=9216,
177-
num_hidden_layers=26,
178-
num_attention_heads=8,
179-
num_key_value_heads=4,
180-
head_dim=256,
181-
hidden_activation="gelu_pytorch_tanh",
182-
max_position_embeddings=131_072,
183-
initializer_range=0.02,
184-
rms_norm_eps=1e-6,
185-
use_cache=True,
186-
pad_token_id=0,
187-
eos_token_id=1,
188-
bos_token_id=2,
189-
tie_word_embeddings=True,
190-
rope_theta=1_000_000.0,
191-
attention_bias=False,
192-
attention_dropout=0.0,
193-
query_pre_attn_scalar=256,
194-
sliding_window=4096,
195-
final_logit_softcapping=None,
196-
attn_logit_softcapping=None,
197-
cache_implementation="hybrid",
198-
rope_scaling=None,
199-
rope_local_base_freq=10_000.0,
200-
sliding_window_pattern=6,
201-
**kwargs,
162+
vocab_size: int = 262_144,
163+
hidden_size: int = 2048,
164+
hidden_size_per_layer_input: int = 256,
165+
num_hidden_layers: int = 35,
166+
sliding_window: int = 512,
167+
intermediate_size: int = 16_384,
168+
num_key_value_heads: int = 2,
169+
rope_theta: float = 1_000_000.0,
170+
rope_local_base_freq: float = 10_000.0,
171+
sliding_window_pattern: int = 5,
172+
final_logit_softcapping: float = 30.0,
173+
altup_active_idx: int = 0,
174+
altup_coef_clip: float = 120.0,
175+
altup_lr_multiplier: float = 1.0,
176+
altup_num_inputs: int = 8,
177+
altup_num_modalities: int = 4,
178+
frac_shared_layers: Union[float, fractions.Fraction] = 0.5,
179+
laurel_rank: int = 64,
180+
activation_sparsity_pattern: Optional[Sequence[float]] = None,
181+
**super_kwargs,
202182
):
183+
super_kwargs["rope_scaling"] = None
184+
203185
super().__init__(
204-
pad_token_id=pad_token_id,
205-
bos_token_id=bos_token_id,
206-
eos_token_id=eos_token_id,
207-
tie_word_embeddings=tie_word_embeddings,
208-
**kwargs,
186+
vocab_size=vocab_size,
187+
hidden_size=hidden_size,
188+
num_hidden_layers=num_hidden_layers,
189+
num_key_value_heads=num_key_value_heads,
190+
intermediate_size=intermediate_size,
191+
rope_theta=rope_theta,
192+
rope_local_base_freq=rope_local_base_freq,
193+
sliding_window=sliding_window,
194+
sliding_window_pattern=sliding_window_pattern,
195+
final_logit_softcapping=final_logit_softcapping,
196+
**super_kwargs,
209197
)
210-
self.vocab_size = vocab_size
211-
self.max_position_embeddings = max_position_embeddings
212-
self.hidden_size = hidden_size
213-
self.intermediate_size = intermediate_size
214-
self.num_hidden_layers = num_hidden_layers
215-
self.num_attention_heads = num_attention_heads
216-
self.head_dim = head_dim
217-
self.num_key_value_heads = num_key_value_heads
218-
self.initializer_range = initializer_range
219-
self.rms_norm_eps = rms_norm_eps
220-
self.use_cache = use_cache
221-
self.rope_theta = rope_theta
222-
self.attention_bias = attention_bias
223-
self.attention_dropout = attention_dropout
224-
self.hidden_activation = hidden_activation
225-
self.query_pre_attn_scalar = query_pre_attn_scalar
226-
self.sliding_window = sliding_window
227-
self.final_logit_softcapping = final_logit_softcapping
228-
self.attn_logit_softcapping = attn_logit_softcapping
229-
self.cache_implementation = cache_implementation
230-
231-
self.rope_local_base_freq = rope_local_base_freq
232-
# For configuring HybridCache to work with 5:1 attention pattern
233-
self.sliding_window_pattern = sliding_window_pattern
234-
self.rope_scaling = rope_scaling
235-
rope_config_validation(self)
198+
self.hidden_size_per_layer_input = hidden_size_per_layer_input
199+
200+
self.altup_active_idx = altup_active_idx
201+
self.altup_coef_clip = altup_coef_clip
202+
self.altup_lr_multiplier = altup_lr_multiplier
203+
self.altup_num_inputs = altup_num_inputs
204+
self.altup_num_modalities = altup_num_modalities
205+
206+
self.laurel_rank = laurel_rank
207+
208+
self.frac_shared_layers = frac_shared_layers
209+
if (
210+
activation_sparsity_pattern is not None
211+
and (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers
212+
):
213+
raise ValueError(
214+
"activation_sparsity_pattern must have an explicit activation sparsity value for every layer."
215+
f"Expected {num_hidden_layers} values but got {len_asp}."
216+
)
217+
self.activation_sparsity_pattern = activation_sparsity_pattern
218+
219+
220+
class Gemma3p5AudioConfig(PretrainedConfig):
221+
model_type = "gemma3p5"
222+
223+
def __init__(self):
224+
pass
225+
226+
227+
class Gemma3p5VisionConfig(PretrainedConfig):
228+
model_type = "gemma3p5"
229+
230+
def __init__(self):
231+
pass
236232

237233

238234
class Gemma3p5Config(PretrainedConfig):
@@ -287,37 +283,43 @@ class Gemma3p5Config(PretrainedConfig):
287283
model_type = "gemma3p5"
288284
sub_configs = {
289285
"text_config": Gemma3p5TextConfig,
290-
"vision_config": SiglipVisionConfig,
286+
"vision_config": Gemma3p5VisionConfig,
287+
"audio_config": Gemma3p5AudioConfig,
291288
}
292289

293290
def __init__(
294291
self,
295-
text_config: Optional[Gemma3p5TextConfig] = None,
296-
vision_config: Optional[SiglipVisionConfig] = None,
292+
text_config: Optional[Union[Gemma3p5TextConfig, dict[str, Any]]] = None,
293+
vision_config: Optional[Union[Gemma3p5VisionConfig, dict[str, Any]]] = None,
294+
audio_config: Optional[Union[Gemma3p5AudioConfig, dict[str, Any]]] = None,
297295
mm_tokens_per_image: int = 256,
298296
boi_token_index: int = 255_999,
299297
eoi_token_index: int = 256_000,
300298
image_token_index: int = 262_144,
301299
initializer_range: float = 0.02,
302300
**kwargs,
303301
):
304-
if text_config is None:
305-
text_config = Gemma3p5TextConfig()
306-
logger.info("text_config is None, using default Gemma3p5TextConfig vision config.")
307-
elif isinstance(text_config, dict):
302+
if isinstance(text_config, dict):
308303
text_config = Gemma3p5TextConfig(**text_config)
304+
elif text_config is None:
305+
text_config = Gemma3p5TextConfig()
306+
logger.info("text_config is None. Using default Gemma3p5TextConfig.")
309307

310308
if isinstance(vision_config, dict):
311-
vision_config = SiglipVisionConfig(**vision_config)
312-
else:
313-
vision_config = SiglipVisionConfig()
314-
logger.info(
315-
"vision_config is None or incompatible with Gemma3p5VisionConfig intialization. Gemma3p5 will be limited "
316-
"to text tasks."
317-
)
309+
vision_config = Gemma3p5VisionConfig(**vision_config)
310+
elif vision_config is None:
311+
vision_config = Gemma3p5VisionConfig()
312+
logger.info("vision_config is None. Using default Gemma3p5VisionConfig.")
313+
314+
if isinstance(audio_config, dict):
315+
audio_config = Gemma3p5AudioConfig(**audio_config)
316+
elif audio_config is None:
317+
audio_config = Gemma3p5AudioConfig()
318+
logger.info("audio_config is None. Using default Gemma3p5AudioConfig.")
318319

319320
self.text_config = text_config
320321
self.vision_config = vision_config
322+
self.audio_config = audio_config
321323
self.mm_tokens_per_image = mm_tokens_per_image
322324
self.boi_token_index = boi_token_index
323325
self.eoi_token_index = eoi_token_index

0 commit comments

Comments
 (0)