From 046e05ae9284bb9d90544d4c264e7f2a78aaf78d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 4 Mar 2025 01:38:57 -0800 Subject: [PATCH 01/52] Gemma3 1B working Signed-off-by: Woosuk Kwon --- vllm/config.py | 2 +- vllm/model_executor/models/gemma3.py | 456 +++++++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/config.py | 4 + 4 files changed, 462 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/gemma3.py diff --git a/vllm/config.py b/vllm/config.py index f87d2d6e82cf..ce076d81f220 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -353,7 +353,7 @@ def __init__( sliding_window = getattr(self.hf_text_config, "sliding_window", None) has_interleaved_attention = (sliding_window is not None) and ( isinstance(sliding_window, list) or - (self.hf_text_config.model_type in ["gemma2", "cohere2"])) + (self.hf_text_config.model_type in ["gemma2", "gemma3", "cohere2"])) if (not self.disable_sliding_window and has_interleaved_attention): if (backend := diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py new file mode 100644 index 000000000000..972f60ee960c --- /dev/null +++ b/vllm/model_executor/models/gemma3.py @@ -0,0 +1,456 @@ +# Copyright 2025 The vLLM team. +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, Optional, Set, Tuple, Union, Literal + +import torch +from torch import nn +# FIXME +# from transformers import Gemma3Config +from transformers.models.gemma3.configuration_gemma3 import Gemma3Config, Gemma3TextConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + +ATTENTION_TYPE_LOCAL = "local_sliding" +AttentionType = Literal["global", "local_sliding"] + + +class Gemma3MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_activation: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_activation != "gelu_pytorch_tanh": + raise ValueError( + "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_act` and `hidden_activation` to " + "`gelu_pytorch_tanh`.") + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma3Attention(nn.Module): + + def __init__(self, + config: Gemma3TextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # FIXME(woosuk): This seems like a bug in config.json. + if config.query_pre_attn_scalar < 1: + self.scaling = config.query_pre_attn_scalar + else: + self.scaling = config.query_pre_attn_scalar**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + + self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + # TODO(woosuk): Add reference to the original HF implementation. + layer_idx = extract_layer_index(prefix) + attn_type = config.attention_pattern[layer_idx % len(config.attention_pattern)] + use_sliding_window = (attn_type == ATTENTION_TYPE_LOCAL) + + # Initialize the rotary embedding. + self.rope_theta = (config.rope_local_base_freq if use_sliding_window else + config.rope_global_base_freq) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + ) + + # Initialize the attention. + sliding_window = (config.interleaved_sliding_window if + use_sliding_window else None) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + k = k.flatten(-2, -1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Gemma3DecoderLayer(nn.Module): + + def __init__( + self, + config: Gemma3TextConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Gemma3Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + attn_logits_soft_cap=None, + prefix=f"{prefix}.self_attn", + ) + self.hidden_size = config.hidden_size + self.mlp = Gemma3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Gemma3Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma3DecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", torch.tensor(normalizer)) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.normalizer + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + logger.warning( + "Some weights are not initialized from checkpoints: %s", + unloaded_params) + return loaded_params + + +class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + del lora_config # Unused. + super().__init__() + self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.quant_config = quant_config + self.model = Gemma3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.logits_processor = LogitsProcessor( + config.vocab_size, soft_cap=config.final_logit_softcapping) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 3a7fcdcf7b37..4e058a0c506e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -53,6 +53,7 @@ "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), + "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1937b1388471..38e53991ae3b 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -715,6 +715,10 @@ def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. No op for pure text models. """ + # FIXME(woosuk): This is a hack because Gemma3's text_config does not match + # its config.json for some reason. Remove this once the issue is fixed. + if config.model_type == "gemma3": + return config if hasattr(config, "text_config"): # The code operates under the assumption that text_config should have # `num_attention_heads` (among others). Assert here to fail early From 669ae5a3951bd8e35493e0096925ccd91d9b3f4e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 5 Mar 2025 16:29:47 -0800 Subject: [PATCH 02/52] Add modeling_gemma3.py Signed-off-by: Woosuk Kwon --- modeling_gemma3.py | 1381 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1381 insertions(+) create mode 100644 modeling_gemma3.py diff --git a/modeling_gemma3.py b/modeling_gemma3.py new file mode 100644 index 000000000000..1ee2d7d056e7 --- /dev/null +++ b/modeling_gemma3.py @@ -0,0 +1,1381 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable, Mapping +from dataclasses import dataclass +from typing import Literal, Optional, Union, cast + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, HybridCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + ModelOutput, + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from ..gemma import GemmaPreTrainedModel +from ..siglip import SiglipVisionModel +from .configuration_gemma3 import Gemma3Config, Gemma3RotaryEmbeddingConfig, Gemma3TextConfig + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "Gemma3Config" + + +class Gemma3RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class Gemma3MultimodalInputProjection(nn.Module): + def __init__(self, vision_dim: int, text_dim: int): + super().__init__() + self.weight = nn.Parameter(torch.zeros(vision_dim, text_dim)) + + def forward(self, x): + output = torch.einsum("btm,md->btd", x, self.weight) + return output.type_as(x) + + +class Gemma3MLP(nn.Module): + def __init__(self, config: Gemma3TextConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +ATTENTION_TYPE_LOCAL = "local_sliding" +AttentionType = Literal["global", "local_sliding"] + + +def create_sliding_window_mask( + position_ids: torch.LongTensor, + cache_position: int, + cache_len: int, + sliding_window_size: int, +) -> torch.Tensor: + """Creates mask for sliding window attention.""" + total_tokens = cache_position + position_ids.shape[1] # cached + processing tokens + + def _reconstruct_rotated_cache_positions(): + cache_positions = torch.arange(cache_len) + total_tokens - cache_len + rotated_cache_positions = torch.zeros_like(cache_positions) + rotated_cache_positions[cache_positions % cache_len] = cache_positions + return rotated_cache_positions + + # Reconstruct position_ids for cached kv. + if total_tokens <= cache_len: + cache_positions = torch.arange(cache_len) + else: + cache_positions = _reconstruct_rotated_cache_positions() + + cache_positions = cache_positions.unsqueeze(0).unsqueeze(0).to(position_ids.device) # [1, 1, cache_len] + position_ids = position_ids.unsqueeze(-1) # [B, seq_len, 1] + sliding_mask = cache_positions > position_ids - sliding_window_size + sliding_mask *= cache_positions < position_ids + sliding_window_size + return sliding_mask.unsqueeze(1) + + +def eager_attention_forward( + module: "Gemma3Attention", + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class Gemma3Attention(nn.Module): + def __init__(self, config: Gemma3TextConfig, layer_idx: int): + super().__init__() + self.attention_dropout = config.attention_dropout + self.attention_type: AttentionType = config.attention_pattern[layer_idx % len(config.attention_pattern)] + self.config = config + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.is_causal = True + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar + self.is_sliding = self.attention_type == ATTENTION_TYPE_LOCAL + self.sliding_window = config.sliding_window if self.is_sliding else None + + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Mapping[AttentionType, tuple[torch.Tensor, torch.Tensor]], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: int = 0, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + cos, sin = position_embeddings[self.attention_type] + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if ( + self.is_sliding + and self.sliding_window is not None + and attention_mask is not None + and position_ids is not None + ): + sliding_mask = create_sliding_window_mask( + position_ids=position_ids, + cache_position=last_cache_position, + cache_len=attention_mask.shape[-1], + sliding_window_size=self.sliding_window, + ) + attention_mask *= sliding_mask + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + # Why is this a torch.LongTensor? Feels like last_cache_position should be in here somehow? + "cache_position": cache_position, + "sliding_window": self.sliding_window, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Here we need to slice as we use a static cache by default, but FA2 does not support it + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + seq_len = attention_mask.shape[-1] + key_states, value_states = ( + key_states[:, :, :seq_len, :], + value_states[:, :, :seq_len, :], + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = cast( + Callable[..., tuple[torch.Tensor, torch.Tensor]], + ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation], + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Gemma3DecoderLayer(nn.Module): + def __init__(self, config: Gemma3TextConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) + self.mlp = Gemma3MLP(config) + self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.is_sliding = self.self_attn.is_sliding + self.sliding_window = config.sliding_window + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Mapping[AttentionType, tuple[torch.Tensor, torch.Tensor]], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: int = 0, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + last_cache_position=last_cache_position, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Gemma3RotaryEmbedding(nn.Module): + def __init__(self, config: Gemma3RotaryEmbeddingConfig, device: torch.device = None): + super().__init__() + # BC: "rope_type" was originally "type" + if (rope_scaling := getattr(config, "rope_scaling", None)) is not None: + self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type")) + else: + self.rope_type = "default" + + self.config = config + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Gemma3PreTrainedModel(GemmaPreTrainedModel): + base_model_prefix = "model" + config_class = Gemma3TextConfig + supports_gradient_checkpointing = True + _no_split_modules = ["Gemma3DecoderLayer"] + + +ATTENTION_TYPE_GLOBAL = "global" + + +class Gemma3Model(Gemma3PreTrainedModel): + def __init__(self, config: Gemma3TextConfig): + super().__init__(config) + self.config = config + + self.embed_tokens = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id, + ) + self.layers = nn.ModuleList( + [Gemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb_global = Gemma3RotaryEmbedding( + config=Gemma3RotaryEmbeddingConfig( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + rope_theta=config.rope_global_base_freq, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + ) + ) + self.rotary_emb_local = Gemma3RotaryEmbedding( + config=Gemma3RotaryEmbeddingConfig( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + rope_theta=config.rope_local_base_freq, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + ) + ) + self.gradient_checkpointing = False + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: Optional[int] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None and not self.training: + batch_size, seq_len, _ = inputs_embeds.shape + past_key_values = HybridCache( + self.config, + max_batch_size=batch_size, + max_cache_len=seq_len, + dtype=inputs_embeds.dtype, + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + if last_cache_position is None: + last_cache_position = 0 + if attention_mask is not None: + # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position + # It will break dynamo tracing but there are no way around it (and it should never happen in practice) + last_cache_position = ( + attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings: Mapping[AttentionType, tuple[torch.Tensor, torch.Tensor]] = { + ATTENTION_TYPE_GLOBAL: self.rotary_emb_global(hidden_states, position_ids), + ATTENTION_TYPE_LOCAL: self.rotary_emb_local(hidden_states, position_ids), + } + + # normalized + # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + position_embeddings, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + last_cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + last_cache_position=last_cache_position, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + @torch.no_grad() + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: HybridCache, + output_attentions: bool, + ): + # Flash Attention currently doesn't support static cache but Gemma3 work only with static cache. + # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape + # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible + # as it doesn't cause dynamic control issues. + if self.config._attn_implementation == "flash_attention_2": + return attention_mask + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +GEMMA3_INPUTS_DOCSTRING = "" + + +class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): + base_model_prefix = "language_model" + + def __init__(self, config: Gemma3TextConfig): + super().__init__(config) + self.config = config + self.vocab_size = config.vocab_size + self.model = Gemma3Model(config) + self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma3 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **loss_kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten: has a special cache type, `HybridCache` + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if past_key_values is not None: + if inputs_embeds is not None or ( # Exception 1 + is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1] + ): # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s + # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride + # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the + # batch size = 1 case, `position_ids` is already contiguous but with varying stride + # which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = { + "input_ids": input_ids.clone(memory_format=torch.contiguous_format), + "inputs_embeds": None, + } + + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 + + if ( + isinstance(past_key_values, HybridCache) + and attention_mask.ndim == 2 + and not self.config._attn_implementation == "flash_attention_2" + ): + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + ) + + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + +@dataclass +class Gemma3CausalLMOutputWithPast(ModelOutput): + """ + Base class for PaliGemmacausal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or + when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each + tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the + embeddings, if the model has an embedding layer, + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin): + config_class = Gemma3Config + supports_gradient_checkpointing = True + + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_quantized_cache = True + _supports_sdpa = True + _supports_static_cache = True + + def __init__(self, config: Gemma3Config): + super().__init__(config) + + self.config = config + text_config = self.config.text_config + vision_config = self.config.vision_config + if vision_config is None: + raise ValueError( + "Atempted to initialize a `Gemma3ForConditionalGeneration` instance without a `Gemma3VisionConfig`; " + "either provide a vision config or use a `Gemma3ForCausalLM` instace for text-only generation." + ) + + self.language_model = Gemma3ForCausalLM(config=text_config) + self.vision_model = SiglipVisionModel(config=vision_config) + self.mm_input_projection = Gemma3MultimodalInputProjection( + vision_dim=vision_config.hidden_size, text_dim=text_config.hidden_size + ) + self.mm_soft_emb_norm = Gemma3RMSNorm(vision_config.hidden_size, eps=vision_config.layer_norm_eps) + + patches_per_image = vision_config.image_size // vision_config.patch_size + avg_pool_k = patches_per_image**2 // text_config.mm_tokens_per_image + self.avg_pool = nn.AvgPool1d(kernel_size=avg_pool_k, stride=avg_pool_k) + self.vocab_size = text_config.vocab_size + self.pad_token_id = pad_token_id if (pad_token_id := text_config.pad_token_id) is not None else -1 + self.post_init() + + # Copied from transformers.models.paligemma.modeling_paligemma.PaliGemmaForConditionalGeneration.get_input_embeddings with PaliGemma->Gema3 + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + # Copied from transformers.models.paligemma.modeling_paligemma.PaliGemmaForConditionalGeneration.set_input_embeddings with PaliGemma->Gema3 + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + # Copied from transformers.models.paligemma.modeling_paligemma.PaliGemmaForConditionalGeneration.get_output_embeddings with PaliGemma->Gema3 + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + # Copied from transformers.models.paligemma.modeling_paligemma.PaliGemmaForConditionalGeneration.set_output_embeddings with PaliGemma->Gema3 + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + # Copied from transformers.models.paligemma.modeling_paligemma.PaliGemmaForConditionalGeneration.set_decoder with PaliGemma->Gema3 + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + # Copied from transformers.models.paligemma.modeling_paligemma.PaliGemmaForConditionalGeneration.get_decoder with PaliGemma->Gema3 + def get_decoder(self): + return self.language_model.get_decoder() + + def encode_vision(self, x: torch.Tensor) -> torch.Tensor: + x = self.mm_soft_emb_norm(x) + x = self.mm_input_projection(x) + return x + + def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Projects the last hidden state from the vision model into language model space. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_outputs = self.vision_model(pixel_values=pixel_values).last_hidden_state + b, n, l = vision_outputs.shape + reshaped_vision_outputs = vision_outputs.permute(0, 2, 1) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + reshaped_vision_outputs = reshaped_vision_outputs.view(b, l, n) + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.permute(0, 2, 1) + image_features = self.encode_vision(pooled_vision_outputs) + return image_features + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_soft_token_mask: Optional[torch.BoolTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[tuple, Gemma3CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration + + >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf") + >>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf") + + >>> prompt = "answer en Where is the cow standing?" + >>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "answer en Where is the cow standing?\nbeach" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + 1 + + # Merge text and images + if pixel_values is not None: + if image_soft_token_mask is None: + raise ValueError( + "Cannot join vision and language embeddings wihtout an `image_soft_token_mask`. " + "Use Gemma3Processor to create one." + ) + + image_features = self.get_image_features(pixel_values).to(inputs_embeds.device, inputs_embeds.dtype) + + image_mask = image_soft_token_mask.unsqueeze(-1) + image_mask = image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if (emb_nel := inputs_embeds[image_mask].numel()) != (img_nel := image_features.numel()): + raise ValueError( + f"Number of image features ({img_nel}) does not match number of special image tokens in the input " + f"text ({emb_nel}). " + ) + + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) + + # mask out pad-token-ids in labels for BC + if labels is not None and self.pad_token_id in labels: + logger.warning_once( + "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " + "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", + ) + labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) + + causal_mask = self._update_causal_mask( + attention_mask, + token_type_ids, + past_key_values, + cache_position, + inputs_embeds, + is_training, + ) + outputs = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + logits = outputs.logits + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Gemma3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + pixel_values=None, + image_soft_token_mask=None, + attention_mask=None, + token_type_ids=None, + use_cache=True, + logits_to_keep=None, + labels=None, + **kwargs, + ): + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + token_type_ids=token_type_ids, + **kwargs, + ) + + # position_ids in Paligemma are 1-indexed + if model_inputs.get("position_ids") is not None: + model_inputs["position_ids"] += 1 + + # If we're in cached decoding stage, pixel values should be None because + # input ids do not contain special image tokens anymore. Otherwise we + # need pixel values to be passed to model. + # NOTE: use_cache=False needs pixel_values always + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + model_inputs["image_soft_token_mask"] = image_soft_token_mask + + is_training = token_type_ids is not None and labels is not None + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + input_tensor = inputs_embeds if inputs_embeds is not None else input_ids + causal_mask = self._update_causal_mask( + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_tensor, + is_training, + ) + model_inputs["attention_mask"] = causal_mask + + return model_inputs + + def _update_causal_mask( + self, + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_tensor, + is_training: bool = False, + ): + if self.config.text_config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + using_static_cache = isinstance(past_key_values, StaticCache) + min_dtype = torch.finfo(self.dtype).min + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + return attention_mask + + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device + ) + # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below + if sequence_length != 1: + if is_training: + causal_mask = torch.triu(causal_mask, diagonal=1) + else: + causal_mask[:, :sequence_length] = 0.0 + + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + # we are training thus we need to create a full mask on the image + prefix but causal on suffix + if is_training: + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 + ) + return causal_mask + + +__all__ = ["Gemma3PreTrainedModel", "Gemma3Model", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration"] From e044645d064e850192352bde9265080fe1b5098b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 8 Mar 2025 21:19:35 -0800 Subject: [PATCH 03/52] remove Signed-off-by: Woosuk Kwon --- modeling_gemma3.py | 1381 -------------------------------------------- 1 file changed, 1381 deletions(-) delete mode 100644 modeling_gemma3.py diff --git a/modeling_gemma3.py b/modeling_gemma3.py deleted file mode 100644 index 1ee2d7d056e7..000000000000 --- a/modeling_gemma3.py +++ /dev/null @@ -1,1381 +0,0 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_gemma3.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# coding=utf-8 -# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections.abc import Callable, Mapping -from dataclasses import dataclass -from typing import Literal, Optional, Union, cast - -import torch -import torch.nn as nn - -from ...activations import ACT2FN -from ...cache_utils import Cache, HybridCache, StaticCache -from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import Unpack -from ...utils import ( - ModelOutput, - add_start_docstrings_to_model_forward, - is_torchdynamo_compiling, - logging, - replace_return_docstrings, -) -from ...utils.deprecation import deprecate_kwarg -from ..gemma import GemmaPreTrainedModel -from ..siglip import SiglipVisionModel -from .configuration_gemma3 import Gemma3Config, Gemma3RotaryEmbeddingConfig, Gemma3TextConfig - - -logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "Gemma3Config" - - -class Gemma3RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.zeros(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()) - # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) - # See https://github.com/huggingface/transformers/pull/29402 - output = output * (1.0 + self.weight.float()) - return output.type_as(x) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.eps}" - - -class Gemma3MultimodalInputProjection(nn.Module): - def __init__(self, vision_dim: int, text_dim: int): - super().__init__() - self.weight = nn.Parameter(torch.zeros(vision_dim, text_dim)) - - def forward(self, x): - output = torch.einsum("btm,md->btd", x, self.weight) - return output.type_as(x) - - -class Gemma3MLP(nn.Module): - def __init__(self, config: Gemma3TextConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_activation] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -ATTENTION_TYPE_LOCAL = "local_sliding" -AttentionType = Literal["global", "local_sliding"] - - -def create_sliding_window_mask( - position_ids: torch.LongTensor, - cache_position: int, - cache_len: int, - sliding_window_size: int, -) -> torch.Tensor: - """Creates mask for sliding window attention.""" - total_tokens = cache_position + position_ids.shape[1] # cached + processing tokens - - def _reconstruct_rotated_cache_positions(): - cache_positions = torch.arange(cache_len) + total_tokens - cache_len - rotated_cache_positions = torch.zeros_like(cache_positions) - rotated_cache_positions[cache_positions % cache_len] = cache_positions - return rotated_cache_positions - - # Reconstruct position_ids for cached kv. - if total_tokens <= cache_len: - cache_positions = torch.arange(cache_len) - else: - cache_positions = _reconstruct_rotated_cache_positions() - - cache_positions = cache_positions.unsqueeze(0).unsqueeze(0).to(position_ids.device) # [1, 1, cache_len] - position_ids = position_ids.unsqueeze(-1) # [B, seq_len, 1] - sliding_mask = cache_positions > position_ids - sliding_window_size - sliding_mask *= cache_positions < position_ids + sliding_window_size - return sliding_mask.unsqueeze(1) - - -def eager_attention_forward( - module: "Gemma3Attention", - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - dropout: float = 0.0, - scaling: Optional[float] = None, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor]: - if scaling is None: - scaling = module.head_dim**-0.5 - - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -class Gemma3Attention(nn.Module): - def __init__(self, config: Gemma3TextConfig, layer_idx: int): - super().__init__() - self.attention_dropout = config.attention_dropout - self.attention_type: AttentionType = config.attention_pattern[layer_idx % len(config.attention_pattern)] - self.config = config - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.is_causal = True - self.layer_idx = layer_idx - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = config.query_pre_attn_scalar - self.is_sliding = self.attention_type == ATTENTION_TYPE_LOCAL - self.sliding_window = config.sliding_window if self.is_sliding else None - - self.q_proj = nn.Linear( - config.hidden_size, - config.num_attention_heads * self.head_dim, - bias=config.attention_bias, - ) - self.k_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=config.attention_bias, - ) - self.v_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=config.attention_bias, - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, - config.hidden_size, - bias=config.attention_bias, - ) - self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) - self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: Mapping[AttentionType, tuple[torch.Tensor, torch.Tensor]], - attention_mask: Optional[torch.Tensor], - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - last_cache_position: int = 0, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - query_states = self.q_norm(query_states) - key_states = self.k_norm(key_states) - - cos, sin = position_embeddings[self.attention_type] - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if ( - self.is_sliding - and self.sliding_window is not None - and attention_mask is not None - and position_ids is not None - ): - sliding_mask = create_sliding_window_mask( - position_ids=position_ids, - cache_position=last_cache_position, - cache_len=attention_mask.shape[-1], - sliding_window_size=self.sliding_window, - ) - attention_mask *= sliding_mask - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - # Why is this a torch.LongTensor? Feels like last_cache_position should be in here somehow? - "cache_position": cache_position, - "sliding_window": self.sliding_window, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": - seq_len = attention_mask.shape[-1] - key_states, value_states = ( - key_states[:, :, :seq_len, :], - value_states[:, :, :seq_len, :], - ) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " - "Falling back to eager attention. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = cast( - Callable[..., tuple[torch.Tensor, torch.Tensor]], - ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation], - ) - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=self.attention_dropout if self.training else 0.0, - scaling=self.scaling, - sliding_window=self.sliding_window, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class Gemma3DecoderLayer(nn.Module): - def __init__(self, config: Gemma3TextConfig, layer_idx: int): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.layer_idx = layer_idx - self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) - self.mlp = Gemma3MLP(config) - self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.is_sliding = self.self_attn.is_sliding - self.sliding_window = config.sliding_window - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: Mapping[AttentionType, tuple[torch.Tensor, torch.Tensor]], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - last_cache_position: int = 0, - **kwargs, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - position_embeddings=position_embeddings, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - last_cache_position=last_cache_position, - **kwargs, - ) - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.pre_feedforward_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = self.post_feedforward_layernorm(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -class Gemma3RotaryEmbedding(nn.Module): - def __init__(self, config: Gemma3RotaryEmbeddingConfig, device: torch.device = None): - super().__init__() - # BC: "rope_type" was originally "type" - if (rope_scaling := getattr(config, "rope_scaling", None)) is not None: - self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type")) - else: - self.rope_type = "default" - - self.config = config - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class Gemma3PreTrainedModel(GemmaPreTrainedModel): - base_model_prefix = "model" - config_class = Gemma3TextConfig - supports_gradient_checkpointing = True - _no_split_modules = ["Gemma3DecoderLayer"] - - -ATTENTION_TYPE_GLOBAL = "global" - - -class Gemma3Model(Gemma3PreTrainedModel): - def __init__(self, config: Gemma3TextConfig): - super().__init__(config) - self.config = config - - self.embed_tokens = nn.Embedding( - config.vocab_size, - config.hidden_size, - padding_idx=config.pad_token_id, - ) - self.layers = nn.ModuleList( - [Gemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb_global = Gemma3RotaryEmbedding( - config=Gemma3RotaryEmbeddingConfig( - hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - rope_theta=config.rope_global_base_freq, - head_dim=config.head_dim, - max_position_embeddings=config.max_position_embeddings, - ) - ) - self.rotary_emb_local = Gemma3RotaryEmbedding( - config=Gemma3RotaryEmbeddingConfig( - hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - rope_theta=config.rope_local_base_freq, - head_dim=config.head_dim, - max_position_embeddings=config.max_position_embeddings, - ) - ) - self.gradient_checkpointing = False - self.post_init() - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - last_cache_position: Optional[int] = None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None and not self.training: - batch_size, seq_len, _ = inputs_embeds.shape - past_key_values = HybridCache( - self.config, - max_batch_size=batch_size, - max_cache_len=seq_len, - dtype=inputs_embeds.dtype, - ) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - if last_cache_position is None: - last_cache_position = 0 - if attention_mask is not None: - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) - last_cache_position = ( - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() - ) - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values, - output_attentions, - ) - - # embed positions - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings: Mapping[AttentionType, tuple[torch.Tensor, torch.Tensor]] = { - ATTENTION_TYPE_GLOBAL: self.rotary_emb_global(hidden_states, position_ids), - ATTENTION_TYPE_LOCAL: self.rotary_emb_local(hidden_states, position_ids), - } - - # normalized - # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 - # See https://github.com/huggingface/transformers/pull/29402 - normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) - hidden_states = hidden_states * normalizer - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - position_embeddings, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - last_cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - last_cache_position=last_cache_position, - **flash_attn_kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - output = BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - return output if return_dict else output.to_tuple() - - @torch.no_grad() - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: HybridCache, - output_attentions: bool, - ): - # Flash Attention currently doesn't support static cache but Gemma3 work only with static cache. - # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape - # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible - # as it doesn't cause dynamic control issues. - if self.config._attn_implementation == "flash_attention_2": - return attention_mask - - dtype, device = input_tensor.dtype, input_tensor.device - sequence_length = input_tensor.shape[1] - if isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=device, - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -GEMMA3_INPUTS_DOCSTRING = "" - - -class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): - base_model_prefix = "language_model" - - def __init__(self, config: Gemma3TextConfig): - super().__init__(config) - self.config = config - self.vocab_size = config.vocab_size - self.model = Gemma3Model(config) - self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False) - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") - @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **loss_kwargs, - ) -> Union[tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, GemmaForCausalLM - - >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") - - >>> prompt = "What is your favorite condiment?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" - ```""" - - if self.training and self.config._attn_implementation != "eager": - logger.warning_once( - "It is strongly recommended to train Gemma3 models with the `eager` attention implementation " - f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." - ) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **loss_kwargs, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - if self.config.final_logit_softcapping is not None: - logits = logits / self.config.final_logit_softcapping - logits = torch.tanh(logits) - logits = logits * self.config.final_logit_softcapping - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - logits_to_keep=None, - **kwargs, - ): - # Overwritten: has a special cache type, `HybridCache` - - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - if past_key_values is not None: - if inputs_embeds is not None or ( # Exception 1 - is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1] - ): # Exception 3 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s - # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride - # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the - # batch size = 1 case, `position_ids` is already contiguous but with varying stride - # which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = { - "input_ids": input_ids.clone(memory_format=torch.contiguous_format), - "inputs_embeds": None, - } - - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 - - if ( - isinstance(past_key_values, HybridCache) - and attention_mask.ndim == 2 - and not self.config._attn_implementation == "flash_attention_2" - ): - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device - - attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - - if logits_to_keep is not None: - model_inputs["logits_to_keep"] = logits_to_keep - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - } - ) - return model_inputs - - -@dataclass -class Gemma3CausalLMOutputWithPast(ModelOutput): - """ - Base class for PaliGemmacausal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or - when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each - tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or - when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the - embeddings, if the model has an embedding layer, + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when - `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape - `(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder after projecting last hidden state. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None - hidden_states: Optional[tuple[torch.FloatTensor]] = None - attentions: Optional[tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[torch.FloatTensor] = None - - -class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin): - config_class = Gemma3Config - supports_gradient_checkpointing = True - - _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_quantized_cache = True - _supports_sdpa = True - _supports_static_cache = True - - def __init__(self, config: Gemma3Config): - super().__init__(config) - - self.config = config - text_config = self.config.text_config - vision_config = self.config.vision_config - if vision_config is None: - raise ValueError( - "Atempted to initialize a `Gemma3ForConditionalGeneration` instance without a `Gemma3VisionConfig`; " - "either provide a vision config or use a `Gemma3ForCausalLM` instace for text-only generation." - ) - - self.language_model = Gemma3ForCausalLM(config=text_config) - self.vision_model = SiglipVisionModel(config=vision_config) - self.mm_input_projection = Gemma3MultimodalInputProjection( - vision_dim=vision_config.hidden_size, text_dim=text_config.hidden_size - ) - self.mm_soft_emb_norm = Gemma3RMSNorm(vision_config.hidden_size, eps=vision_config.layer_norm_eps) - - patches_per_image = vision_config.image_size // vision_config.patch_size - avg_pool_k = patches_per_image**2 // text_config.mm_tokens_per_image - self.avg_pool = nn.AvgPool1d(kernel_size=avg_pool_k, stride=avg_pool_k) - self.vocab_size = text_config.vocab_size - self.pad_token_id = pad_token_id if (pad_token_id := text_config.pad_token_id) is not None else -1 - self.post_init() - - # Copied from transformers.models.paligemma.modeling_paligemma.PaliGemmaForConditionalGeneration.get_input_embeddings with PaliGemma->Gema3 - def get_input_embeddings(self): - return self.language_model.get_input_embeddings() - - # Copied from transformers.models.paligemma.modeling_paligemma.PaliGemmaForConditionalGeneration.set_input_embeddings with PaliGemma->Gema3 - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - # Copied from transformers.models.paligemma.modeling_paligemma.PaliGemmaForConditionalGeneration.get_output_embeddings with PaliGemma->Gema3 - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - # Copied from transformers.models.paligemma.modeling_paligemma.PaliGemmaForConditionalGeneration.set_output_embeddings with PaliGemma->Gema3 - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - # Copied from transformers.models.paligemma.modeling_paligemma.PaliGemmaForConditionalGeneration.set_decoder with PaliGemma->Gema3 - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - # Copied from transformers.models.paligemma.modeling_paligemma.PaliGemmaForConditionalGeneration.get_decoder with PaliGemma->Gema3 - def get_decoder(self): - return self.language_model.get_decoder() - - def encode_vision(self, x: torch.Tensor) -> torch.Tensor: - x = self.mm_soft_emb_norm(x) - x = self.mm_input_projection(x) - return x - - def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: - """ - Projects the last hidden state from the vision model into language model space. - - Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) - The tensors corresponding to the input images. - Returns: - image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). - """ - vision_outputs = self.vision_model(pixel_values=pixel_values).last_hidden_state - b, n, l = vision_outputs.shape - reshaped_vision_outputs = vision_outputs.permute(0, 2, 1) - reshaped_vision_outputs = reshaped_vision_outputs.contiguous() - reshaped_vision_outputs = reshaped_vision_outputs.view(b, l, n) - pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) - pooled_vision_outputs = pooled_vision_outputs.permute(0, 2, 1) - image_features = self.encode_vision(pooled_vision_outputs) - return image_features - - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") - @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - image_soft_token_mask: Optional[torch.BoolTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, - token_type_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, - ) -> Union[tuple, Gemma3CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration - - >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf") - >>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf") - - >>> prompt = "answer en Where is the cow standing?" - >>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_length=30) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "answer en Where is the cow standing?\nbeach" - ```""" - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - is_training = token_type_ids is not None and labels is not None - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) + 1 - - # Merge text and images - if pixel_values is not None: - if image_soft_token_mask is None: - raise ValueError( - "Cannot join vision and language embeddings wihtout an `image_soft_token_mask`. " - "Use Gemma3Processor to create one." - ) - - image_features = self.get_image_features(pixel_values).to(inputs_embeds.device, inputs_embeds.dtype) - - image_mask = image_soft_token_mask.unsqueeze(-1) - image_mask = image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - - if (emb_nel := inputs_embeds[image_mask].numel()) != (img_nel := image_features.numel()): - raise ValueError( - f"Number of image features ({img_nel}) does not match number of special image tokens in the input " - f"text ({emb_nel}). " - ) - - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) - - # mask out pad-token-ids in labels for BC - if labels is not None and self.pad_token_id in labels: - logger.warning_once( - "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " - "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", - ) - labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) - - causal_mask = self._update_causal_mask( - attention_mask, - token_type_ids, - past_key_values, - cache_position, - inputs_embeds, - is_training, - ) - outputs = self.language_model( - attention_mask=causal_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - **lm_kwargs, - ) - - logits = outputs.logits - loss = None - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - shift_logits = logits[..., :-1, :] - shift_labels = labels[..., 1:] - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) - shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() - else: - shift_logits = shift_logits.contiguous() - shift_labels = shift_labels.contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - - flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) - flat_labels = shift_labels.view(-1).to(shift_logits.device) - loss = loss_fct(flat_logits, flat_labels) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Gemma3CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - pixel_values=None, - image_soft_token_mask=None, - attention_mask=None, - token_type_ids=None, - use_cache=True, - logits_to_keep=None, - labels=None, - **kwargs, - ): - # Overwritten -- custom `position_ids` and `pixel_values` handling - model_inputs = self.language_model.prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - cache_position=cache_position, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - token_type_ids=token_type_ids, - **kwargs, - ) - - # position_ids in Paligemma are 1-indexed - if model_inputs.get("position_ids") is not None: - model_inputs["position_ids"] += 1 - - # If we're in cached decoding stage, pixel values should be None because - # input ids do not contain special image tokens anymore. Otherwise we - # need pixel values to be passed to model. - # NOTE: use_cache=False needs pixel_values always - if cache_position[0] == 0: - model_inputs["pixel_values"] = pixel_values - model_inputs["image_soft_token_mask"] = image_soft_token_mask - - is_training = token_type_ids is not None and labels is not None - if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): - input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self._update_causal_mask( - attention_mask, - token_type_ids, - past_key_values, - cache_position, - input_tensor, - is_training, - ) - model_inputs["attention_mask"] = causal_mask - - return model_inputs - - def _update_causal_mask( - self, - attention_mask, - token_type_ids, - past_key_values, - cache_position, - input_tensor, - is_training: bool = False, - ): - if self.config.text_config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - using_static_cache = isinstance(past_key_values, StaticCache) - min_dtype = torch.finfo(self.dtype).min - inputs_lead_dim, sequence_length = input_tensor.shape[:2] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - elif isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else cache_position[0] + sequence_length + 1 - ) - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - return attention_mask - - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device - ) - # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below - if sequence_length != 1: - if is_training: - causal_mask = torch.triu(causal_mask, diagonal=1) - else: - causal_mask[:, :sequence_length] = 0.0 - - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - # we are training thus we need to create a full mask on the image + prefix but causal on suffix - if is_training: - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 - ) - return causal_mask - - -__all__ = ["Gemma3PreTrainedModel", "Gemma3Model", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration"] From bd78da5a64c9f60115d12c17a9ad7b45a9160c3c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 8 Mar 2025 21:59:17 -0800 Subject: [PATCH 04/52] [TMP] Add HF Gemma 3 Signed-off-by: Woosuk Kwon --- hf-gemma3/__init__.py | 29 + hf-gemma3/configuration_gemma3.py | 281 ++++ .../convert_gemma3_weights_orbax_to_hf.py | 537 ++++++ hf-gemma3/image_processing_gemma3.py | 405 +++++ hf-gemma3/modeling_gemma3.py | 1466 +++++++++++++++++ hf-gemma3/modular_gemma3.py | 1036 ++++++++++++ hf-gemma3/processing_gemma3.py | 159 ++ 7 files changed, 3913 insertions(+) create mode 100644 hf-gemma3/__init__.py create mode 100644 hf-gemma3/configuration_gemma3.py create mode 100644 hf-gemma3/convert_gemma3_weights_orbax_to_hf.py create mode 100644 hf-gemma3/image_processing_gemma3.py create mode 100644 hf-gemma3/modeling_gemma3.py create mode 100644 hf-gemma3/modular_gemma3.py create mode 100644 hf-gemma3/processing_gemma3.py diff --git a/hf-gemma3/__init__.py b/hf-gemma3/__init__.py new file mode 100644 index 000000000000..e8e1c60bd56e --- /dev/null +++ b/hf-gemma3/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_gemma3 import * + from .modeling_gemma3 import * + from .image_processing_gemma3 import * + from .processing_gemma3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/hf-gemma3/configuration_gemma3.py b/hf-gemma3/configuration_gemma3.py new file mode 100644 index 000000000000..1dec0bc30ae2 --- /dev/null +++ b/hf-gemma3/configuration_gemma3.py @@ -0,0 +1,281 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Sequence +from typing import Literal, Optional, cast + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging +from ..siglip import SiglipVisionConfig + + +logger = logging.get_logger(__name__) + +ATTENTION_TYPE_GLOBAL = "global" +ATTENTION_TYPE_LOCAL = "local_sliding" +AttentionType = Literal["global", "local_sliding"] +AttentionPattern = Sequence[AttentionType] +DEFAULT_ATTENION_PATTERN = cast( + AttentionPattern, + ( + ATTENTION_TYPE_LOCAL, + ATTENTION_TYPE_LOCAL, + ATTENTION_TYPE_LOCAL, + ATTENTION_TYPE_LOCAL, + ATTENTION_TYPE_LOCAL, + ATTENTION_TYPE_GLOBAL, + ), +) + + +class Gemma3TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3Model`]. It is used to instantiate a Gemma3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma3-7B. + e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Gemma3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Gemma3Model`] + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + hidden_size (`int`, *optional*, defaults to 2304): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 9216): + Dimension of the MLP representations. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to + `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` + activation function. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_local_base_freq (float, *optional*, defaults to `rope_theta`): + The base period of the RoPE embeddings for local attention. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_pattern (Sequence[AttentionTypes], defaults to (5 * local, global)): + The attention pattern to apply + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to None): + The scaling factor used on the attention scores, not that + sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window + attention. This is the size of the sliding window. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh soft-capping + on the attention scorexs. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + + ```python + >>> from transformers import Gemma3Model, Gemma3Config + >>> # Initializing a Gemma3 gemma3-7b style configuration + >>> configuration = Gemma3Config() + >>> # Initializing a model from the gemma3-7b style configuration + >>> model = Gemma3Model(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma3_text" + + def __init__( + self, + # Config parameters found in all implementations, name differences noted + vocab_size: int = 262_144, # num_embed in FLAX + hidden_size: int = 2304, # embed_dim in FLAX + intermediate_size: int = 9216, # hidden_dim in FLAX + num_hidden_layers: int = 26, # num_layers in FLAX + num_attention_heads: int = 8, # num_heads in FLAX + num_key_value_heads: int = 4, # num_kv_heads in FLAX + head_dim: int = 256, + sliding_window: int = 4096, # sliding_window_size in FLAX + query_pre_attn_scalar: Optional[float] = None, + attention_pattern: AttentionPattern = DEFAULT_ATTENION_PATTERN, + rope_theta: float = 1_000_000.0, + rope_scaling=None, + rope_local_base_freq: float = 10_000.0, + rms_norm_eps: float = 1e-6, + hidden_activation: str = "gelu_pytorch_tanh", + pad_token_id: int = 0, + eos_token_id: int = 1, + bos_token_id: int = 2, + tie_word_embeddings: bool = True, + max_position_embeddings: int = 131_072, + initializer_range: float = 0.02, + attention_bias: bool = False, + attention_dropout: float = 0.0, + use_cache: bool = True, + final_logit_softcapping=None, + attn_logit_softcapping=None, + cache_implementation: str = "hybrid", + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.rope_local_base_freq = rope_local_base_freq + self.attention_pattern = attention_pattern + # For configuring HybridCache to work with 5:1 attention pattern + self.sliding_window_pattern = len(self.attention_pattern) + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.query_pre_attn_scalar = query_pre_attn_scalar + self.sliding_window = sliding_window + self.final_logit_softcapping = final_logit_softcapping + self.attn_logit_softcapping = attn_logit_softcapping + self.cache_implementation = cache_implementation + rope_config_validation(self) + + +class Gemma3Config(PretrainedConfig): + model_type = "gemma3" + sub_configs = { + "text_config": Gemma3TextConfig, + "vision_config": SiglipVisionConfig, + } + + def __init__( + self, + text_config: Optional[Gemma3TextConfig] = None, + vision_config: Optional[SiglipVisionConfig] = None, + mm_tokens_per_image: int = 256, + boi_token_index: int = 255_999, + eoi_token_index: int = 256_000, + image_token_index: int = 262_144, + initializer_range: float = 0.02, + **kwargs, + ): + if text_config is None: + text_config = Gemma3TextConfig() + logger.info("text_config is None, using default Gemma3TextConfig vision config.") + elif isinstance(text_config, dict): + text_config = Gemma3TextConfig(**text_config) + + if isinstance(vision_config, dict): + vision_config = SiglipVisionConfig(**vision_config) + else: + vision_config = SiglipVisionConfig() + logger.info( + "vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited " + "to text tasks." + ) + + self.text_config = text_config + self.vision_config = vision_config + self.mm_tokens_per_image = mm_tokens_per_image + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + +__all__ = ["Gemma3Config", "Gemma3TextConfig"] diff --git a/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py b/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py new file mode 100644 index 000000000000..6f116976bc4d --- /dev/null +++ b/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py @@ -0,0 +1,537 @@ +r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint. + +python -m transformers.models.gemma3.convert_gemma3_weights_orbax_to_hf \ + --variant='gemma3_4b' \ + --tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \ + --checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \ + --output_path="$HOME/gemma3/gemma3_4b_pt_safetensors/" \ + --precision='bfloat16' +""" + +import dataclasses +import math +from collections.abc import Iterator, Sequence +from typing import Any + +import accelerate +import numpy as np +import torch +import tree +from absl import app, flags, logging +from orbax import checkpoint as obc + +from ..gemma import GemmaTokenizerFast +from ...image_utils import PILImageResampling +from . import ( + Gemma3ForCausalLM, + Gemma3ForConditionalGeneration, + Gemma3Processor, + Gemma3ImageProcessor, +) +from .configuration_gemma3 import ( + DEFAULT_ATTENION_PATTERN, + Gemma3Config, + Gemma3TextConfig, + SiglipVisionConfig, +) + + +# ==== Internal Constants and Classes ==== + +_CHAT_TEMPLATE = ( + "{{ bos_token }}{% set system_message = '' %}{% if messages[0]['role'] == 'system' %}" + "{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}" + "{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}" + "{% if loop.index0 == 0 and message['role'] == 'user' %}" + "{{ '' + message['role'] + '\n' + system_message + message['content'] | trim + '\n' }}" + "{% elif (message['role'] == 'assistant') %}{% set role = 'model' %}" + "{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% else %}" + "{{ '' + message['role'] + '\n' + message['content'] | trim + '\n' }}{% endif %}" + "{% endfor %}{% if add_generation_prompt %}{{ 'model\n' }}{% endif %}" +) + +_DTYPES = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, +} + +_SIGLIP_BASE = "SigLiPFromPatches_0/siglip_encoder" +_SIGLIP_EMBEDDING = "SigLiPFromPatches_0/siglip_encoder/embedding" +_SIGLIP_TRANSFORMER_ENCODER_BLOCK = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_" +_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK) +_SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm" + +_TRANSFORMER_DECODER_BLOCK = "transformer/layer_" +_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK) +_TRANSFORMER_EMBEDDER = "transformer/embedder" +_TRANSFORMER_FINAL_NORM = "transformer/final_norm" +_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/" +_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX) + +_VISION_CONFIG = { + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + "num_channels": 3, + "image_size": 896, + "patch_size": 14, + "hidden_act": "gelu_pytorch_tanh", + "layer_norm_eps": 1e-6, + "attention_dropout": 0.0, + "vision_use_head": False, +} + +_VARIANT_GEMMA_3_1B = "gemma3_1b" +_VARIANT_GEMMA_3_4B = "gemma3_4b" +_VARIANT_GEMMA_3_12B = "gemma3_12b" +_VARIANT_GEMMA_3_27B = "gemma3_27b" +_VARIANTS = { + _VARIANT_GEMMA_3_1B: Gemma3Config( + text_config=Gemma3TextConfig( + vocab_size=262_144, + hidden_size=1152, + intermediate_size=6912, + num_attention_heads=4, + num_hidden_layers=26, + num_key_value_heads=1, + attention_pattern=DEFAULT_ATTENION_PATTERN, + sliding_window=512, + rope_global_base_freq=1_000_000, + rope_local_base_freq=10_000, + attn_logit_softcapping=None, + query_pre_attn_scalar=256**-0.5, + max_position_embeddings=32_768, + ), + vision_config=None, + ), + _VARIANT_GEMMA_3_4B: Gemma3Config( + text_config=Gemma3TextConfig( + vocab_size=262_144, + hidden_size=2560, + intermediate_size=10_240, + num_attention_heads=8, + num_hidden_layers=34, + num_key_value_heads=4, + attention_pattern=DEFAULT_ATTENION_PATTERN, + sliding_window=1024, + rope_global_base_freq=1_000_000, + rope_local_base_freq=10_000, + attn_logit_softcapping=None, + query_pre_attn_scalar=256**-0.5, + ), + vision_config=_VISION_CONFIG, + ), + _VARIANT_GEMMA_3_12B: Gemma3Config( + text_config=Gemma3TextConfig( + vocab_size=262_144, + hidden_size=3840, + intermediate_size=3840 * 8 // 2, + num_attention_heads=16, + num_hidden_layers=48, + num_key_value_heads=8, + attention_pattern=DEFAULT_ATTENION_PATTERN, + sliding_window=1024, + rope_global_base_freq=1_000_000, + rope_local_base_freq=10_000, + attn_logit_softcapping=None, + query_pre_attn_scalar=256**-0.5, + ), + vision_config=_VISION_CONFIG, + ), + _VARIANT_GEMMA_3_27B: Gemma3Config( + text_config=Gemma3TextConfig( + vocab_size=262_144, + hidden_size=5376, + intermediate_size=5376 * 8 // 2, + num_attention_heads=32, + num_hidden_layers=62, + num_key_value_heads=16, + head_dim=128, + attention_pattern=DEFAULT_ATTENION_PATTERN, + sliding_window=1024, + rope_global_base_freq=1_000_000, + rope_local_base_freq=10_000, + attn_logit_softcapping=None, + query_pre_attn_scalar=1 / math.sqrt(5376 // 32), # 1 / sqrt(hidden_size // num_attention_heads) + ), + vision_config=_VISION_CONFIG, + ), +} + +# ==== Flags ==== + +CHECKPOINT_PATH = flags.DEFINE_string( + name="checkpoint_path", + default=None, + help="Path to the Orbax checkpoint.", + required=True, +) + +INCLUDE_CHAT_TEMPLATE = flags.DEFINE_bool( + name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer" +) + +OUTPUT_PATH = flags.DEFINE_string( + name="output_path", + default=None, + help="Path to store the HF checkpoint.", + required=True, +) + +PRECISION = flags.DEFINE_enum( + name="precision", + default=None, + help="The floating point precision (aka dtype) of the model.", + enum_values=set(_DTYPES.keys()), + required=True, +) + +_TEXT_ONLY = flags.DEFINE_bool( + name="text_only", + default=False, + help=( + "If True, the model is loaded and saved as a Gemma3ForCausalLM, " + "otherwise model saed as Gemma3ForConditionalGeneration." + ), +) + +TOKENIZER_PATH = flags.DEFINE_string( + name="tokenizer_path", + default=None, + help="Path to the SentencePiece model file.", + required=True, +) + +_VARIANT = flags.DEFINE_enum( + name="variant", + default=_VARIANT_GEMMA_3_4B, + help="The model variant to convert.", + enum_values=set(_VARIANTS.keys()), +) + + +def convert_siglip_weight( + config: SiglipVisionConfig, + paths: Sequence[str], + weights: np.ndarray, +) -> tuple[str, np.ndarray]: + path, prop = paths + normalized_path: str = "" + updated_weights: np.ndarray = None + + if path == _SIGLIP_BASE: + normalized_path = "vision_tower.vision_model.embeddings.position_embedding.weight" + updated_weights = weights.reshape(-1, config.hidden_size) + elif path == _SIGLIP_EMBEDDING: + if prop == "kernel": + normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.weight" + updated_weights = weights.transpose(3, 2, 0, 1) + elif prop == "bias": + normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") + elif path.startswith(_SIGLIP_TRANSFORMER_ENCODER_BLOCK): + encoder_block_path = path[_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN:] + next_path_seperator_idx = encoder_block_path.find("/") + layer_idx = encoder_block_path[:next_path_seperator_idx] + encoder_block_path = encoder_block_path[next_path_seperator_idx:] + normalized_path = f"vision_tower.vision_model.encoder.layers.{layer_idx}" + + if encoder_block_path.startswith("/LayerNorm"): + normalized_path += ".layer_norm1" if path.endswith("_0") else ".layer_norm2" + + if prop == "scale": + normalized_path += ".weight" + updated_weights = weights.transpose() + elif prop == "bias": + normalized_path += ".bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.") + elif encoder_block_path.startswith("/MlpBlock_0"): + normalized_path += ".mlp.fc1" if "/Dense_0" in encoder_block_path else ".mlp.fc2" + + if prop == "kernel": + normalized_path += ".weight" + updated_weights = weights.transpose() + elif prop == "bias": + normalized_path += ".bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") + elif encoder_block_path.startswith("/MultiHeadDotProductAttention_0"): + if encoder_block_path.endswith("/key"): + normalized_path += ".self_attn.k_proj" + elif encoder_block_path.endswith("/out"): + normalized_path += ".self_attn.out_proj" + elif encoder_block_path.endswith("/query"): + normalized_path += ".self_attn.q_proj" + elif encoder_block_path.endswith("/value"): + normalized_path += ".self_attn.v_proj" + else: + raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer MultiHeadDotProductAttention_0.") + + if prop == "bias": + normalized_path += ".bias" + updated_weights = weights.reshape(-1, config.hidden_size).reshape(-1) + elif prop == "kernel": + normalized_path += ".weight" + updated_weights = weights.reshape(-1, config.hidden_size).transpose() + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") + else: + raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer Encoder Block.") + elif path == _SIGLIP_TRANSFORMER_ENCODER_NORM: + if prop == "scale": + normalized_path = "vision_tower.vision_model.post_layernorm.weight" + updated_weights = weights.transpose() + elif prop == "bias": + normalized_path = "vision_tower.vision_model.post_layernorm.bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.") + else: + raise ValueError(f"Unexpected path `{path}`.") + + if "vision" in normalized_path: + print(normalized_path) + return normalized_path, updated_weights + + +def convert_transformer_weights( + config: Gemma3TextConfig, + paths: Sequence[str], + weights: np.ndarray, +) -> Iterator[tuple[str, np.ndarray]]: + path, prop = paths + + if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX): + path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:] + + converted_paths: list[str] = [] + converted_weights: list[Any] = [] + + attn_head_dim = config.num_attention_heads * config.head_dim + kv_head_dim = config.num_key_value_heads * config.head_dim + + if path == _TRANSFORMER_EMBEDDER: + if prop == "input_embedding": + # Tied to language_model.lm_head.weight, assigned at the end. + converted_paths = ["language_model.model.embed_tokens.weight"] + converted_weights = [weights] + elif _TEXT_ONLY.value or prop in ("mm_output_embedding", "mm_input_embedding_extra"): + return zip([], []) + else: + raise ValueError(f"Unexpected member, {prop}, in Embedder.") + elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"): + if _TEXT_ONLY.value: + return zip([], []) + + if path.endswith("/mm_input_projection"): + converted_paths = ["multi_modal_projector.mm_input_projection_weight"] + converted_weights = [weights] + elif path.endswith("/mm_soft_embedding_norm"): + converted_paths = ["multi_modal_projector.mm_soft_emb_norm.weight"] + converted_weights = [weights] + else: + raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.") + elif path == _TRANSFORMER_FINAL_NORM: + converted_paths = ["language_model.model.norm.weight"] + converted_weights = [weights] + elif path.startswith(_TRANSFORMER_DECODER_BLOCK): + decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:] + next_path_seperator_idx = decoder_block_path.find("/") + layer_idx = decoder_block_path[:next_path_seperator_idx] + decoder_block_path = decoder_block_path[next_path_seperator_idx:] + + base_path = f"language_model.model.layers.{layer_idx}" + + if path.endswith("attn/attn_vec_einsum"): + converted_paths = [f"{base_path}.self_attn.o_proj.weight"] + converted_weights = [weights.transpose(2, 0, 1).reshape(config.hidden_size, attn_head_dim)] + elif path.endswith("attn/_key_norm"): + converted_paths = [f"{base_path}.self_attn.k_norm.weight"] + converted_weights = [weights] + elif path.endswith("attn/kv_einsum"): + converted_paths = [ + f"{base_path}.self_attn.k_proj.weight", + f"{base_path}.self_attn.v_proj.weight", + ] + k_proj_weights, v_proj_weights = weights + converted_weights = [ + k_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size), + v_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size), + ] + elif path.endswith("attn/q_einsum"): + converted_paths = [f"{base_path}.self_attn.q_proj.weight"] + converted_weights = [weights.transpose(0, 2, 1).reshape(attn_head_dim, config.hidden_size)] + elif path.endswith("attn/_query_norm"): + converted_paths = [f"{base_path}.self_attn.q_norm.weight"] + converted_weights = [weights] + elif path.endswith("mlp/gating_einsum"): + converted_paths = [ + f"{base_path}.mlp.gate_proj.weight", + f"{base_path}.mlp.up_proj.weight", + ] + gate_proj_weight, up_proj_weight = weights + converted_weights = [gate_proj_weight, up_proj_weight] + elif path.endswith("mlp/linear"): + converted_paths = [f"{base_path}.mlp.down_proj.weight"] + converted_weights = [weights.transpose()] + elif path.endswith("post_attention_norm"): + converted_paths = [f"{base_path}.post_attention_layernorm.weight"] + converted_weights = [weights] + elif path.endswith("post_ffw_norm"): + converted_paths = [f"{base_path}.post_feedforward_layernorm.weight"] + converted_weights = [weights] + elif path.endswith("pre_attention_norm"): + converted_paths = [f"{base_path}.input_layernorm.weight"] + converted_weights = [weights] + elif path.endswith("pre_ffw_norm"): + converted_paths = [f"{base_path}.pre_feedforward_layernorm.weight"] + converted_weights = [weights] + else: + raise ValueError(f"Unexpected path `{path}` in Decoder Block.") + else: + raise ValueError(f"Unexpected path `{path}`.") + + if (cpl := len(converted_paths)) != (cwl := len(converted_weights)): + raise ValueError( + "The `converted_paths` and `converted_weights` should be the same " + f"length. Got {cpl} and {cwl}, respectively, for {path}." + ) + + return zip(converted_paths, converted_weights) + + +def transpose_reshape(x: torch.Tensor) -> torch.Tensor: + x = x.transpose(1, 2) + return x.reshape(x.shape[0] * x.shape[1], x.shape[2]).contiguous() + + +@dataclasses.dataclass(frozen=True) +class ConversionResult: + state_tree: dict[str, torch.Tensor] + config: Gemma3Config + + +def convert( + checkpoint_path: str, + config: Gemma3Config, + target_dtype: torch.dtype, +) -> ConversionResult: + """Loads Orbax checkpoint from `input_path` and converts it to HF tree.""" + checkpointer = obc.PyTreeCheckpointer() + ckpt = checkpointer.restore(checkpoint_path) + hf_tree: dict[str, torch.Tensor] = {} + + def update_tree(path: str, weights: np.ndarray) -> None: + torch_tensor = torch.from_numpy(weights.astype("float32")).type(target_dtype) + logging.info( + "%s converted shape=%s with dtype=%s", + path, + weights.shape, + torch_tensor.dtype, + ) + hf_tree[path] = torch_tensor + + for paths, value in tree.flatten_with_path(ckpt): + if paths[0].startswith("SigLiPFromPatches_"): + if config.vision_config is None: + continue + + path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value) + update_tree(path, weights) + else: + for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value): + if config.vision_config is None: + path = path[len("language_model.") :] + + update_tree(path, weights) + + if config.vision_config is None: + hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"] + else: + hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"] + + return ConversionResult(state_tree=hf_tree, config=config) + + +def main(*args): + del args + + variant = _VARIANT.value + dtype = getattr(torch, PRECISION.value) + config = _VARIANTS[variant] + output_path = OUTPUT_PATH.value + + if variant == _VARIANT_GEMMA_3_1B: + flags.FLAGS.set_default(_TEXT_ONLY.name, True) + + tokenizer = GemmaTokenizerFast( + TOKENIZER_PATH.value, + add_bos_token=True, + extra_special_tokens={ + "image_token": "", # Should be ID=262_144 + "boi_token": "", # Should be ID=255_999 + "eoi_token": "", # Should be ID=256_000 + }, + ) + + if INCLUDE_CHAT_TEMPLATE.value: + tokenizer.chat_template = _CHAT_TEMPLATE + + if _TEXT_ONLY.value: + config.vision_config = None + tokenizer.save_pretrained(output_path) + logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path) + del tokenizer + else: + image_processor = Gemma3ImageProcessor( + image_seq_length=256, + image_mean=(127.5,) * 3, + image_std=(127.5,) * 3, + size={"height": 896, "width": 896}, + do_rescale=False, + resample=PILImageResampling.BILINEAR, + ) + processor = Gemma3Processor( + image_processor=image_processor, + tokenizer=tokenizer, + ) + processor.save_pretrained(output_path) + logging.info("Saved Gemma3Processor for %s to %s", variant, output_path) + del processor + del tokenizer + + logging.info("Gemma 3 (%s) configured as: %s", variant, config) + logging.info("Converting Gemma 3 (%s) @ %s", variant, dtype) + result = convert(CHECKPOINT_PATH.value, config, dtype) + logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant) + + with accelerate.init_empty_weights(): + if config.vision_config is None: + model = Gemma3ForCausalLM(config=config.text_config) + else: + model = Gemma3ForConditionalGeneration(config) + + model.load_state_dict(result.state_tree, assign=True, strict=True) + model.config.torch_dtype = dtype + logging.info("Loaded Gemma 3 (%s) in Hugging Face Transformers as a %s instance.", variant, type(model).__name__) + model.save_pretrained(output_path, safe_serialization=True) + logging.info( + "Saved Gemma 3 (%s) to SafeTensors in %s using %s", + variant, + output_path, + type(model).__name__, + ) + del model + del result + + +if __name__ == "__main__": + app.run(main) diff --git a/hf-gemma3/image_processing_gemma3.py b/hf-gemma3/image_processing_gemma3.py new file mode 100644 index 000000000000..2787057cf618 --- /dev/null +++ b/hf-gemma3/image_processing_gemma3.py @@ -0,0 +1,405 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Gemma3.""" + +import itertools +import math +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_nested_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +class Gemma3ImageProcessor(BaseImageProcessor): + r""" + Constructs a SigLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image by the specified mean and standard deviation. Can be overridden by + `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_pan_and_scan (`bool`, *optional*): + Whether to apply `pan_and_scan` to images. + """ + + model_input_names = ["pixel_values", "num_crops"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = False, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + do_pan_and_scan: bool = None, + pan_and_scan_min_crop_size: int = None, + pan_and_scan_max_num_crops: int = None, + pan_and_scan_min_ratio_to_activate: float = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + self.do_pan_and_scan = do_pan_and_scan + self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size + self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops + self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate + + def pan_and_scan( + self, + image: np.ndarray, + pan_and_scan_min_crop_size: int, + pan_and_scan_max_num_crops: int, + pan_and_scan_min_ratio_to_activate: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pan and Scan and image, whatever it means. TODO: write-up docs + + Args: + image (`np.ndarray`): + Image to resize. + pan_and_scan_min_crop_size (`int`): + Size of pan_and_scan_min_crop_size. + pan_and_scan_max_num_crops (`int`): + pan_and_scan_max_num_crops for the image. + pan_and_scan_min_ratio_to_activate (`int`): + pan_and_scan_min_ratio_to_activate for the image.. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + height, width = get_image_size(image) + + # Square or landscape image. + if width >= height: + # Only apply PaS if the image is sufficiently exaggerated + if width / height < pan_and_scan_min_ratio_to_activate: + return [] + + # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. + num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding. + num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w) + + # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. + num_crops_w = max(2, num_crops_w) + num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) + num_crops_h = 1 + + # Portrait image. + else: + # Only apply PaS if the image is sufficiently exaggerated + if height / width < pan_and_scan_min_ratio_to_activate: + return [] + + # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. + num_crops_h = int(math.floor(height / width + 0.5)) + num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h) + + # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. + num_crops_h = max(2, num_crops_h) + num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) + num_crops_w = 1 + + crop_size_w = int(math.ceil(width / num_crops_w)) + crop_size_h = int(math.ceil(height / num_crops_h)) + + # Don't apply PaS if crop size is too small. + if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: + return [] + + crop_positions_w = [crop_size_w * i for i in range(num_crops_w)] + crop_positions_h = [crop_size_h * i for i in range(num_crops_h)] + + return [ + image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] + for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w) + ] + + def _process_images_for_pas( + self, + images: List[np.ndarray], + do_pan_and_scan: bool, + pan_and_scan_min_crop_size: int, + pan_and_scan_max_num_crops: int, + pan_and_scan_min_ratio_to_activate: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + pas_images_list = [] + num_crops = [] + for image in images: + pas_images = self.pan_and_scan( + image=image, + pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, + pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, + pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, + data_format=data_format, + input_data_format=input_data_format, + ) + pas_images_list.extend([image] + pas_images) + num_crops.append(len(pas_images)) + return pas_images_list, num_crops + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: bool = None, + do_pan_and_scan: bool = None, + pan_and_scan_min_crop_size: int = None, + pan_and_scan_max_num_crops: int = None, + pan_and_scan_min_ratio_to_activate: float = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + do_pan_and_scan (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to apply `pan_and_scan` to images. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_pan_and_scan = do_pan_and_scan if do_pan_and_scan is not None else self.do_pan_and_scan + pan_and_scan_min_crop_size = ( + pan_and_scan_min_crop_size if pan_and_scan_min_crop_size is not None else self.pan_and_scan_min_crop_size + ) + pan_and_scan_max_num_crops = ( + pan_and_scan_max_num_crops if pan_and_scan_max_num_crops is not None else self.pan_and_scan_max_num_crops + ) + pan_and_scan_min_ratio_to_activate = ( + pan_and_scan_min_ratio_to_activate + if pan_and_scan_min_ratio_to_activate is not None + else self.pan_and_scan_min_ratio_to_activate + ) + + images_list = make_nested_list_of_images(images) + + if not valid_images(images_list[0]): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + if do_convert_rgb: + images_list = [[convert_to_rgb(image) for image in images] for images in images_list] + + # All transformations expect numpy arrays. + images_list = [[to_numpy_array(image) for image in images] for images in images_list] + + if do_rescale and is_scaled_image(images_list[0][0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images_list[0][0]) + + if do_pan_and_scan: + images_list_and_num_crops = [ + self._process_images_for_pas( + images=images, + do_pan_and_scan=do_pan_and_scan, + pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, + pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, + pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, + data_format=data_format, + input_data_format=input_data_format, + ) + for images in images_list + ] + images_list = [images for images, _ in images_list_and_num_crops] + num_crops = [num_crops for _, num_crops in images_list_and_num_crops] + else: + num_crops = [[0] for images in images_list] + + if do_resize: + height, width = size["height"], size["width"] + images_list = [ + [ + resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format) + for image in images + ] + for images in images_list + ] + + if do_rescale: + images_list = [ + [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + for images in images_list + ] + + if do_normalize: + images_list = [ + [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + for images in images_list + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for images in images_list + for image in images + ] + + data = {"pixel_values": images, "num_crops": num_crops} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["Gemma3ImageProcessor"] diff --git a/hf-gemma3/modeling_gemma3.py b/hf-gemma3/modeling_gemma3.py new file mode 100644 index 000000000000..8ee518286ecc --- /dev/null +++ b/hf-gemma3/modeling_gemma3.py @@ -0,0 +1,1466 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable +from dataclasses import dataclass +from typing import List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, HybridCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from ..auto import AutoModel, AutoModelForCausalLM +from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "Gemma3Config" + + +class Gemma3MLP(nn.Module): + def __init__(self, config: Gemma3TextConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Gemma3RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class Gemma3RotaryEmbedding(nn.Module): + def __init__(self, config: Gemma3TextConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +ATTENTION_TYPE_GLOBAL = "global" +ATTENTION_TYPE_LOCAL = "local_sliding" +AttentionType = Literal["global", "local_sliding"] + + +def create_sliding_window_mask( + sliding_window_size: int, + q_pos: torch.Tensor, + kv_pos: torch.Tensor, +) -> torch.Tensor: + """Creates mask for sliding window attention.""" + return q_pos < kv_pos + sliding_window_size + + +def eager_attention_forward( + module: "Gemma3Attention", + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class Gemma3Attention(nn.Module): + def __init__(self, config: Gemma3TextConfig, layer_idx: int): + super().__init__() + self.attention_dropout = config.attention_dropout + self.attention_type: AttentionType = config.attention_pattern[layer_idx % len(config.attention_pattern)] + self.config = config + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.is_causal = True + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar + self.is_sliding = self.attention_type == ATTENTION_TYPE_LOCAL + self.sliding_window = config.sliding_window if self.is_sliding else None + + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings_global: torch.Tensor, + position_embeddings_local: torch.Tensor, + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + if self.attention_type == ATTENTION_TYPE_GLOBAL: + cos, sin = position_embeddings_global + else: + cos, sin = position_embeddings_local + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "sliding_window": self.sliding_window, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Here we need to slice as we use a static cache by default, but FA2 does not support it + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + seq_len = attention_mask.shape[-1] + key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] + + if self.is_sliding and key_states.shape[-2] > self.sliding_window: + assert self.sliding_window is not None + if query_states.shape[-2] == key_states.shape[-2]: + sliding_window_mask = create_sliding_window_mask( + sliding_window_size=self.sliding_window, + q_pos=torch.arange(query_states.shape[-2]).unsqueeze(-1), + kv_pos=torch.arange(key_states.shape[-2]).unsqueeze(-2), + ) + attention_mask = torch.logical_and(attention_mask, sliding_window_mask) + else: + raise ValueError() + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask.to(query_states), + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Gemma3DecoderLayer(nn.Module): + def __init__(self, config: Gemma3TextConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) + self.mlp = Gemma3MLP(config) + self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.is_sliding = self.self_attn.is_sliding + self.sliding_window = config.sliding_window + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings_global: torch.Tensor, + position_embeddings_local: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: int = 0, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + if not isinstance(past_key_value, HybridCache): + raise ValueError("Gemma 3 only supports a HybridCache, required for local vs global attention") + + if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding + # In prefill, we may be larger than sliding window + effective_seq_len = max(cache_position.shape[0], self.sliding_window) + # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), + # thus we must slice from the right (at most `effective_seq_len` elements) + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask[:, -effective_seq_len:] + # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice + # from the left, with an offset if we are beyond the sliding window + else: + min_dtype = torch.finfo(attention_mask.dtype).min + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window + ) + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + # In case we are beyond the sliding window, we need to correctly offset the mask slicing + # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo + offset = last_cache_position - effective_seq_len + # Should only be used when beyond the sliding window (i.e. offset > 0) + offset = max(0, offset) + attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +GEMMA3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Gemma3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Gemma3 Model outputting raw hidden-states without any specific head on top.", + GEMMA3_START_DOCSTRING, +) +class Gemma3PreTrainedModel(PreTrainedModel): + config_class = Gemma3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Gemma3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +GEMMA3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Gemma3 Model outputting raw hidden-states without any specific head on top.", + GEMMA3_START_DOCSTRING, +) +class Gemma3Model(Gemma3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma3DecoderLayer`] + + Args: + config: Gemma3Config + """ + + config_class = Gemma3TextConfig + + def __init__(self, config: Gemma3TextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Gemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Gemma3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas + # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default"} + self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: Optional[int] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # normalized + # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=inputs_embeds.dtype) + inputs_embeds = inputs_embeds * normalizer + + if use_cache and past_key_values is None and not self.training: + batch_size, seq_len, _ = inputs_embeds.shape + past_key_values = HybridCache( + self.config, + max_batch_size=batch_size, + max_cache_len=seq_len, + dtype=inputs_embeds.dtype, + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + if last_cache_position is None: + last_cache_position = 0 + if attention_mask is not None: + # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position + # It will break dynamo tracing but there are no way around it (and it should never happen in practice) + last_cache_position = ( + attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings_global = self.rotary_emb(hidden_states, position_ids) + position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + position_embeddings_global, + position_embeddings_local, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + last_cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + last_cache_position=last_cache_position, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + @torch.no_grad() + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: HybridCache, + output_attentions: bool, + ): + # Flash Attention currently doesn't support static cache but Gemma3 work only with static cache. + # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape + # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible + # as it doesn't cause dynamic control issues. + if self.config._attn_implementation == "flash_attention_2": + return attention_mask + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if isinstance(past_key_values, (HybridCache, StaticCache)): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config_class = Gemma3TextConfig + + def __init__(self, config: Gemma3TextConfig): + super().__init__(config) + self.model = Gemma3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Gemma3ForCausalLM + + >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma3 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **loss_kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten: has a special cache type, `HybridCache` + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if past_key_values is not None: + if ( + inputs_embeds is not None # Exception 1 + or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s + # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride + # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the + # batch size = 1 case, `position_ids` is already contiguous but with varying stride + # which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 + + if ( + isinstance(past_key_values, HybridCache) + and attention_mask.ndim == 2 + and not self.config._attn_implementation == "flash_attention_2" + ): + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + ) + + if logits_to_keep is not None: + model_inputs["logits_to_keep"] = logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + +class Gemma3MultiModalProjector(nn.Module): + def __init__(self, config: Gemma3Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) + ) + + self.mm_soft_emb_norm = Gemma3RMSNorm( + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) + + self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + b, _, l = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape(b, l, self.patches_per_image, self.patches_per_image) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.einsum("btm,md->btd", normed_vision_outputs, self.mm_input_projection_weight) + return projected_vision_outputs.type_as(vision_outputs) + + +@dataclass +class Gemma3CausalLMOutputWithPast(ModelOutput): + """ + Base class for Gemma3causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +@add_start_docstrings( + """The GEMMA3 model which consists of a vision backbone and a language model.""", + GEMMA3_START_DOCSTRING, +) +class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): + def __init__(self, config: Gemma3Config): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.multi_modal_projector = Gemma3MultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + + language_model = AutoModelForCausalLM.from_config(config=config.text_config) + + if language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] + self.language_model = language_model + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def _update_causal_mask( + self, + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_tensor, + is_training: bool = False, + ): + if self.config.text_config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted + # form and requires no inversion or slicing. + return attention_mask + + min_dtype = torch.finfo(self.dtype).min + batch_size, sequence_length = input_tensor.shape[:2] + if isinstance(past_key_values, (HybridCache, StaticCache)): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + causal_mask = torch.ones((sequence_length, target_length), dtype=torch.bool) + causal_mask = torch.tril(causal_mask) + causal_mask = causal_mask.to(self.device) + + attention_mask = attention_mask.unsqueeze(-2).to(self.device) + causal_mask = causal_mask.unsqueeze(0).repeat(attention_mask.shape[0], 1, 1) + combined_mask = attention_mask * causal_mask[:, :, :sequence_length] + + image_token_mask = input_tensor == self.config.image_token_index + image_token_mask.to(self.device) + # logger.warning("image_token_mask shape = %s", image_token_mask.shape) + padded_mask = nn.functional.pad(image_token_mask, (1, 0), value=0) + padded_mask = padded_mask.to(self.device) + # logger.warning("padded_mask shape = %s", padded_mask.shape) + boundary = padded_mask[:, 1:] > padded_mask[:, :-1] + boundary = boundary.to(self.device) + numbered_boundary = torch.cumsum(boundary, dim=-1) + numbered_boundary = numbered_boundary.to(self.device) + q_block_indices = image_token_mask * numbered_boundary + q_block_indices = q_block_indices.to(self.device) + kv_block_indices = q_block_indices + # logger.warning("q_block_indices/kv_block_indices shape = %s", q_block_indices.shape) + bidirectional_mask = torch.logical_and( + kv_block_indices[:, None, :] == q_block_indices.unsqueeze(-1), + q_block_indices.unsqueeze(-1) > 0, + ) + bidirectional_mask.to(self.device) + attention_mask = torch.logical_or(combined_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)).to(self.device) + full_attention_mask = torch.zeros((batch_size, 1, sequence_length, target_length)).to(self.device, torch.bool) + full_attention_mask[:, :, :, :sequence_length] = attention_mask + attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(self.device) + + return attention_mask + + def get_image_features(self, pixel_values: torch.Tensor): + """ + Projects the last hidden state from the vision model into language model space. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = self.multi_modal_projector(vision_outputs) + return image_features + + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration + + >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") + >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") + + >>> prompt = "answer en Where is the cow standing?" + >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "answer en Where is the cow standing?\nbeach" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + if inputs_embeds is None: + special_image_mask = input_ids == self.config.image_token_index + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + # normalized + # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype) + inputs_embeds = inputs_embeds * normalizer + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + special_image_mask = special_image_mask.unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index) + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_ids, is_training + ) + outputs = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + logits = outputs.logits + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Gemma3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + pixel_values=None, + attention_mask=None, + token_type_ids=None, + use_cache=True, + logits_to_keep=None, + labels=None, + **kwargs, + ): + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + token_type_ids=token_type_ids, + **kwargs, + ) + + # position_ids in Gemma3 are 1-indexed + if model_inputs.get("position_ids") is not None: + model_inputs["position_ids"] += 1 + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + is_training = token_type_ids is not None and labels is not None + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + input_tensor = inputs_embeds if inputs_embeds is not None else input_ids + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training + ) + model_inputs["attention_mask"] = causal_mask + + return model_inputs + + +__all__ = ["Gemma3PreTrainedModel", "Gemma3Model", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration"] diff --git a/hf-gemma3/modular_gemma3.py b/hf-gemma3/modular_gemma3.py new file mode 100644 index 000000000000..5a090c5d26ea --- /dev/null +++ b/hf-gemma3/modular_gemma3.py @@ -0,0 +1,1036 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable, Sequence +from typing import List, Literal, Optional, Tuple, Union, cast + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...cache_utils import Cache, HybridCache, StaticCache +from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, +) +from ...modeling_rope_utils import rope_config_validation +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...utils import ( + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ..gemma2.modeling_gemma2 import ( + Gemma2ForCausalLM, + Gemma2MLP, + Gemma2Model, + Gemma2RMSNorm, + Gemma2RotaryEmbedding, + apply_rotary_pos_emb, + repeat_kv, +) +from ..paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +from ..siglip import SiglipVisionConfig + + +_CHECKPOINT_FOR_DOC = "google/gemma-3-4b" +_CONFIG_FOR_DOC = "Gemma3Config" + +logger = logging.get_logger(__name__) + +GEMMA3_INPUTS_DOCSTRING = "" + +ATTENTION_TYPE_GLOBAL = "global" +ATTENTION_TYPE_LOCAL = "local_sliding" +AttentionType = Literal["global", "local_sliding"] +AttentionPattern = Sequence[AttentionType] +DEFAULT_ATTENION_PATTERN = cast( + AttentionPattern, + ( + ATTENTION_TYPE_LOCAL, + ATTENTION_TYPE_LOCAL, + ATTENTION_TYPE_LOCAL, + ATTENTION_TYPE_LOCAL, + ATTENTION_TYPE_LOCAL, + ATTENTION_TYPE_GLOBAL, + ), +) + + +class Gemma3TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3Model`]. It is used to instantiate a Gemma3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma3-7B. + e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Gemma3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Gemma3Model`] + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + hidden_size (`int`, *optional*, defaults to 2304): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 9216): + Dimension of the MLP representations. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to + `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` + activation function. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_local_base_freq (float, *optional*, defaults to `rope_theta`): + The base period of the RoPE embeddings for local attention. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_pattern (Sequence[AttentionTypes], defaults to (5 * local, global)): + The attention pattern to apply + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to None): + The scaling factor used on the attention scores, not that + sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window + attention. This is the size of the sliding window. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh soft-capping + on the attention scorexs. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + + ```python + >>> from transformers import Gemma3Model, Gemma3Config + >>> # Initializing a Gemma3 gemma3-7b style configuration + >>> configuration = Gemma3Config() + >>> # Initializing a model from the gemma3-7b style configuration + >>> model = Gemma3Model(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma3_text" + + def __init__( + self, + # Config parameters found in all implementations, name differences noted + vocab_size: int = 262_144, # num_embed in FLAX + hidden_size: int = 2304, # embed_dim in FLAX + intermediate_size: int = 9216, # hidden_dim in FLAX + num_hidden_layers: int = 26, # num_layers in FLAX + num_attention_heads: int = 8, # num_heads in FLAX + num_key_value_heads: int = 4, # num_kv_heads in FLAX + head_dim: int = 256, + sliding_window: int = 4096, # sliding_window_size in FLAX + query_pre_attn_scalar: Optional[float] = None, + attention_pattern: AttentionPattern = DEFAULT_ATTENION_PATTERN, + rope_theta: float = 1_000_000.0, + rope_scaling = None, + rope_local_base_freq: float = 10_000.0, + rms_norm_eps: float = 1e-6, + hidden_activation: str = "gelu_pytorch_tanh", + pad_token_id: int = 0, + eos_token_id: int = 1, + bos_token_id: int = 2, + tie_word_embeddings: bool = True, + max_position_embeddings: int = 131_072, + initializer_range: float = 0.02, + attention_bias: bool = False, + attention_dropout: float = 0.0, + use_cache: bool = True, + final_logit_softcapping=None, + attn_logit_softcapping=None, + cache_implementation: str = "hybrid", + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.rope_local_base_freq = rope_local_base_freq + self.attention_pattern = attention_pattern + # For configuring HybridCache to work with 5:1 attention pattern + self.sliding_window_pattern = len(self.attention_pattern) + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.query_pre_attn_scalar = query_pre_attn_scalar + self.sliding_window = sliding_window + self.final_logit_softcapping = final_logit_softcapping + self.attn_logit_softcapping = attn_logit_softcapping + self.cache_implementation = cache_implementation + rope_config_validation(self) + + +class Gemma3Config(PretrainedConfig): + model_type = "gemma3" + sub_configs = { + "text_config": Gemma3TextConfig, + "vision_config": SiglipVisionConfig, + } + + def __init__( + self, + text_config: Optional[Gemma3TextConfig] = None, + vision_config: Optional[SiglipVisionConfig] = None, + mm_tokens_per_image: int = 256, + boi_token_index: int = 255_999, + eoi_token_index: int = 256_000, + image_token_index: int = 262_144, + initializer_range: float = 0.02, + **kwargs, + ): + if text_config is None: + text_config = Gemma3TextConfig() + logger.info("text_config is None, using default Gemma3TextConfig vision config.") + elif isinstance(text_config, dict): + text_config = Gemma3TextConfig(**text_config) + + if isinstance(vision_config, dict): + vision_config = SiglipVisionConfig(**vision_config) + else: + vision_config = SiglipVisionConfig() + logger.info( + "vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited " + "to text tasks." + ) + + self.text_config = text_config + self.vision_config = vision_config + self.mm_tokens_per_image = mm_tokens_per_image + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + +class Gemma3MLP(Gemma2MLP): + def __init__(self, config: Gemma3TextConfig): + super().__init__(config) + + +class Gemma3RMSNorm(Gemma2RMSNorm): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + + +class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding): + def __init__(self, config: Gemma3TextConfig, device=None): + super().__init__(config) + + +def create_sliding_window_mask( + position_ids: torch.LongTensor, + cache_position: int, + cache_len: int, + sliding_window_size: int, +) -> torch.Tensor: + """Creates mask for sliding window attention.""" + total_tokens = cache_position + position_ids.shape[1] # cached + processing tokens + + def _reconstruct_rotated_cache_positions(): + cache_positions = torch.arange(cache_len) + total_tokens - cache_len + rotated_cache_positions = torch.zeros_like(cache_positions) + rotated_cache_positions[cache_positions % cache_len] = cache_positions + return rotated_cache_positions + + # Reconstruct position_ids for cached kv. + if total_tokens <= cache_len: + cache_positions = torch.arange(cache_len) + else: + cache_positions = _reconstruct_rotated_cache_positions() + + cache_positions = cache_positions.unsqueeze(0).unsqueeze(0).to(position_ids.device) # [1, 1, cache_len] + position_ids = position_ids.unsqueeze(-1) # [B, seq_len, 1] + sliding_mask = cache_positions > position_ids - sliding_window_size + sliding_mask *= cache_positions < position_ids + sliding_window_size + return sliding_mask.unsqueeze(1) + + +def create_sliding_window_mask( + sliding_window_size: int, + q_pos: torch.Tensor, + kv_pos: torch.Tensor, +) -> torch.Tensor: + """Creates mask for sliding window attention.""" + return q_pos < kv_pos + sliding_window_size + + +def eager_attention_forward( + module: "Gemma3Attention", + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class Gemma3Attention(nn.Module): + def __init__(self, config: Gemma3TextConfig, layer_idx: int): + super().__init__() + self.attention_dropout = config.attention_dropout + self.attention_type: AttentionType = config.attention_pattern[layer_idx % len(config.attention_pattern)] + self.config = config + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.is_causal = True + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar + self.is_sliding = self.attention_type == ATTENTION_TYPE_LOCAL + self.sliding_window = config.sliding_window if self.is_sliding else None + + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings_global: torch.Tensor, + position_embeddings_local: torch.Tensor, + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + if self.attention_type == ATTENTION_TYPE_GLOBAL: + cos, sin = position_embeddings_global + else: + cos, sin = position_embeddings_local + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "sliding_window": self.sliding_window, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Here we need to slice as we use a static cache by default, but FA2 does not support it + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + seq_len = attention_mask.shape[-1] + key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] + + if self.is_sliding and key_states.shape[-2] > self.sliding_window: + assert self.sliding_window is not None + if query_states.shape[-2] == key_states.shape[-2]: + sliding_window_mask = create_sliding_window_mask( + sliding_window_size=self.sliding_window, + q_pos=torch.arange(query_states.shape[-2]).unsqueeze(-1), + kv_pos=torch.arange(key_states.shape[-2]).unsqueeze(-2), + ) + attention_mask = torch.logical_and(attention_mask, sliding_window_mask) + else: + raise ValueError() + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask.to(query_states), + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Gemma3DecoderLayer(nn.Module): + def __init__(self, config: Gemma3TextConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) + self.mlp = Gemma3MLP(config) + self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.is_sliding = self.self_attn.is_sliding + self.sliding_window = config.sliding_window + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings_global: torch.Tensor, + position_embeddings_local: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: int = 0, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + if not isinstance(past_key_value, HybridCache): + raise ValueError("Gemma 3 only supports a HybridCache, required for local vs global attention") + + if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding + # In prefill, we may be larger than sliding window + effective_seq_len = max(cache_position.shape[0], self.sliding_window) + # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), + # thus we must slice from the right (at most `effective_seq_len` elements) + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask[:, -effective_seq_len:] + # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice + # from the left, with an offset if we are beyond the sliding window + else: + min_dtype = torch.finfo(attention_mask.dtype).min + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window + ) + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + # In case we are beyond the sliding window, we need to correctly offset the mask slicing + # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo + offset = last_cache_position - effective_seq_len + # Should only be used when beyond the sliding window (i.e. offset > 0) + offset = max(0, offset) + attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Gemma3Model(Gemma2Model): + config_class = Gemma3TextConfig + + def __init__(self, config: Gemma3TextConfig): + super().__init__(config) + + # TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas + # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default"} + self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: Optional[int] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # normalized + # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=inputs_embeds.dtype) + inputs_embeds = inputs_embeds * normalizer + + if use_cache and past_key_values is None and not self.training: + batch_size, seq_len, _ = inputs_embeds.shape + past_key_values = HybridCache( + self.config, + max_batch_size=batch_size, + max_cache_len=seq_len, + dtype=inputs_embeds.dtype, + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + if last_cache_position is None: + last_cache_position = 0 + if attention_mask is not None: + # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position + # It will break dynamo tracing but there are no way around it (and it should never happen in practice) + last_cache_position = ( + attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings_global = self.rotary_emb(hidden_states, position_ids) + position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + position_embeddings_global, + position_embeddings_local, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + last_cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + last_cache_position=last_cache_position, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class Gemma3ForCausalLM(Gemma2ForCausalLM): + config_class = Gemma3TextConfig + + def __init__(self, config: Gemma3TextConfig): + super().__init__(config) + + +class Gemma3MultiModalProjector(nn.Module): + def __init__(self, config: Gemma3Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) + ) + + self.mm_soft_emb_norm = Gemma3RMSNorm( + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) + + self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + b, _, l = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape(b, l, self.patches_per_image, self.patches_per_image) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.einsum("btm,md->btd", normed_vision_outputs, self.mm_input_projection_weight) + return projected_vision_outputs.type_as(vision_outputs) + +# The only diff on forward is that Gemma3 relies on `input_ids` when creating causal mask, while Paligemma +# passes input embeds. Can be removed when we enable good token type ids +class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): + def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Projects the last hidden state from the vision model into language model space. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = self.multi_modal_projector(vision_outputs) + return image_features + + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration + + >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") + >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") + + >>> prompt = "answer en Where is the cow standing?" + >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "answer en Where is the cow standing?\nbeach" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + if inputs_embeds is None: + special_image_mask = (input_ids == self.config.image_token_index) + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + # normalized + # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype) + inputs_embeds = inputs_embeds * normalizer + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + special_image_mask = special_image_mask.unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index) + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_ids, is_training + ) + outputs = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + logits = outputs.logits + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Gemma3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + def _update_causal_mask( + self, + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_tensor, + is_training: bool = False, + ): + if self.config.text_config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted + # form and requires no inversion or slicing. + return attention_mask + + min_dtype = torch.finfo(self.dtype).min + batch_size, sequence_length = input_tensor.shape[:2] + if isinstance(past_key_values, (HybridCache, StaticCache)): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + causal_mask = torch.ones((sequence_length, target_length), dtype=torch.bool) + causal_mask = torch.tril(causal_mask) + causal_mask = causal_mask.to(self.device) + + attention_mask = attention_mask.unsqueeze(-2).to(self.device) + causal_mask = causal_mask.unsqueeze(0).repeat(attention_mask.shape[0], 1, 1) + combined_mask = attention_mask * causal_mask[:, :, :sequence_length] + + image_token_mask = input_tensor == self.config.image_token_index + image_token_mask.to(self.device) + # logger.warning("image_token_mask shape = %s", image_token_mask.shape) + padded_mask = nn.functional.pad(image_token_mask, (1, 0), value=0) + padded_mask = padded_mask.to(self.device) + # logger.warning("padded_mask shape = %s", padded_mask.shape) + boundary = padded_mask[:, 1:] > padded_mask[:, :-1] + boundary = boundary.to(self.device) + numbered_boundary = torch.cumsum(boundary, dim=-1) + numbered_boundary = numbered_boundary.to(self.device) + q_block_indices = image_token_mask * numbered_boundary + q_block_indices = q_block_indices.to(self.device) + kv_block_indices = q_block_indices + # logger.warning("q_block_indices/kv_block_indices shape = %s", q_block_indices.shape) + bidirectional_mask = torch.logical_and( + kv_block_indices[:, None, :] == q_block_indices.unsqueeze(-1), + q_block_indices.unsqueeze(-1) > 0, + ) + bidirectional_mask.to(self.device) + attention_mask = torch.logical_or(combined_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)).to(self.device) + full_attention_mask = torch.zeros((batch_size, 1, sequence_length, target_length)).to(self.device, torch.bool) + full_attention_mask[:, :, :, :sequence_length] = attention_mask + attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(self.device) + + return attention_mask + + +__all__ = [ + "Gemma3Config", + "Gemma3TextConfig", + "Gemma3PreTrainedModel", # noqa: F822 + "Gemma3Model", + "Gemma3ForCausalLM", + "Gemma3ForConditionalGeneration", +] diff --git a/hf-gemma3/processing_gemma3.py b/hf-gemma3/processing_gemma3.py new file mode 100644 index 000000000000..b45cbe73c7ce --- /dev/null +++ b/hf-gemma3/processing_gemma3.py @@ -0,0 +1,159 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, make_nested_list_of_images +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import to_py_obj + + +class Gemma3ImagesKwargs(ImagesKwargs): + do_pan_and_scan: Optional[bool] + pan_and_scan_min_crop_size: Optional[int] + pan_and_scan_max_num_crops: Optional[int] + pan_and_scan_min_ratio_to_activate: Optional[float] + do_convert_rgb: Optional[bool] + + +class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": { + "do_pan_and_scan": False, + "pan_and_scan_min_crop_size": 256, + "pan_and_scan_max_num_crops": 4, + "pan_and_scan_min_ratio_to_activate": 1.2, + }, + } + + +class Gemma3Processor(ProcessorMixin): + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + image_processor_class = "Gemma3ImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor, + tokenizer, + chat_template=None, + num_mm_soft_tokens_per_image: int = 256, + **kwargs, + ): + self.image_seq_length = getattr(image_processor, "image_seq_length") + self.image_token_id = tokenizer.image_token_id + image_tokens_expanded = "".join([tokenizer.image_token] * num_mm_soft_tokens_per_image) + self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded }{tokenizer.eoi_token}\n\n" + + super().__init__( + image_processor=image_processor, + tokenizer=tokenizer, + chat_template=chat_template, + **kwargs, + ) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + videos=None, + audio=None, + **kwargs: Unpack[Gemma3ProcessorKwargs], + ) -> BatchFeature: + if text is None and images is None: + raise ValueError("Provide at least one of `text` or `images`.") + + output_kwargs = self._merge_kwargs( + Gemma3ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + image_inputs = {} + if images is not None: + batched_images = make_nested_list_of_images(images) + image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"]) + + # Create empty text to be replaced with placeholders + if not text: + text = [" ".join([""] * len(images)) for images in batched_images] + + if len(batched_images) != len(text): + raise ValueError( + f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})." + ) + + # Replace image tokens by the full expanded sequence + batch_num_crops = to_py_obj(image_inputs.pop("num_crops")) + for prompt, images, num_crops in zip(text, batched_images, batch_num_crops): + image_indexes = [m.start() for m in re.finditer("", prompt)] + + if len(images) != len(image_indexes): + raise ValueError( + f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images." + ) + + # Insert additional image tokens for Pan-and-Scan crops + for num, idx in reversed(list(zip(num_crops, image_indexes))): + if num: + formatted_image_text = ( + "Here is the original image and here are some crops to help you see better " + + " ".join([""] * num) + ) + prompt = prompt[:idx] + formatted_image_text + prompt[idx + len("") :] + + # Expand placeholder image tokens to the full image token sequence + text = [prompt.replace("", self.full_image_sequence) for prompt in text] + + text_input = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + return BatchFeature(data={**text_input, **image_inputs}) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["Gemma3Processor"] From a50c2d59773fe13e579a8717a19430c069761f47 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 8 Mar 2025 22:01:03 -0800 Subject: [PATCH 05/52] Fix config: Signed-off-by: Woosuk Kwon --- vllm/config.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ce076d81f220..0ef14f55c289 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -350,10 +350,11 @@ def __init__( if self.enforce_eager is None: self.enforce_eager = False + interleaved_attn_models = ["gemma2", "gemma3", "gemma3_text", "cohere2"] sliding_window = getattr(self.hf_text_config, "sliding_window", None) has_interleaved_attention = (sliding_window is not None) and ( isinstance(sliding_window, list) or - (self.hf_text_config.model_type in ["gemma2", "gemma3", "cohere2"])) + (self.hf_text_config.model_type in interleaved_attn_models)) if (not self.disable_sliding_window and has_interleaved_attention): if (backend := @@ -2487,11 +2488,11 @@ def _get_and_verify_dtype( dtype = dtype.lower() if dtype == "auto": if config_dtype == torch.float32: - if config.model_type == "gemma2": + if config.model_type in ("gemma2", "gemma3", "gemma3_text"): logger.info( - "For Gemma 2, we downcast float32 to bfloat16 instead " - "of float16 by default. Please specify `dtype` if you " - "want to use float16.") + "For Gemma 2 and 3, we downcast float32 to bfloat16 " + "instead of float16 by default. Please specify `dtype` " + "if you want to use float16.") torch_dtype = torch.bfloat16 else: # Following the common practice, we use float16 for float32 From d2562cb9f15a0604033bcefa17068c66379c2689 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 9 Mar 2025 19:05:57 -0700 Subject: [PATCH 06/52] [TMP] image input Signed-off-by: Woosuk Kwon --- examples/offline_inference/vision_language.py | 13 +- .../vision_language_multi_image.py | 17 + .../multimodal/processing/test_common.py | 1 + tests/models/registry.py | 1 + vllm/model_executor/models/gemma3_mm.py | 357 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/config.py | 4 - 7 files changed, 389 insertions(+), 5 deletions(-) create mode 100644 vllm/model_executor/models/gemma3_mm.py diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index e2ec36211b86..16d024cd6ca3 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -111,6 +111,16 @@ def run_fuyu(question: str, modality: str): return llm, prompt, stop_token_ids +def run_gemma3(question: str, modality: str): + assert modality == "image" + prompt = f" {question}" + model_name = "gg-hf-g/gemma-3-4b-it-pr" + llm = LLM(model=model_name, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) + stop_token_ids = None + return llm, prompt, stop_token_ids + + # GLM-4v def run_glm4v(question: str, modality: str): assert modality == "image" @@ -589,6 +599,7 @@ def run_qwen2_5_vl(question: str, modality: str): "deepseek_vl_v2": run_deepseek_vl2, "florence2": run_florence2, "fuyu": run_fuyu, + "gemma3": run_gemma3, "glm4v": run_glm4v, "h2ovl_chat": run_h2ovl, "idefics3": run_idefics3, @@ -689,7 +700,7 @@ def main(args): # We set temperature to 0.2 so that outputs can be different # even when all prompts are identical when running batch inference. - sampling_params = SamplingParams(temperature=0.2, + sampling_params = SamplingParams(temperature=0.0, max_tokens=64, stop_token_ids=stop_token_ids) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index b1aec33cff46..e61dbb16b835 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -77,6 +77,22 @@ def load_deepseek_vl2(question: str, image_urls: list[str]): ) +def load_gemma3(question, image_urls: list[str]) -> ModelRequestData: + model_name = "gg-hf-g/gemma-3-4b-it-pr" + llm = LLM(model=model_name, + max_model_len=8192, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}) + prompt = " " * len(image_urls) + question + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=None, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "h2oai/h2ovl-mississippi-800m" @@ -453,6 +469,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData: model_example_map = { "aria": load_aria, "deepseek_vl_v2": load_deepseek_vl2, + "gemma3": load_gemma3, "h2ovl_chat": load_h2ovl, "idefics3": load_idefics3, "internvl_chat": load_internvl, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 7534f0c97798..faa8eb0cc2f8 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -153,6 +153,7 @@ def _test_processing_correctness( "deepseek-ai/deepseek-vl2-tiny", "microsoft/Florence-2-base", "adept/fuyu-8b", + "gg-hf-g/gemma-3-4b-it-pr", "THUDM/glm-4v-9b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", diff --git a/tests/models/registry.py b/tests/models/registry.py index 97db33b46fad..fca5297bcbb6 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -241,6 +241,7 @@ def check_available_online( "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), + "Gemma3ForConditionalGeneration": _HfExamplesInfo("gg-hf-g/gemma-3-4b-it-pr"), "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py new file mode 100644 index 000000000000..dab46d156eba --- /dev/null +++ b/vllm/model_executor/models/gemma3_mm.py @@ -0,0 +1,357 @@ +# SPDX-License-Identifier: Apache-2.0 +import math +import re +from collections.abc import Iterable, Mapping, Sequence +from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple, + TypedDict, Union) + +import torch +from torch import nn +# FIXME +# from transformers import Gemma3Config +from transformers.models.gemma3.configuration_gemma3 import Gemma3Config +from transformers import BatchFeature, ProcessorMixin + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import (ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsMultiModal, SupportsPP +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings, flatten_bn) + +logger = init_logger(__name__) + +NUM_TOKENS_PER_IMAGE = 256 + + +class Gemma3ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +Gemma3ImageInputs = Gemma3ImagePixelInputs + + +class Gemma3ProcessingInfo(BaseProcessingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": NUM_TOKENS_PER_IMAGE} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[ProcessorMixin], + ) -> int: + return NUM_TOKENS_PER_IMAGE + + def get_image_size_with_most_features(self) -> ImageSize: + # Result in the max possible feature size (h:w = 16:1) + return ImageSize(height=8000, width=50) + + +class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + return ProcessorInputs( + prompt_text=" ".join([""] * num_images), + mm_data=mm_data, + ) + + +class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + # FIXME(woosuk): Currently, PaS is not supported. + img_kwargs = mm_kwargs.get("images_kwargs", {}) + if img_kwargs: + img_kwargs["do_pan_and_scan"] = False + else: + img_kwargs = {"do_pan_and_scan": False} + mm_kwargs["images_kwargs"] = img_kwargs + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_token = hf_processor.tokenizer.image_token + image_tokens_expanded = "".join([image_token] * NUM_TOKENS_PER_IMAGE) + + def get_replacement_gemma3(item_idx: int): + return PromptUpdateDetails( + full=hf_processor.full_image_sequence, + features=image_tokens_expanded, + ) + + return [ + PromptReplacement( + modality="image", + target="", + replacement=get_replacement_gemma3, + ) + ] + + +class Gemma3MultiModalProjector(nn.Module): + def __init__(self, config: Gemma3Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) + ) + + self.mm_soft_emb_norm = GemmaRMSNorm( + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) + + self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + b, _, l = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape(b, l, self.patches_per_image, self.patches_per_image) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.einsum("btm,md->btd", normed_vision_outputs, self.mm_input_projection_weight) + return projected_vision_outputs.type_as(vision_outputs) + + +@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor, + info=Gemma3ProcessingInfo, + dummy_inputs=Gemma3DummyInputsBuilder) +class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + + self.vision_tower = SiglipVisionModel(config.vision_config, + quant_config, + prefix=maybe_prefix( + prefix, "vision_tower")) + self.multi_modal_projector = Gemma3MultiModalProjector(config) + + config.text_config.architectures = ["Gemma3ForCausalLM"] + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.language_model.logits_processor.scale *= logit_scale + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @property + def sampler(self): + return self.language_model.sampler + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + if d.shape != expected_dims: + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_dims}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Gemma3ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + assert image_embeds is None + if pixel_values is None: + return None + + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = flatten_bn(pixel_values) + return Gemma3ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + target_dtype = vision_tower.get_input_embeddings().weight.dtype + image_features = vision_tower(pixel_values.to(dtype=target_dtype)) + return image_features + + def _process_image_input( + self, + image_input: Gemma3ImageInputs, + ) -> torch.Tensor: + assert self.vision_tower is not None + pixel_values = image_input["data"] + vision_outputs = self._image_pixels_to_features( + self.vision_tower, + pixel_values, + ) + return self.multi_modal_projector(vision_outputs) + + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + if multimodal_embeddings is None: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + else: + # NOTE(woosuk): Gemma3 uses vocab_size as the image token index. + # To avoid out-of-range error in the embedding layer, we replace the + # image token index with 0. + safe_input_ids = torch.where( + input_ids == self.config.image_token_index, 0, input_ids) + inputs_embeds = self.language_model.get_input_embeddings(safe_input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_index) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4e058a0c506e..3555cf1cc4a3 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -162,6 +162,7 @@ "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), + "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 38e53991ae3b..1937b1388471 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -715,10 +715,6 @@ def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. No op for pure text models. """ - # FIXME(woosuk): This is a hack because Gemma3's text_config does not match - # its config.json for some reason. Remove this once the issue is fixed. - if config.model_type == "gemma3": - return config if hasattr(config, "text_config"): # The code operates under the assumption that text_config should have # `num_attention_heads` (among others). Assert here to fail early From 017239edac3f32d0e57c7dbe57a323501269778b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 00:26:53 -0700 Subject: [PATCH 07/52] Update Signed-off-by: Woosuk Kwon --- examples/offline_inference/vision_language.py | 4 +-- .../vision_language_multi_image.py | 4 +-- vllm/config.py | 4 ++- vllm/model_executor/models/gemma3.py | 31 ++++++++++--------- vllm/model_executor/models/gemma3_mm.py | 13 +++++--- 5 files changed, 32 insertions(+), 24 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 16d024cd6ca3..d53073bf22ed 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -113,8 +113,8 @@ def run_fuyu(question: str, modality: str): def run_gemma3(question: str, modality: str): assert modality == "image" - prompt = f" {question}" - model_name = "gg-hf-g/gemma-3-4b-it-pr" + prompt = f" {question}" + model_name = "gg-hf-g/gemma-3-4b" llm = LLM(model=model_name, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index e61dbb16b835..c2de580f12b1 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -78,12 +78,12 @@ def load_deepseek_vl2(question: str, image_urls: list[str]): def load_gemma3(question, image_urls: list[str]) -> ModelRequestData: - model_name = "gg-hf-g/gemma-3-4b-it-pr" + model_name = "gg-hf-g/gemma-3-4b" llm = LLM(model=model_name, max_model_len=8192, max_num_seqs=2, limit_mm_per_prompt={"image": len(image_urls)}) - prompt = " " * len(image_urls) + question + prompt = " " * len(image_urls) + question return ModelRequestData( llm=llm, prompt=prompt, diff --git a/vllm/config.py b/vllm/config.py index 0ef14f55c289..5653448ceef1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2624,7 +2624,9 @@ def _get_and_verify_max_len( derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) - if rope_scaling is not None: + # NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE + # scaling, so we skip applying the scaling factor again. + if rope_scaling is not None and "gemma3" not in hf_config.model_type: # No need to consider "type" key because of patch_rope_scaling when # loading HF config rope_type = rope_scaling["rope_type"] diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 972f60ee960c..7f0e0d500d4b 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -119,11 +119,7 @@ def __init__(self, self.head_dim = head_dim self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - # FIXME(woosuk): This seems like a bug in config.json. - if config.query_pre_attn_scalar < 1: - self.scaling = config.query_pre_attn_scalar - else: - self.scaling = config.query_pre_attn_scalar**-0.5 + self.scaling = config.query_pre_attn_scalar self.qkv_proj = QKVParallelLinear( hidden_size, @@ -145,23 +141,29 @@ def __init__(self, # TODO(woosuk): Add reference to the original HF implementation. layer_idx = extract_layer_index(prefix) - attn_type = config.attention_pattern[layer_idx % len(config.attention_pattern)] - use_sliding_window = (attn_type == ATTENTION_TYPE_LOCAL) + self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) # Initialize the rotary embedding. - self.rope_theta = (config.rope_local_base_freq if use_sliding_window else - config.rope_global_base_freq) + if self.is_sliding: + # Local attention. Override the values in config.json. + self.rope_theta = config.rope_local_base_freq + self.rope_scaling = {"rope_type": "default"} + self.sliding_window = config.interleaved_sliding_window + else: + # Global attention. Use the values in config.json. + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.sliding_window = None self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=self.rope_theta, is_neox_style=True, + rope_scaling=self.rope_scaling, ) # Initialize the attention. - sliding_window = (config.interleaved_sliding_window if - use_sliding_window else None) self.attn = Attention(self.num_heads, self.head_dim, self.scaling, @@ -169,7 +171,7 @@ def __init__(self, cache_config=cache_config, quant_config=quant_config, logits_soft_cap=attn_logits_soft_cap, - per_layer_sliding_window=sliding_window, + per_layer_sliding_window=self.sliding_window, prefix=f"{prefix}.attn") def forward( @@ -290,7 +292,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) + # NOTE(woosuk): Only apply the normalizer to the output of + # vocab embedding. Don't apply it to the vision embedding. + return self.embed_tokens(input_ids) * self.normalizer def forward( self, @@ -304,7 +308,6 @@ def forward( hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) - hidden_states *= self.normalizer residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index dab46d156eba..14d98047cdaf 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -36,6 +36,7 @@ logger = init_logger(__name__) NUM_TOKENS_PER_IMAGE = 256 +BOI_TOKEN = "" class Gemma3ImagePixelInputs(TypedDict): @@ -92,7 +93,7 @@ def get_dummy_processor_inputs( num_images=num_images) } return ProcessorInputs( - prompt_text=" ".join([""] * num_images), + prompt_text=" ".join([BOI_TOKEN] * num_images), mm_data=mm_data, ) @@ -144,7 +145,7 @@ def get_replacement_gemma3(item_idx: int): return [ PromptReplacement( modality="image", - target="", + target=BOI_TOKEN, replacement=get_replacement_gemma3, ) ] @@ -168,10 +169,12 @@ def __init__(self, config: Gemma3Config): self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) def forward(self, vision_outputs: torch.Tensor): - b, _, l = vision_outputs.shape + batch_size, _, seq_length = vision_outputs.shape reshaped_vision_outputs = vision_outputs.transpose(1, 2) - reshaped_vision_outputs = reshaped_vision_outputs.reshape(b, l, self.patches_per_image, self.patches_per_image) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) reshaped_vision_outputs = reshaped_vision_outputs.contiguous() pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) @@ -180,7 +183,7 @@ def forward(self, vision_outputs: torch.Tensor): normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) - projected_vision_outputs = torch.einsum("btm,md->btd", normed_vision_outputs, self.mm_input_projection_weight) + projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight) return projected_vision_outputs.type_as(vision_outputs) From de0136aaf370569a468056028c139c782f8e3f9a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 01:03:21 -0700 Subject: [PATCH 08/52] Update Signed-off-by: Woosuk Kwon --- hf-gemma3/__init__.py | 2 +- hf-gemma3/configuration_gemma3.py | 165 +++-- .../convert_gemma3_weights_orbax_to_hf.py | 98 +-- hf-gemma3/image_processing_gemma3.py | 32 +- hf-gemma3/modeling_gemma3.py | 350 +++++------ hf-gemma3/modular_gemma3.py | 572 ++++++++---------- hf-gemma3/processing_gemma3.py | 44 +- 7 files changed, 621 insertions(+), 642 deletions(-) diff --git a/hf-gemma3/__init__.py b/hf-gemma3/__init__.py index e8e1c60bd56e..511b0a38e1d8 100644 --- a/hf-gemma3/__init__.py +++ b/hf-gemma3/__init__.py @@ -19,8 +19,8 @@ if TYPE_CHECKING: from .configuration_gemma3 import * - from .modeling_gemma3 import * from .image_processing_gemma3 import * + from .modeling_gemma3 import * from .processing_gemma3 import * else: import sys diff --git a/hf-gemma3/configuration_gemma3.py b/hf-gemma3/configuration_gemma3.py index 1dec0bc30ae2..276aecc5602c 100644 --- a/hf-gemma3/configuration_gemma3.py +++ b/hf-gemma3/configuration_gemma3.py @@ -19,8 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence -from typing import Literal, Optional, cast +from typing import Optional from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import rope_config_validation @@ -30,44 +29,26 @@ logger = logging.get_logger(__name__) -ATTENTION_TYPE_GLOBAL = "global" -ATTENTION_TYPE_LOCAL = "local_sliding" -AttentionType = Literal["global", "local_sliding"] -AttentionPattern = Sequence[AttentionType] -DEFAULT_ATTENION_PATTERN = cast( - AttentionPattern, - ( - ATTENTION_TYPE_LOCAL, - ATTENTION_TYPE_LOCAL, - ATTENTION_TYPE_LOCAL, - ATTENTION_TYPE_LOCAL, - ATTENTION_TYPE_LOCAL, - ATTENTION_TYPE_GLOBAL, - ), -) - class Gemma3TextConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Gemma3Model`]. It is used to instantiate a Gemma3 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Gemma3-7B. + defaults will yield a similar configuration to that of the Gemma3-4B. e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: - vocab_size (`int`, *optional*, defaults to 256000): + vocab_size (`int`, *optional*, defaults to 262144): Vocabulary size of the Gemma3 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Gemma3Model`] - num_hidden_layers (`int`, *optional*, defaults to 26): - Number of hidden layers in the Transformer decoder. - max_position_embeddings (`int`, *optional*, defaults to 8192): - The maximum sequence length that this model might ever be used with. hidden_size (`int`, *optional*, defaults to 2304): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 9216): Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 8): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*, defaults to 4): @@ -80,26 +61,12 @@ class Gemma3TextConfig(PretrainedConfig): `num_attention_heads`. head_dim (`int`, *optional*, defaults to 256): The attention head dimension. - hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the decoder. Will default to - `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` - activation function. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - pad_token_id (`int`, *optional*, defaults to 0): - Padding token id. - eos_token_id (`int`, *optional*, defaults to 1): - End of stream token id. - bos_token_id (`int`, *optional*, defaults to 2): - Beginning of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_local_base_freq (float, *optional*, defaults to `rope_theta`): - The base period of the RoPE embeddings for local attention. + sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window + attention. This is the size of the sliding window. + query_pre_attn_scalar (`float`, *optional*): + The scaling factor used on the attention scores, not that + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings used for global attention. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value @@ -137,28 +104,47 @@ class Gemma3TextConfig(PretrainedConfig): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - attention_pattern (Sequence[AttentionTypes], defaults to (5 * local, global)): - The attention pattern to apply + rope_local_base_freq (float, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings for local attention. + sliding_window_pattern (`int`, *optional*, defaults to 6): + Pattern for the sliding window attention. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to + `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` + activation function. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - query_pre_attn_scalar (`float`, *optional*, defaults to None): - The scaling factor used on the attention scores, not that - sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window - attention. This is the size of the sliding window. - attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh soft-capping - on the attention scorexs. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + final_logit_softcapping (`bool`, *optional*, defaults to `True`): + Whether to apply logit softcapping or nor + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): + Scaling factor when applying tanh soft-capping on the attention scorexs. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): + The cache type to be used with `generate`. ```python - >>> from transformers import Gemma3Model, Gemma3Config - >>> # Initializing a Gemma3 gemma3-7b style configuration + >>> from transformers import Gemma3Model, Gemma3TextConfig + >>> # Initializing a Gemma3 gemma3-4b style configuration >>> configuration = Gemma3Config() - >>> # Initializing a model from the gemma3-7b style configuration + >>> # Initializing a model from the gemma3-4b style configuration >>> model = Gemma3Model(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -168,20 +154,19 @@ class Gemma3TextConfig(PretrainedConfig): def __init__( self, - # Config parameters found in all implementations, name differences noted - vocab_size: int = 262_144, # num_embed in FLAX - hidden_size: int = 2304, # embed_dim in FLAX - intermediate_size: int = 9216, # hidden_dim in FLAX - num_hidden_layers: int = 26, # num_layers in FLAX - num_attention_heads: int = 8, # num_heads in FLAX - num_key_value_heads: int = 4, # num_kv_heads in FLAX + vocab_size: int = 262_144, + hidden_size: int = 2304, + intermediate_size: int = 9216, + num_hidden_layers: int = 26, + num_attention_heads: int = 8, + num_key_value_heads: int = 4, head_dim: int = 256, - sliding_window: int = 4096, # sliding_window_size in FLAX + sliding_window: int = 4096, query_pre_attn_scalar: Optional[float] = None, - attention_pattern: AttentionPattern = DEFAULT_ATTENION_PATTERN, rope_theta: float = 1_000_000.0, rope_scaling=None, rope_local_base_freq: float = 10_000.0, + sliding_window_pattern: int = 6, rms_norm_eps: float = 1e-6, hidden_activation: str = "gelu_pytorch_tanh", pad_token_id: int = 0, @@ -220,9 +205,8 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.rope_local_base_freq = rope_local_base_freq - self.attention_pattern = attention_pattern # For configuring HybridCache to work with 5:1 attention pattern - self.sliding_window_pattern = len(self.attention_pattern) + self.sliding_window_pattern = sliding_window_pattern self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.hidden_activation = hidden_activation @@ -235,6 +219,53 @@ def __init__( class Gemma3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an + Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PaliGemma-2B. + + e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`Union[Gemma3TextConfig, dict]`, *optional*): + The config object of the text backbone. + vision_config (`Union[AutoConfig, dict]`, *optional*): + Custom vision config or dict. + mm_tokens_per_image (`int`, *optional*, defaults to 256): + The number of tokens per image embedding. + boi_token_index (`int`, *optional*, defaults to 255999): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 256000): + The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 262144): + The image token index to encode the image prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig + + >>> # Initializing a Siglip-like vision config + >>> vision_config = SiglipVisionConfig() + + >>> # Initializing a Gemma3 Text config + >>> text_config = Gemma3TextConfig() + + >>> # Initializing a Gemma3 gemma-3-4b style configuration + >>> configuration = Gemma3Config(vision_config, text_config) + + >>> # Initializing a model from the gemma-3-4b style configuration + >>> model = Gemma3TextConfig(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "gemma3" sub_configs = { "text_config": Gemma3TextConfig, diff --git a/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py b/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py index 6f116976bc4d..f9f4c45d45b1 100644 --- a/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py +++ b/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py @@ -9,7 +9,6 @@ """ import dataclasses -import math from collections.abc import Iterator, Sequence from typing import Any @@ -20,16 +19,15 @@ from absl import app, flags, logging from orbax import checkpoint as obc -from ..gemma import GemmaTokenizerFast from ...image_utils import PILImageResampling +from ..gemma import GemmaTokenizerFast from . import ( Gemma3ForCausalLM, Gemma3ForConditionalGeneration, - Gemma3Processor, Gemma3ImageProcessor, + Gemma3Processor, ) from .configuration_gemma3 import ( - DEFAULT_ATTENION_PATTERN, Gemma3Config, Gemma3TextConfig, SiglipVisionConfig, @@ -38,18 +36,44 @@ # ==== Internal Constants and Classes ==== -_CHAT_TEMPLATE = ( - "{{ bos_token }}{% set system_message = '' %}{% if messages[0]['role'] == 'system' %}" - "{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}" - "{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}" - "{% if loop.index0 == 0 and message['role'] == 'user' %}" - "{{ '' + message['role'] + '\n' + system_message + message['content'] | trim + '\n' }}" - "{% elif (message['role'] == 'assistant') %}{% set role = 'model' %}" - "{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% else %}" - "{{ '' + message['role'] + '\n' + message['content'] | trim + '\n' }}{% endif %}" - "{% endfor %}{% if add_generation_prompt %}{{ 'model\n' }}{% endif %}" -) + +_CHAT_TEMPLATE = """{{ bos_token }} +{%- if messages[0]['role'] == 'system' -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{%- for message in loop_messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {{ '' + role + '\n' + (first_user_prefix if loop.first else "") }} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + {{ '\n' }} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{'model\n'}} +{%- endif -%} +""" _DTYPES = { "float32": torch.float32, @@ -93,13 +117,13 @@ text_config=Gemma3TextConfig( vocab_size=262_144, hidden_size=1152, - intermediate_size=6912, + intermediate_size=6 * 1152, num_attention_heads=4, num_hidden_layers=26, num_key_value_heads=1, - attention_pattern=DEFAULT_ATTENION_PATTERN, + head_dim=256, sliding_window=512, - rope_global_base_freq=1_000_000, + rope_theta=1_000_000, # used for global RoPE only rope_local_base_freq=10_000, attn_logit_softcapping=None, query_pre_attn_scalar=256**-0.5, @@ -111,13 +135,14 @@ text_config=Gemma3TextConfig( vocab_size=262_144, hidden_size=2560, - intermediate_size=10_240, + intermediate_size=2560 * 8 // 2, num_attention_heads=8, + head_dim=256, num_hidden_layers=34, num_key_value_heads=4, - attention_pattern=DEFAULT_ATTENION_PATTERN, sliding_window=1024, - rope_global_base_freq=1_000_000, + rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only + rope_theta=1_000_000, rope_local_base_freq=10_000, attn_logit_softcapping=None, query_pre_attn_scalar=256**-0.5, @@ -127,14 +152,15 @@ _VARIANT_GEMMA_3_12B: Gemma3Config( text_config=Gemma3TextConfig( vocab_size=262_144, - hidden_size=3840, - intermediate_size=3840 * 8 // 2, + hidden_size=30 * 128, + intermediate_size=30 * 128 * 8 // 2, num_attention_heads=16, + head_dim=256, num_hidden_layers=48, num_key_value_heads=8, - attention_pattern=DEFAULT_ATTENION_PATTERN, sliding_window=1024, - rope_global_base_freq=1_000_000, + rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only + rope_theta=1_000_000, rope_local_base_freq=10_000, attn_logit_softcapping=None, query_pre_attn_scalar=256**-0.5, @@ -144,18 +170,18 @@ _VARIANT_GEMMA_3_27B: Gemma3Config( text_config=Gemma3TextConfig( vocab_size=262_144, - hidden_size=5376, - intermediate_size=5376 * 8 // 2, + hidden_size=42 * 128, + intermediate_size=42 * 128 * 8 // 2, num_attention_heads=32, num_hidden_layers=62, num_key_value_heads=16, head_dim=128, - attention_pattern=DEFAULT_ATTENION_PATTERN, sliding_window=1024, - rope_global_base_freq=1_000_000, + rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only + rope_theta=1_000_000, rope_local_base_freq=10_000, attn_logit_softcapping=None, - query_pre_attn_scalar=1 / math.sqrt(5376 // 32), # 1 / sqrt(hidden_size // num_attention_heads) + query_pre_attn_scalar=(42 * 128 // 32) ** -0.5, # 1 / sqrt(hidden_size // num_attention_heads) ), vision_config=_VISION_CONFIG, ), @@ -483,6 +509,7 @@ def main(*args): ) if INCLUDE_CHAT_TEMPLATE.value: + # Include chat temaplate for CausalLM models tokenizer.chat_template = _CHAT_TEMPLATE if _TEXT_ONLY.value: @@ -493,16 +520,19 @@ def main(*args): else: image_processor = Gemma3ImageProcessor( image_seq_length=256, - image_mean=(127.5,) * 3, - image_std=(127.5,) * 3, + image_mean=(0.5,) * 3, + image_std=(0.5,) * 3, size={"height": 896, "width": 896}, - do_rescale=False, resample=PILImageResampling.BILINEAR, ) processor = Gemma3Processor( image_processor=image_processor, tokenizer=tokenizer, ) + if INCLUDE_CHAT_TEMPLATE.value: + # Duplicate so multimodal instruct models can also be used for CausalLM + processor.chat_template = tokenizer.chat_template + processor.save_pretrained(output_path) logging.info("Saved Gemma3Processor for %s to %s", variant, output_path) del processor diff --git a/hf-gemma3/image_processing_gemma3.py b/hf-gemma3/image_processing_gemma3.py index 2787057cf618..46334cd41fb2 100644 --- a/hf-gemma3/image_processing_gemma3.py +++ b/hf-gemma3/image_processing_gemma3.py @@ -60,7 +60,7 @@ class Gemma3ImageProcessor(BaseImageProcessor): `do_resize` in the `preprocess` method. size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. - resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. do_rescale (`bool`, *optional*, defaults to `True`): Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in @@ -82,6 +82,12 @@ class Gemma3ImageProcessor(BaseImageProcessor): Whether to convert the image to RGB. do_pan_and_scan (`bool`, *optional*): Whether to apply `pan_and_scan` to images. + pan_and_scan_min_crop_size (`int`, *optional*): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*): + Minimum aspect ratio to activate pan and scan. """ model_input_names = ["pixel_values", "num_crops"] @@ -91,7 +97,7 @@ def __init__( do_resize: bool = True, size: Dict[str, int] = None, resample: PILImageResampling = PILImageResampling.BILINEAR, - do_rescale: bool = False, + do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, @@ -105,6 +111,7 @@ def __init__( ) -> None: super().__init__(**kwargs) size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size, default_to_square=True) image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD @@ -132,17 +139,18 @@ def pan_and_scan( input_data_format: Optional[Union[str, ChannelDimension]] = None, ): """ - Pan and Scan and image, whatever it means. TODO: write-up docs + Pan and Scan and image, by cropping into smaller images when the aspect ratio exceeds + minumum allowed ratio. Args: image (`np.ndarray`): Image to resize. - pan_and_scan_min_crop_size (`int`): - Size of pan_and_scan_min_crop_size. - pan_and_scan_max_num_crops (`int`): - pan_and_scan_max_num_crops for the image. - pan_and_scan_min_ratio_to_activate (`int`): - pan_and_scan_min_ratio_to_activate for the image.. + pan_and_scan_min_crop_size (`int`, *optional*): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*): + Minimum aspect ratio to activate pan and scan. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): @@ -288,6 +296,12 @@ def preprocess( Whether to convert the image to RGB. do_pan_and_scan (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to apply `pan_and_scan` to images. + pan_and_scan_min_crop_size (`int`, *optional*, defaults to `self.pan_and_scan_min_crop_size`): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*, defaults to `self.pan_and_scan_max_num_crops`): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*, defaults to `self.pan_and_scan_min_ratio_to_activate`): + Minimum aspect ratio to activate pan and scan. """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size diff --git a/hf-gemma3/modeling_gemma3.py b/hf-gemma3/modeling_gemma3.py index 8ee518286ecc..5a79ec5143ff 100644 --- a/hf-gemma3/modeling_gemma3.py +++ b/hf-gemma3/modeling_gemma3.py @@ -19,9 +19,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy from collections.abc import Callable from dataclasses import dataclass -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -30,12 +31,11 @@ from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( - ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, is_torchdynamo_compiling, @@ -51,6 +51,59 @@ _CONFIG_FOR_DOC = "Gemma3Config" +@dataclass +class Gemma3CausalLMOutputWithPast(ModelOutput): + """ + Base class for Gemma3 causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class Gemma3ScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + class Gemma3MLP(nn.Module): def __init__(self, config: Gemma3TextConfig): super().__init__() @@ -194,20 +247,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -ATTENTION_TYPE_GLOBAL = "global" -ATTENTION_TYPE_LOCAL = "local_sliding" -AttentionType = Literal["global", "local_sliding"] - - -def create_sliding_window_mask( - sliding_window_size: int, - q_pos: torch.Tensor, - kv_pos: torch.Tensor, -) -> torch.Tensor: - """Creates mask for sliding window attention.""" - return q_pos < kv_pos + sliding_window_size - - def eager_attention_forward( module: "Gemma3Attention", query: torch.Tensor, @@ -242,15 +281,14 @@ class Gemma3Attention(nn.Module): def __init__(self, config: Gemma3TextConfig, layer_idx: int): super().__init__() self.attention_dropout = config.attention_dropout - self.attention_type: AttentionType = config.attention_pattern[layer_idx % len(config.attention_pattern)] self.config = config self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.is_causal = True self.layer_idx = layer_idx self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.query_pre_attn_scalar - self.is_sliding = self.attention_type == ATTENTION_TYPE_LOCAL - self.sliding_window = config.sliding_window if self.is_sliding else None + self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) + self.sliding_window = config.sliding_window self.q_proj = nn.Linear( config.hidden_size, @@ -295,10 +333,10 @@ def forward( query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) - if self.attention_type == ATTENTION_TYPE_GLOBAL: - cos, sin = position_embeddings_global - else: + if self.is_sliding: cos, sin = position_embeddings_local + else: + cos, sin = position_embeddings_global query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -317,18 +355,6 @@ def forward( seq_len = attention_mask.shape[-1] key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - if self.is_sliding and key_states.shape[-2] > self.sliding_window: - assert self.sliding_window is not None - if query_states.shape[-2] == key_states.shape[-2]: - sliding_window_mask = create_sliding_window_mask( - sliding_window_size=self.sliding_window, - q_pos=torch.arange(query_states.shape[-2]).unsqueeze(-1), - kv_pos=torch.arange(key_states.shape[-2]).unsqueeze(-2), - ) - attention_mask = torch.logical_and(attention_mask, sliding_window_mask) - else: - raise ValueError() - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -386,9 +412,6 @@ def forward( last_cache_position: int = 0, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - if not isinstance(past_key_value, HybridCache): - raise ValueError("Gemma 3 only supports a HybridCache, required for local vs global attention") - if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding # In prefill, we may be larger than sliding window effective_seq_len = max(cache_position.shape[0], self.sliding_window) @@ -468,9 +491,14 @@ def forward( ) class Gemma3PreTrainedModel(PreTrainedModel): config_class = Gemma3Config - base_model_prefix = "model" + base_model_prefix = "language_model" supports_gradient_checkpointing = True - _no_split_modules = ["Gemma3DecoderLayer"] + _no_split_modules = [ + "Gemma3DecoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True @@ -481,8 +509,15 @@ class Gemma3PreTrainedModel(PreTrainedModel): _supports_attention_backend = True def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): + # important: this ported version of Gemma2 isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() @@ -586,7 +621,10 @@ def __init__(self, config: Gemma3TextConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402 + self.embed_tokens = Gemma3ScaledWordEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 + ) self.layers = nn.ModuleList( [Gemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) @@ -596,6 +634,7 @@ def __init__(self, config: Gemma3TextConfig): # TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE + config = copy.deepcopy(config) config.rope_theta = config.rope_local_base_freq config.rope_scaling = {"rope_type": "default"} self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) @@ -643,11 +682,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # normalized - # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 - # See https://github.com/huggingface/transformers/pull/29402 - normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=inputs_embeds.dtype) - inputs_embeds = inputs_embeds * normalizer if use_cache and past_key_values is None and not self.training: batch_size, seq_len, _ = inputs_embeds.shape @@ -848,6 +882,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = Gemma3TextConfig + base_model_prefix = "language_model" def __init__(self, config: Gemma3TextConfig): super().__init__(config) @@ -991,42 +1026,23 @@ def prepare_inputs_for_generation( ): # Overwritten: has a special cache type, `HybridCache` - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - if past_key_values is not None: - if ( - inputs_embeds is not None # Exception 1 - or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 - ): - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s - # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride - # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the - # batch size = 1 case, `position_ids` is already contiguous but with varying stride - # which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing # (retrieving the same value from `cache_position` later on would crash dynamo) model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 + if logits_to_keep is None: + _ = model_inputs.pop("logits_to_keep", None) if ( isinstance(past_key_values, HybridCache) @@ -1049,19 +1065,8 @@ def prepare_inputs_for_generation( cache_position=cache_position, batch_size=batch_size, ) + model_inputs["attention_mask"] = attention_mask - if logits_to_keep is not None: - model_inputs["logits_to_keep"] = logits_to_keep - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - } - ) return model_inputs @@ -1083,10 +1088,12 @@ def __init__(self, config: Gemma3Config): self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) def forward(self, vision_outputs: torch.Tensor): - b, _, l = vision_outputs.shape + batch_size, _, seq_length = vision_outputs.shape reshaped_vision_outputs = vision_outputs.transpose(1, 2) - reshaped_vision_outputs = reshaped_vision_outputs.reshape(b, l, self.patches_per_image, self.patches_per_image) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) reshaped_vision_outputs = reshaped_vision_outputs.contiguous() pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) @@ -1095,50 +1102,10 @@ def forward(self, vision_outputs: torch.Tensor): normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) - projected_vision_outputs = torch.einsum("btm,md->btd", normed_vision_outputs, self.mm_input_projection_weight) + projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight) return projected_vision_outputs.type_as(vision_outputs) -@dataclass -class Gemma3CausalLMOutputWithPast(ModelOutput): - """ - Base class for Gemma3causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder after projecting last hidden state. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[torch.FloatTensor] = None - - @add_start_docstrings( """The GEMMA3 model which consists of a vision backbone and a language model.""", GEMMA3_START_DOCSTRING, @@ -1196,9 +1163,12 @@ def _update_causal_mask( # form and requires no inversion or slicing. return attention_mask + using_static_cache = isinstance(past_key_values, StaticCache) min_dtype = torch.finfo(self.dtype).min - batch_size, sequence_length = input_tensor.shape[:2] - if isinstance(past_key_values, (HybridCache, StaticCache)): + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): target_length = past_key_values.get_max_cache_shape() else: target_length = ( @@ -1207,39 +1177,43 @@ def _update_causal_mask( else cache_position[0] + sequence_length + 1 ) - causal_mask = torch.ones((sequence_length, target_length), dtype=torch.bool) - causal_mask = torch.tril(causal_mask) - causal_mask = causal_mask.to(self.device) - - attention_mask = attention_mask.unsqueeze(-2).to(self.device) - causal_mask = causal_mask.unsqueeze(0).repeat(attention_mask.shape[0], 1, 1) - combined_mask = attention_mask * causal_mask[:, :, :sequence_length] - - image_token_mask = input_tensor == self.config.image_token_index - image_token_mask.to(self.device) - # logger.warning("image_token_mask shape = %s", image_token_mask.shape) - padded_mask = nn.functional.pad(image_token_mask, (1, 0), value=0) - padded_mask = padded_mask.to(self.device) - # logger.warning("padded_mask shape = %s", padded_mask.shape) - boundary = padded_mask[:, 1:] > padded_mask[:, :-1] - boundary = boundary.to(self.device) - numbered_boundary = torch.cumsum(boundary, dim=-1) - numbered_boundary = numbered_boundary.to(self.device) - q_block_indices = image_token_mask * numbered_boundary - q_block_indices = q_block_indices.to(self.device) - kv_block_indices = q_block_indices - # logger.warning("q_block_indices/kv_block_indices shape = %s", q_block_indices.shape) - bidirectional_mask = torch.logical_and( - kv_block_indices[:, None, :] == q_block_indices.unsqueeze(-1), - q_block_indices.unsqueeze(-1) > 0, + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + return attention_mask + + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device ) - bidirectional_mask.to(self.device) - attention_mask = torch.logical_or(combined_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)).to(self.device) - full_attention_mask = torch.zeros((batch_size, 1, sequence_length, target_length)).to(self.device, torch.bool) - full_attention_mask[:, :, :, :sequence_length] = attention_mask - attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(self.device) - return attention_mask + # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + + # Apply bidirectional mask on images if token type ids are provided + if token_type_ids is not None and sequence_length != 1: + token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) + token_type_mask[token_type_ids == 0] = False # if text token do not change anything + token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) + causal_mask = causal_mask.clone() + causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( + token_type_mask, 0.0 + ) + + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + + # Then apply padding mask (will mask pad tokens) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask def get_image_features(self, pixel_values: torch.Tensor): """ @@ -1255,6 +1229,7 @@ def get_image_features(self, pixel_values: torch.Tensor): image_features = self.multi_modal_projector(vision_outputs) return image_features + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1315,11 +1290,6 @@ def forward( if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1328,16 +1298,16 @@ def forward( is_training = token_type_ids is not None and labels is not None - if inputs_embeds is None: + # Replace image id woth PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_index >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_index llm_input_ids = input_ids.clone() llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(llm_input_ids) - # normalized - # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 - # See https://github.com/huggingface/transformers/pull/29402 - normalizer = torch.tensor(self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype) - inputs_embeds = inputs_embeds * normalizer if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1352,10 +1322,16 @@ def forward( if pixel_values is not None: image_features = self.get_image_features(pixel_values) - special_image_mask = special_image_mask.unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index) + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] raise ValueError( f"Number of images does not match number of special image tokens in the input text. " f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " @@ -1364,8 +1340,16 @@ def forward( image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + # mask out pad-token-ids in labels for BC + if labels is not None and self.pad_token_id in labels: + logger.warning_once( + "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " + "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", + ) + labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) + causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_ids, is_training + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training ) outputs = self.language_model( attention_mask=causal_mask, diff --git a/hf-gemma3/modular_gemma3.py b/hf-gemma3/modular_gemma3.py index 5a090c5d26ea..8c2653cab03b 100644 --- a/hf-gemma3/modular_gemma3.py +++ b/hf-gemma3/modular_gemma3.py @@ -13,8 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable, Sequence -from typing import List, Literal, Optional, Tuple, Union, cast +import copy +from collections.abc import Callable +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -25,19 +27,20 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, + ModelOutput, ) from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack from ...utils import ( - add_start_docstrings_to_model_forward, - is_torchdynamo_compiling, logging, - replace_return_docstrings, ) +from ..bart.modeling_bart import BartScaledWordEmbedding from ..gemma2.modeling_gemma2 import ( Gemma2ForCausalLM, Gemma2MLP, Gemma2Model, + Gemma2PreTrainedModel, Gemma2RMSNorm, Gemma2RotaryEmbedding, apply_rotary_pos_emb, @@ -54,44 +57,26 @@ GEMMA3_INPUTS_DOCSTRING = "" -ATTENTION_TYPE_GLOBAL = "global" -ATTENTION_TYPE_LOCAL = "local_sliding" -AttentionType = Literal["global", "local_sliding"] -AttentionPattern = Sequence[AttentionType] -DEFAULT_ATTENION_PATTERN = cast( - AttentionPattern, - ( - ATTENTION_TYPE_LOCAL, - ATTENTION_TYPE_LOCAL, - ATTENTION_TYPE_LOCAL, - ATTENTION_TYPE_LOCAL, - ATTENTION_TYPE_LOCAL, - ATTENTION_TYPE_GLOBAL, - ), -) - class Gemma3TextConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Gemma3Model`]. It is used to instantiate a Gemma3 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Gemma3-7B. + defaults will yield a similar configuration to that of the Gemma3-4B. e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: - vocab_size (`int`, *optional*, defaults to 256000): + vocab_size (`int`, *optional*, defaults to 262144): Vocabulary size of the Gemma3 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Gemma3Model`] - num_hidden_layers (`int`, *optional*, defaults to 26): - Number of hidden layers in the Transformer decoder. - max_position_embeddings (`int`, *optional*, defaults to 8192): - The maximum sequence length that this model might ever be used with. hidden_size (`int`, *optional*, defaults to 2304): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 9216): Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 8): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*, defaults to 4): @@ -104,26 +89,12 @@ class Gemma3TextConfig(PretrainedConfig): `num_attention_heads`. head_dim (`int`, *optional*, defaults to 256): The attention head dimension. - hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the decoder. Will default to - `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` - activation function. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - pad_token_id (`int`, *optional*, defaults to 0): - Padding token id. - eos_token_id (`int`, *optional*, defaults to 1): - End of stream token id. - bos_token_id (`int`, *optional*, defaults to 2): - Beginning of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_local_base_freq (float, *optional*, defaults to `rope_theta`): - The base period of the RoPE embeddings for local attention. + sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window + attention. This is the size of the sliding window. + query_pre_attn_scalar (`float`, *optional*): + The scaling factor used on the attention scores, not that + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings used for global attention. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value @@ -161,28 +132,47 @@ class Gemma3TextConfig(PretrainedConfig): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - attention_pattern (Sequence[AttentionTypes], defaults to (5 * local, global)): - The attention pattern to apply + rope_local_base_freq (float, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings for local attention. + sliding_window_pattern (`int`, *optional*, defaults to 6): + Pattern for the sliding window attention. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to + `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` + activation function. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - query_pre_attn_scalar (`float`, *optional*, defaults to None): - The scaling factor used on the attention scores, not that - sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window - attention. This is the size of the sliding window. - attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh soft-capping - on the attention scorexs. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + final_logit_softcapping (`bool`, *optional*, defaults to `True`): + Whether to apply logit softcapping or nor + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): + Scaling factor when applying tanh soft-capping on the attention scorexs. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): + The cache type to be used with `generate`. ```python - >>> from transformers import Gemma3Model, Gemma3Config - >>> # Initializing a Gemma3 gemma3-7b style configuration + >>> from transformers import Gemma3Model, Gemma3TextConfig + >>> # Initializing a Gemma3 gemma3-4b style configuration >>> configuration = Gemma3Config() - >>> # Initializing a model from the gemma3-7b style configuration + >>> # Initializing a model from the gemma3-4b style configuration >>> model = Gemma3Model(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -192,20 +182,19 @@ class Gemma3TextConfig(PretrainedConfig): def __init__( self, - # Config parameters found in all implementations, name differences noted - vocab_size: int = 262_144, # num_embed in FLAX - hidden_size: int = 2304, # embed_dim in FLAX - intermediate_size: int = 9216, # hidden_dim in FLAX - num_hidden_layers: int = 26, # num_layers in FLAX - num_attention_heads: int = 8, # num_heads in FLAX - num_key_value_heads: int = 4, # num_kv_heads in FLAX + vocab_size: int = 262_144, + hidden_size: int = 2304, + intermediate_size: int = 9216, + num_hidden_layers: int = 26, + num_attention_heads: int = 8, + num_key_value_heads: int = 4, head_dim: int = 256, - sliding_window: int = 4096, # sliding_window_size in FLAX + sliding_window: int = 4096, query_pre_attn_scalar: Optional[float] = None, - attention_pattern: AttentionPattern = DEFAULT_ATTENION_PATTERN, rope_theta: float = 1_000_000.0, - rope_scaling = None, + rope_scaling=None, rope_local_base_freq: float = 10_000.0, + sliding_window_pattern: int = 6, rms_norm_eps: float = 1e-6, hidden_activation: str = "gelu_pytorch_tanh", pad_token_id: int = 0, @@ -244,9 +233,8 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.rope_local_base_freq = rope_local_base_freq - self.attention_pattern = attention_pattern # For configuring HybridCache to work with 5:1 attention pattern - self.sliding_window_pattern = len(self.attention_pattern) + self.sliding_window_pattern = sliding_window_pattern self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.hidden_activation = hidden_activation @@ -259,6 +247,53 @@ def __init__( class Gemma3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an + Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PaliGemma-2B. + + e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*): + Custom vision config or dict. + text_config (`Union[Gemma3TextConfig, dict]`, *optional*): + The config object of the text backbone. + mm_tokens_per_image (`int`, *optional*, defaults to 256): + The number of tokens per image embedding. + image_token_index (`int`, *optional*, defaults to 262_144): + The image token index to encode the image prompt. + boi_token_index (`int`, *optional*, defaults to 255_999): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 256_000): + The end-of-image token index to wrap the image prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig + + >>> # Initializing a Siglip-like vision config + >>> vision_config = SiglipVisionConfig() + + >>> # Initializing a Gemma3 Text config + >>> text_config = Gemma3TextConfig() + + >>> # Initializing a Gemma3 gemma-3-4b style configuration + >>> configuration = Gemma3Config(vision_config, text_config) + + >>> # Initializing a model from the gemma-3-4b style configuration + >>> model = Gemma3TextConfig(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "gemma3" sub_configs = { "text_config": Gemma3TextConfig, @@ -302,6 +337,50 @@ def __init__( super().__init__(**kwargs) +@dataclass +class Gemma3CausalLMOutputWithPast(ModelOutput): + """ + Base class for Gemma3 causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class Gemma3ScaledWordEmbedding(BartScaledWordEmbedding): + pass + + class Gemma3MLP(Gemma2MLP): def __init__(self, config: Gemma3TextConfig): super().__init__(config) @@ -317,43 +396,6 @@ def __init__(self, config: Gemma3TextConfig, device=None): super().__init__(config) -def create_sliding_window_mask( - position_ids: torch.LongTensor, - cache_position: int, - cache_len: int, - sliding_window_size: int, -) -> torch.Tensor: - """Creates mask for sliding window attention.""" - total_tokens = cache_position + position_ids.shape[1] # cached + processing tokens - - def _reconstruct_rotated_cache_positions(): - cache_positions = torch.arange(cache_len) + total_tokens - cache_len - rotated_cache_positions = torch.zeros_like(cache_positions) - rotated_cache_positions[cache_positions % cache_len] = cache_positions - return rotated_cache_positions - - # Reconstruct position_ids for cached kv. - if total_tokens <= cache_len: - cache_positions = torch.arange(cache_len) - else: - cache_positions = _reconstruct_rotated_cache_positions() - - cache_positions = cache_positions.unsqueeze(0).unsqueeze(0).to(position_ids.device) # [1, 1, cache_len] - position_ids = position_ids.unsqueeze(-1) # [B, seq_len, 1] - sliding_mask = cache_positions > position_ids - sliding_window_size - sliding_mask *= cache_positions < position_ids + sliding_window_size - return sliding_mask.unsqueeze(1) - - -def create_sliding_window_mask( - sliding_window_size: int, - q_pos: torch.Tensor, - kv_pos: torch.Tensor, -) -> torch.Tensor: - """Creates mask for sliding window attention.""" - return q_pos < kv_pos + sliding_window_size - - def eager_attention_forward( module: "Gemma3Attention", query: torch.Tensor, @@ -388,15 +430,14 @@ class Gemma3Attention(nn.Module): def __init__(self, config: Gemma3TextConfig, layer_idx: int): super().__init__() self.attention_dropout = config.attention_dropout - self.attention_type: AttentionType = config.attention_pattern[layer_idx % len(config.attention_pattern)] self.config = config self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.is_causal = True self.layer_idx = layer_idx self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.query_pre_attn_scalar - self.is_sliding = self.attention_type == ATTENTION_TYPE_LOCAL - self.sliding_window = config.sliding_window if self.is_sliding else None + self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) + self.sliding_window = config.sliding_window self.q_proj = nn.Linear( config.hidden_size, @@ -441,10 +482,10 @@ def forward( query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) - if self.attention_type == ATTENTION_TYPE_GLOBAL: - cos, sin = position_embeddings_global - else: + if self.is_sliding: cos, sin = position_embeddings_local + else: + cos, sin = position_embeddings_global query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -463,18 +504,6 @@ def forward( seq_len = attention_mask.shape[-1] key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - if self.is_sliding and key_states.shape[-2] > self.sliding_window: - assert self.sliding_window is not None - if query_states.shape[-2] == key_states.shape[-2]: - sliding_window_mask = create_sliding_window_mask( - sliding_window_size=self.sliding_window, - q_pos=torch.arange(query_states.shape[-2]).unsqueeze(-1), - kv_pos=torch.arange(key_states.shape[-2]).unsqueeze(-2), - ) - attention_mask = torch.logical_and(attention_mask, sliding_window_mask) - else: - raise ValueError() - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -532,9 +561,6 @@ def forward( last_cache_position: int = 0, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - if not isinstance(past_key_value, HybridCache): - raise ValueError("Gemma 3 only supports a HybridCache, required for local vs global attention") - if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding # In prefill, we may be larger than sliding window effective_seq_len = max(cache_position.shape[0], self.sliding_window) @@ -591,14 +617,51 @@ def forward( return outputs +GEMMA3_START_DOCSTRING = None + + +class Gemma3PreTrainedModel(Gemma2PreTrainedModel): + base_model_prefix = "language_model" + _no_split_modules = [ + "Gemma3DecoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + + def _init_weights(self, module): + # important: this ported version of Gemma2 isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + class Gemma3Model(Gemma2Model): config_class = Gemma3TextConfig def __init__(self, config: Gemma3TextConfig): super().__init__(config) + # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402 + self.embed_tokens = Gemma3ScaledWordEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 + ) + # TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE + config = copy.deepcopy(config) config.rope_theta = config.rope_local_base_freq config.rope_scaling = {"rope_type": "default"} self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) @@ -636,11 +699,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # normalized - # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 - # See https://github.com/huggingface/transformers/pull/29402 - normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=inputs_embeds.dtype) - inputs_embeds = inputs_embeds * normalizer if use_cache and past_key_values is None and not self.training: batch_size, seq_len, _ = inputs_embeds.shape @@ -745,6 +803,7 @@ def forward( class Gemma3ForCausalLM(Gemma2ForCausalLM): config_class = Gemma3TextConfig + base_model_prefix = "language_model" def __init__(self, config: Gemma3TextConfig): super().__init__(config) @@ -768,10 +827,12 @@ def __init__(self, config: Gemma3Config): self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) def forward(self, vision_outputs: torch.Tensor): - b, _, l = vision_outputs.shape + batch_size, _, seq_length = vision_outputs.shape reshaped_vision_outputs = vision_outputs.transpose(1, 2) - reshaped_vision_outputs = reshaped_vision_outputs.reshape(b, l, self.patches_per_image, self.patches_per_image) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) reshaped_vision_outputs = reshaped_vision_outputs.contiguous() pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) @@ -780,11 +841,10 @@ def forward(self, vision_outputs: torch.Tensor): normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) - projected_vision_outputs = torch.einsum("btm,md->btd", normed_vision_outputs, self.mm_input_projection_weight) + projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight) return projected_vision_outputs.type_as(vision_outputs) -# The only diff on forward is that Gemma3 relies on `input_ids` when creating causal mask, while Paligemma -# passes input embeds. Can be removed when we enable good token type ids + class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: """ @@ -800,167 +860,6 @@ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: image_features = self.multi_modal_projector(vision_outputs) return image_features - @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, - token_type_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, - ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration - - >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") - >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") - - >>> prompt = "answer en Where is the cow standing?" - >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_length=30) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "answer en Where is the cow standing?\nbeach" - ```""" - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - is_training = token_type_ids is not None and labels is not None - - if inputs_embeds is None: - special_image_mask = (input_ids == self.config.image_token_index) - llm_input_ids = input_ids.clone() - llm_input_ids[special_image_mask] = 0 - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - # normalized - # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 - # See https://github.com/huggingface/transformers/pull/29402 - normalizer = torch.tensor(self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype) - inputs_embeds = inputs_embeds * normalizer - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed - - # Merge text and images - if pixel_values is not None: - image_features = self.get_image_features(pixel_values) - - special_image_mask = special_image_mask.unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index) - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_ids, is_training - ) - outputs = self.language_model( - attention_mask=causal_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - **lm_kwargs, - ) - - logits = outputs.logits - loss = None - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - shift_logits = logits[..., :-1, :] - shift_labels = labels[..., 1:] - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) - shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() - else: - shift_logits = shift_logits.contiguous() - shift_labels = shift_labels.contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - - flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) - flat_labels = shift_labels.view(-1).to(shift_logits.device) - loss = loss_fct(flat_logits, flat_labels) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Gemma3CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - def _update_causal_mask( self, attention_mask, @@ -980,9 +879,12 @@ def _update_causal_mask( # form and requires no inversion or slicing. return attention_mask + using_static_cache = isinstance(past_key_values, StaticCache) min_dtype = torch.finfo(self.dtype).min - batch_size, sequence_length = input_tensor.shape[:2] - if isinstance(past_key_values, (HybridCache, StaticCache)): + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): target_length = past_key_values.get_max_cache_shape() else: target_length = ( @@ -991,39 +893,43 @@ def _update_causal_mask( else cache_position[0] + sequence_length + 1 ) - causal_mask = torch.ones((sequence_length, target_length), dtype=torch.bool) - causal_mask = torch.tril(causal_mask) - causal_mask = causal_mask.to(self.device) - - attention_mask = attention_mask.unsqueeze(-2).to(self.device) - causal_mask = causal_mask.unsqueeze(0).repeat(attention_mask.shape[0], 1, 1) - combined_mask = attention_mask * causal_mask[:, :, :sequence_length] - - image_token_mask = input_tensor == self.config.image_token_index - image_token_mask.to(self.device) - # logger.warning("image_token_mask shape = %s", image_token_mask.shape) - padded_mask = nn.functional.pad(image_token_mask, (1, 0), value=0) - padded_mask = padded_mask.to(self.device) - # logger.warning("padded_mask shape = %s", padded_mask.shape) - boundary = padded_mask[:, 1:] > padded_mask[:, :-1] - boundary = boundary.to(self.device) - numbered_boundary = torch.cumsum(boundary, dim=-1) - numbered_boundary = numbered_boundary.to(self.device) - q_block_indices = image_token_mask * numbered_boundary - q_block_indices = q_block_indices.to(self.device) - kv_block_indices = q_block_indices - # logger.warning("q_block_indices/kv_block_indices shape = %s", q_block_indices.shape) - bidirectional_mask = torch.logical_and( - kv_block_indices[:, None, :] == q_block_indices.unsqueeze(-1), - q_block_indices.unsqueeze(-1) > 0, + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + return attention_mask + + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device ) - bidirectional_mask.to(self.device) - attention_mask = torch.logical_or(combined_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)).to(self.device) - full_attention_mask = torch.zeros((batch_size, 1, sequence_length, target_length)).to(self.device, torch.bool) - full_attention_mask[:, :, :, :sequence_length] = attention_mask - attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(self.device) - return attention_mask + # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + + # Apply bidirectional mask on images if token type ids are provided + if token_type_ids is not None and sequence_length != 1: + token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) + token_type_mask[token_type_ids == 0] = False # if text token do not change anything + token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) + causal_mask = causal_mask.clone() + causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( + token_type_mask, 0.0 + ) + + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + + # Then apply padding mask (will mask pad tokens) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask __all__ = [ diff --git a/hf-gemma3/processing_gemma3.py b/hf-gemma3/processing_gemma3.py index b45cbe73c7ce..eab9759d9631 100644 --- a/hf-gemma3/processing_gemma3.py +++ b/hf-gemma3/processing_gemma3.py @@ -16,6 +16,8 @@ import re from typing import List, Optional, Union +import numpy as np + from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, make_nested_list_of_images from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack @@ -32,6 +34,7 @@ class Gemma3ImagesKwargs(ImagesKwargs): class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Gemma3ImagesKwargs _defaults = { "text_kwargs": { "padding": False, @@ -47,7 +50,7 @@ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): class Gemma3Processor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] + valid_kwargs = ["chat_template", "image_seq_length"] image_processor_class = "Gemma3ImageProcessor" tokenizer_class = "AutoTokenizer" @@ -56,13 +59,14 @@ def __init__( image_processor, tokenizer, chat_template=None, - num_mm_soft_tokens_per_image: int = 256, + image_seq_length: int = 256, **kwargs, ): - self.image_seq_length = getattr(image_processor, "image_seq_length") + self.image_seq_length = image_seq_length self.image_token_id = tokenizer.image_token_id - image_tokens_expanded = "".join([tokenizer.image_token] * num_mm_soft_tokens_per_image) - self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded }{tokenizer.eoi_token}\n\n" + self.boi_token = tokenizer.boi_token + image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length) + self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n" super().__init__( image_processor=image_processor, @@ -100,7 +104,7 @@ def __call__( # Create empty text to be replaced with placeholders if not text: - text = [" ".join([""] * len(images)) for images in batched_images] + text = [" ".join([self.boi_token] * len(images)) for images in batched_images] if len(batched_images) != len(text): raise ValueError( @@ -109,8 +113,9 @@ def __call__( # Replace image tokens by the full expanded sequence batch_num_crops = to_py_obj(image_inputs.pop("num_crops")) - for prompt, images, num_crops in zip(text, batched_images, batch_num_crops): - image_indexes = [m.start() for m in re.finditer("", prompt)] + text_with_crops = text + for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)): + image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)] if len(images) != len(image_indexes): raise ValueError( @@ -121,16 +126,25 @@ def __call__( for num, idx in reversed(list(zip(num_crops, image_indexes))): if num: formatted_image_text = ( - "Here is the original image and here are some crops to help you see better " - + " ".join([""] * num) + f"Here is the original image {self.boi_token} and here are some crops to help you see better " + + " ".join([self.boi_token] * num) ) - prompt = prompt[:idx] + formatted_image_text + prompt[idx + len("") :] + prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :] + text_with_crops[batch_idx] = prompt # Expand placeholder image tokens to the full image token sequence - text = [prompt.replace("", self.full_image_sequence) for prompt in text] + text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text] + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np") - text_input = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) - return BatchFeature(data={**text_input, **image_inputs}) + # Add token type ids manually, as tokenizer can't do arbitrary position token types + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs + text_inputs["token_type_ids"] = mm_token_type_ids.tolist() + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma def batch_decode(self, *args, **kwargs): @@ -151,7 +165,7 @@ def decode(self, *args, **kwargs): @property # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma def model_input_names(self): - tokenizer_input_names = self.tokenizer.model_input_names + tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"] image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) From 12b7e9dec3860b1424301d42ca9995db8499f45c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 01:47:36 -0700 Subject: [PATCH 09/52] update Signed-off-by: Woosuk Kwon --- examples/offline_inference/vision_language.py | 2 +- examples/offline_inference/vision_language_multi_image.py | 2 +- tests/models/registry.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index d53073bf22ed..cbb686c32c93 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -114,7 +114,7 @@ def run_fuyu(question: str, modality: str): def run_gemma3(question: str, modality: str): assert modality == "image" prompt = f" {question}" - model_name = "gg-hf-g/gemma-3-4b" + model_name = "gg-hf-g/gemma-3-4b-it" llm = LLM(model=model_name, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index c2de580f12b1..5df45c6eb629 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -78,7 +78,7 @@ def load_deepseek_vl2(question: str, image_urls: list[str]): def load_gemma3(question, image_urls: list[str]) -> ModelRequestData: - model_name = "gg-hf-g/gemma-3-4b" + model_name = "gg-hf-g/gemma-3-4b-it" llm = LLM(model=model_name, max_model_len=8192, max_num_seqs=2, diff --git a/tests/models/registry.py b/tests/models/registry.py index fca5297bcbb6..e464db2746cc 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -241,7 +241,7 @@ def check_available_online( "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), - "Gemma3ForConditionalGeneration": _HfExamplesInfo("gg-hf-g/gemma-3-4b-it-pr"), + "Gemma3ForConditionalGeneration": _HfExamplesInfo("gg-hf-g/gemma-3-4b-it"), "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 From 26b6199dc1b593352c4ebd5ceb906ac4406a4a4b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 01:55:25 -0700 Subject: [PATCH 10/52] minor Signed-off-by: Woosuk Kwon --- .../multimodal/processing/test_common.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index faa8eb0cc2f8..0916f08e277f 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -153,7 +153,7 @@ def _test_processing_correctness( "deepseek-ai/deepseek-vl2-tiny", "microsoft/Florence-2-base", "adept/fuyu-8b", - "gg-hf-g/gemma-3-4b-it-pr", + "gg-hf-g/gemma-3-4b-it", "THUDM/glm-4v-9b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", @@ -195,28 +195,28 @@ def test_processing_correctness( ) -# yapf: disable -@pytest.mark.parametrize("model_id", ["microsoft/Phi-3-vision-128k-instruct"]) -@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) -@pytest.mark.parametrize("num_batches", [32]) -@pytest.mark.parametrize("simplify_rate", [1.0]) -# yapf: enable -def test_processing_correctness_phi3v( - model_id: str, - hit_rate: float, - num_batches: int, - simplify_rate: float, -): - # HACK - this is an attempted workaround for the following bug - # https://github.com/huggingface/transformers/issues/34307 - from transformers import AutoImageProcessor # noqa: F401 - from transformers import AutoProcessor # noqa: F401 - - AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True) - - _test_processing_correctness( - model_id, - hit_rate=hit_rate, - num_batches=num_batches, - simplify_rate=simplify_rate, - ) +# # yapf: disable +# @pytest.mark.parametrize("model_id", ["microsoft/Phi-3-vision-128k-instruct"]) +# @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) +# @pytest.mark.parametrize("num_batches", [32]) +# @pytest.mark.parametrize("simplify_rate", [1.0]) +# # yapf: enable +# def test_processing_correctness_phi3v( +# model_id: str, +# hit_rate: float, +# num_batches: int, +# simplify_rate: float, +# ): +# # HACK - this is an attempted workaround for the following bug +# # https://github.com/huggingface/transformers/issues/34307 +# from transformers import AutoImageProcessor # noqa: F401 +# from transformers import AutoProcessor # noqa: F401 + +# AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True) + +# _test_processing_correctness( +# model_id, +# hit_rate=hit_rate, +# num_batches=num_batches, +# simplify_rate=simplify_rate, +# ) From b99336a14b6d4e66648c90bd8d350bc87ecaae28 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 01:59:09 -0700 Subject: [PATCH 11/52] Remove Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 7f0e0d500d4b..5ac410f5fb65 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -50,9 +50,6 @@ logger = init_logger(__name__) -ATTENTION_TYPE_LOCAL = "local_sliding" -AttentionType = Literal["global", "local_sliding"] - class Gemma3MLP(nn.Module): From 004dc924965cf24a4ade341394c45290169dd2ba Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 02:18:37 -0700 Subject: [PATCH 12/52] fix Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 5ac410f5fb65..b8ca3aa71fa2 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -13,13 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Optional, Set, Tuple, Union, Literal +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn -# FIXME -# from transformers import Gemma3Config -from transformers.models.gemma3.configuration_gemma3 import Gemma3Config, Gemma3TextConfig +from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile From 2bb965bc99ecd144cf08b89525d4a06ada346e1f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 08:59:06 -0700 Subject: [PATCH 13/52] Add kwargs Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index b8ca3aa71fa2..16fa1d704506 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -173,6 +173,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, + **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -185,7 +186,7 @@ def forward( k = k.flatten(-2, -1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) + attn_output = self.attn(q, k, v, **kwargs) output, _ = self.o_proj(attn_output) return output @@ -234,6 +235,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states @@ -244,6 +246,7 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, + **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -297,6 +300,7 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -313,6 +317,7 @@ def forward( positions, hidden_states, residual, + **kwargs, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -422,9 +427,10 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + inputs_embeds, **kwargs) return hidden_states def compute_logits( From b11994552f6371d9f2d49a18620e42ee16ecdf1b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 14:23:11 -0700 Subject: [PATCH 14/52] Hew HF Signed-off-by: Woosuk Kwon --- hf-gemma3/__init__.py | 1 + hf-gemma3/configuration_gemma3.py | 178 ++++---- .../convert_gemma3_weights_orbax_to_hf.py | 47 ++- hf-gemma3/image_processing_gemma3.py | 68 ++- hf-gemma3/image_processing_gemma3_fast.py | 387 ++++++++++++++++++ hf-gemma3/modeling_gemma3.py | 84 ++-- hf-gemma3/modular_gemma3.py | 279 ++++--------- hf-gemma3/processing_gemma3.py | 3 +- 8 files changed, 684 insertions(+), 363 deletions(-) create mode 100644 hf-gemma3/image_processing_gemma3_fast.py diff --git a/hf-gemma3/__init__.py b/hf-gemma3/__init__.py index 511b0a38e1d8..37ec82f91037 100644 --- a/hf-gemma3/__init__.py +++ b/hf-gemma3/__init__.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .configuration_gemma3 import * from .image_processing_gemma3 import * + from .image_processing_gemma3_fast import * from .modeling_gemma3 import * from .processing_gemma3 import * else: diff --git a/hf-gemma3/configuration_gemma3.py b/hf-gemma3/configuration_gemma3.py index 276aecc5602c..c19a05ba60c4 100644 --- a/hf-gemma3/configuration_gemma3.py +++ b/hf-gemma3/configuration_gemma3.py @@ -32,17 +32,16 @@ class Gemma3TextConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`Gemma3Model`]. It is used to instantiate a Gemma3 + This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Gemma3-4B. - e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) + defaults will yield a similar configuration to that of the Gemma3Text-7B. + e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 262144): - Vocabulary size of the Gemma3 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Gemma3Model`] + vocab_size (`int`, *optional*, defaults to 262208): + Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Gemma3TextModel`] hidden_size (`int`, *optional*, defaults to 2304): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 9216): @@ -61,14 +60,43 @@ class Gemma3TextConfig(PretrainedConfig): `num_attention_heads`. head_dim (`int`, *optional*, defaults to 256): The attention head dimension. - sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window - attention. This is the size of the sliding window. - query_pre_attn_scalar (`float`, *optional*): - The scaling factor used on the attention scores, not that + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 1000000.0): - The base period of the RoPE embeddings used for global attention. + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): + Scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the + size of the sliding window. + final_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the attention scores. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. Expected contents: @@ -108,79 +136,68 @@ class Gemma3TextConfig(PretrainedConfig): The base period of the RoPE embeddings for local attention. sliding_window_pattern (`int`, *optional*, defaults to 6): Pattern for the sliding window attention. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the decoder. Will default to - `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` - activation function. - pad_token_id (`int`, *optional*, defaults to 0): - Padding token id. - eos_token_id (`int`, *optional*, defaults to 1): - End of stream token id. - bos_token_id (`int`, *optional*, defaults to 2): - Beginning of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether to tie weight embeddings - max_position_embeddings (`int`, *optional*, defaults to 131072): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - attention_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - final_logit_softcapping (`bool`, *optional*, defaults to `True`): - Whether to apply logit softcapping or nor - attn_logit_softcapping (`float`, *optional*, defaults to 50.0): - Scaling factor when applying tanh soft-capping on the attention scorexs. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): - The cache type to be used with `generate`. ```python - >>> from transformers import Gemma3Model, Gemma3TextConfig - >>> # Initializing a Gemma3 gemma3-4b style configuration - >>> configuration = Gemma3Config() - >>> # Initializing a model from the gemma3-4b style configuration - >>> model = Gemma3Model(configuration) + >>> from transformers import Gemma3TextModel, Gemma3TextConfig + >>> # Initializing a Gemma3Text gemma3_text-7b style configuration + >>> configuration = Gemma3TextConfig() + >>> # Initializing a model from the gemma3_text-7b style configuration + >>> model = Gemma3TextModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config - ```""" + ``` + rope_local_base_freq (float, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings for local attention. + sliding_window_pattern (`int`, *optional*, defaults to 6): + Pattern for the sliding window attention. + """ model_type = "gemma3_text" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, - vocab_size: int = 262_144, - hidden_size: int = 2304, - intermediate_size: int = 9216, - num_hidden_layers: int = 26, - num_attention_heads: int = 8, - num_key_value_heads: int = 4, - head_dim: int = 256, - sliding_window: int = 4096, - query_pre_attn_scalar: Optional[float] = None, - rope_theta: float = 1_000_000.0, - rope_scaling=None, - rope_local_base_freq: float = 10_000.0, - sliding_window_pattern: int = 6, - rms_norm_eps: float = 1e-6, - hidden_activation: str = "gelu_pytorch_tanh", - pad_token_id: int = 0, - eos_token_id: int = 1, - bos_token_id: int = 2, - tie_word_embeddings: bool = True, - max_position_embeddings: int = 131_072, - initializer_range: float = 0.02, - attention_bias: bool = False, - attention_dropout: float = 0.0, - use_cache: bool = True, + vocab_size=262_208, + hidden_size=2304, + intermediate_size=9216, + num_hidden_layers=26, + num_attention_heads=8, + num_key_value_heads=4, + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=131_072, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=1_000_000.0, + attention_bias=False, + attention_dropout=0.0, + query_pre_attn_scalar=256, + sliding_window=4096, final_logit_softcapping=None, attn_logit_softcapping=None, - cache_implementation: str = "hybrid", + cache_implementation="hybrid", + rope_scaling=None, + rope_local_base_freq=10_000.0, + sliding_window_pattern=6, **kwargs, ): super().__init__( @@ -190,7 +207,6 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) - self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size @@ -203,10 +219,6 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self.rope_local_base_freq = rope_local_base_freq - # For configuring HybridCache to work with 5:1 attention pattern - self.sliding_window_pattern = sliding_window_pattern self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.hidden_activation = hidden_activation @@ -215,6 +227,11 @@ def __init__( self.final_logit_softcapping = final_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping self.cache_implementation = cache_implementation + + self.rope_local_base_freq = rope_local_base_freq + # For configuring HybridCache to work with 5:1 attention pattern + self.sliding_window_pattern = sliding_window_pattern + self.rope_scaling = rope_scaling rope_config_validation(self) @@ -245,6 +262,7 @@ class Gemma3Config(PretrainedConfig): initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + Example: ```python diff --git a/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py b/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py index f9f4c45d45b1..bb833fa7c3b2 100644 --- a/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py +++ b/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint. python -m transformers.models.gemma3.convert_gemma3_weights_orbax_to_hf \ @@ -126,14 +142,14 @@ rope_theta=1_000_000, # used for global RoPE only rope_local_base_freq=10_000, attn_logit_softcapping=None, - query_pre_attn_scalar=256**-0.5, + query_pre_attn_scalar=256, max_position_embeddings=32_768, ), vision_config=None, ), _VARIANT_GEMMA_3_4B: Gemma3Config( text_config=Gemma3TextConfig( - vocab_size=262_144, + vocab_size=262_208, hidden_size=2560, intermediate_size=2560 * 8 // 2, num_attention_heads=8, @@ -145,13 +161,13 @@ rope_theta=1_000_000, rope_local_base_freq=10_000, attn_logit_softcapping=None, - query_pre_attn_scalar=256**-0.5, + query_pre_attn_scalar=256, ), vision_config=_VISION_CONFIG, ), _VARIANT_GEMMA_3_12B: Gemma3Config( text_config=Gemma3TextConfig( - vocab_size=262_144, + vocab_size=262_208, hidden_size=30 * 128, intermediate_size=30 * 128 * 8 // 2, num_attention_heads=16, @@ -163,13 +179,13 @@ rope_theta=1_000_000, rope_local_base_freq=10_000, attn_logit_softcapping=None, - query_pre_attn_scalar=256**-0.5, + query_pre_attn_scalar=256, ), vision_config=_VISION_CONFIG, ), _VARIANT_GEMMA_3_27B: Gemma3Config( text_config=Gemma3TextConfig( - vocab_size=262_144, + vocab_size=262_208, hidden_size=42 * 128, intermediate_size=42 * 128 * 8 // 2, num_attention_heads=32, @@ -181,7 +197,7 @@ rope_theta=1_000_000, rope_local_base_freq=10_000, attn_logit_softcapping=None, - query_pre_attn_scalar=(42 * 128 // 32) ** -0.5, # 1 / sqrt(hidden_size // num_attention_heads) + query_pre_attn_scalar=(42 * 128 // 32), # 1 / sqrt(hidden_size // num_attention_heads) ), vision_config=_VISION_CONFIG, ), @@ -348,6 +364,15 @@ def convert_transformer_weights( if prop == "input_embedding": # Tied to language_model.lm_head.weight, assigned at the end. converted_paths = ["language_model.model.embed_tokens.weight"] + + if not _TEXT_ONLY.value: + # Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama + pre_expansion_embeddings = weights + mu = np.mean(pre_expansion_embeddings, axis=0) + sigma = np.cov(pre_expansion_embeddings, rowvar=False, bias=True) + new_embeddings = np.random.multivariate_normal(mu, 1e-5 * sigma, size=64) + weights = np.vstack([pre_expansion_embeddings, new_embeddings]) + converted_weights = [weights] elif _TEXT_ONLY.value or prop in ("mm_output_embedding", "mm_input_embedding_extra"): return zip([], []) @@ -434,11 +459,6 @@ def convert_transformer_weights( return zip(converted_paths, converted_weights) -def transpose_reshape(x: torch.Tensor) -> torch.Tensor: - x = x.transpose(1, 2) - return x.reshape(x.shape[0] * x.shape[1], x.shape[2]).contiguous() - - @dataclasses.dataclass(frozen=True) class ConversionResult: state_tree: dict[str, torch.Tensor] @@ -509,8 +529,9 @@ def main(*args): ) if INCLUDE_CHAT_TEMPLATE.value: - # Include chat temaplate for CausalLM models + # Include chat template for CausalLM models tokenizer.chat_template = _CHAT_TEMPLATE + config.eos_token_id = [1, 106] if _TEXT_ONLY.value: config.vision_config = None diff --git a/hf-gemma3/image_processing_gemma3.py b/hf-gemma3/image_processing_gemma3.py index 46334cd41fb2..f985a9a9dd80 100644 --- a/hf-gemma3/image_processing_gemma3.py +++ b/hf-gemma3/image_processing_gemma3.py @@ -198,12 +198,20 @@ def pan_and_scan( crop_positions_w = [crop_size_w * i for i in range(num_crops_w)] crop_positions_h = [crop_size_h * i for i in range(num_crops_h)] - return [ - image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] - for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w) - ] + if input_data_format == ChannelDimension.LAST: + image_crops = [ + image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] + for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w) + ] + else: + image_crops = [ + image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] + for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w) + ] - def _process_images_for_pas( + return image_crops + + def _process_images_for_pan_and_scan( self, images: List[np.ndarray], do_pan_and_scan: bool, @@ -362,7 +370,7 @@ def preprocess( if do_pan_and_scan: images_list_and_num_crops = [ - self._process_images_for_pas( + self._process_images_for_pan_and_scan( images=images, do_pan_and_scan=do_pan_and_scan, pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, @@ -378,41 +386,27 @@ def preprocess( else: num_crops = [[0] for images in images_list] - if do_resize: - height, width = size["height"], size["width"] - images_list = [ - [ - resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format) - for image in images - ] - for images in images_list - ] + processed_images = [] + for images in images_list: + for image in images: + if do_resize: + height, width = size["height"], size["width"] + image = resize( + image=image, size=(height, width), resample=resample, input_data_format=input_data_format + ) - if do_rescale: - images_list = [ - [ - self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) - for image in images - ] - for images in images_list - ] + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) - if do_normalize: - images_list = [ - [ - self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) - for image in images - ] - for images in images_list - ] + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) - images = [ - to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) - for images in images_list - for image in images - ] + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + processed_images.append(image) - data = {"pixel_values": images, "num_crops": num_crops} + data = {"pixel_values": processed_images, "num_crops": num_crops} return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/hf-gemma3/image_processing_gemma3_fast.py b/hf-gemma3/image_processing_gemma3_fast.py new file mode 100644 index 000000000000..fd4848ce21da --- /dev/null +++ b/hf-gemma3/image_processing_gemma3_fast.py @@ -0,0 +1,387 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for SigLIP.""" + +import itertools +import math +from functools import partial +from typing import List, Optional, Union + +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + BatchFeature, + DefaultFastImageProcessorInitKwargs, + DefaultFastImageProcessorPreprocessKwargs, + get_size_dict, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + SizeDict, + get_image_size, + make_nested_list_of_images, + validate_kwargs, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + is_vision_available, + logging, +) + + +if is_vision_available(): + from ...image_utils import PILImageResampling + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + +logger = logging.get_logger(__name__) + + +class Gemma3FastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs): + do_pan_and_scan: Optional[bool] + pan_and_scan_min_crop_size: Optional[int] + pan_and_scan_max_num_crops: Optional[int] + pan_and_scan_min_ratio_to_activate: Optional[float] + + +class Gemma3FastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs): + do_pan_and_scan: Optional[bool] + pan_and_scan_min_crop_size: Optional[int] + pan_and_scan_max_num_crops: Optional[int] + pan_and_scan_min_ratio_to_activate: Optional[float] + + +@add_start_docstrings( + "Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of Pan adn Scan cropping method.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + do_pan_and_scan (`bool`, *optional*): + Whether to apply `pan_and_scan` to images. + pan_and_scan_min_crop_size (`int`, *optional*): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*): + Minimum aspect ratio to activate pan and scan. + """, +) +class Gemma3ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 224, "width": 224} + default_to_square = True + do_resize = True + do_rescale = True + do_normalize = True + do_pan_and_scan = None + pan_and_scan_min_crop_size = None + pan_and_scan_max_num_crops = None + pan_and_scan_min_ratio_to_activate = None + valid_init_kwargs = Gemma3FastImageProcessorInitKwargs + valid_preprocess_kwargs = Gemma3FastImageProcessorPreprocessKwargs + + def __init__(self, **kwargs: Unpack[Gemma3FastImageProcessorInitKwargs]): + super().__init__(**kwargs) + + def _prepare_images_structure( + self, + images: ImageInput, + ) -> ImageInput: + """ + Prepare the images structure for processing. + + Args: + images (`ImageInput`): + The input images to process. + + Returns: + `ImageInput`: The images with a valid nesting. + """ + return make_nested_list_of_images(images) + + def _prepare_input_images( + self, + images: ImageInput, + do_convert_rgb: bool = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + device: Optional["torch.device"] = None, + ) -> List["torch.Tensor"]: + """ + Prepare the input images for processing. + """ + batch_images = self._prepare_images_structure(images) + process_image_fn = partial( + self._process_image, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + device=device, + ) + # todo: yoni - check if we can parallelize this efficiently + batch_processed_images = [] + for image_list in batch_images: + processed_images = [] + for image in image_list: + processed_images.append(process_image_fn(image)) + batch_processed_images.append(processed_images) + + return batch_processed_images + + def pan_and_scan( + self, + image: "torch.Tensor", + pan_and_scan_min_crop_size: int, + pan_and_scan_max_num_crops: int, + pan_and_scan_min_ratio_to_activate: float, + ): + """ + Pan and Scan and image, by cropping into smaller images when the aspect ratio exceeds + minumum allowed ratio. + + Args: + image (`torch.Tensor`): + Image to resize. + pan_and_scan_min_crop_size (`int`, *optional*): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*): + Minimum aspect ratio to activate pan and scan. + """ + height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST) + + # Square or landscape image. + if width >= height: + # Only apply PaS if the image is sufficiently exaggerated + if width / height < pan_and_scan_min_ratio_to_activate: + return [] + + # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. + num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding. + num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w) + + # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. + num_crops_w = max(2, num_crops_w) + num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) + num_crops_h = 1 + + # Portrait image. + else: + # Only apply PaS if the image is sufficiently exaggerated + if height / width < pan_and_scan_min_ratio_to_activate: + return [] + + # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. + num_crops_h = int(math.floor(height / width + 0.5)) + num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h) + + # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. + num_crops_h = max(2, num_crops_h) + num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) + num_crops_w = 1 + + crop_size_w = int(math.ceil(width / num_crops_w)) + crop_size_h = int(math.ceil(height / num_crops_h)) + + # Don't apply PaS if crop size is too small. + if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: + return [] + + crop_positions_w = [crop_size_w * i for i in range(num_crops_w)] + crop_positions_h = [crop_size_h * i for i in range(num_crops_h)] + + return [ + image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] + for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w) + ] + + def _process_images_for_pan_and_scan( + self, + images: List["torch.Tensor"], + do_pan_and_scan: bool, + pan_and_scan_min_crop_size: int, + pan_and_scan_max_num_crops: int, + pan_and_scan_min_ratio_to_activate: float, + ): + pas_images_list = [] + num_crops = [] + for image in images: + pas_images = self.pan_and_scan( + image=image, + pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, + pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, + pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, + ) + pas_images_list.extend([image] + pas_images) + num_crops.append(len(pas_images)) + return pas_images_list, num_crops + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + do_pan_and_scan (`bool`, *optional*): + Whether to apply `pan_and_scan` to images. + pan_and_scan_min_crop_size (`int`, *optional*): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*): + Minimum aspect ratio to activate pan and scan. + """, + ) + def preprocess( + self, + images: ImageInput, + **kwargs: Unpack[Gemma3FastImageProcessorPreprocessKwargs], + ) -> BatchFeature: + validate_kwargs( + captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_preprocess_kwargs.__annotations__.keys() + ) + # Set default kwargs from self. This ensures that if a kwarg is not provided + # by the user, it gets its default value from the instance, or is set to None. + for kwarg_name in self.valid_preprocess_kwargs.__annotations__: + kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + + # Extract parameters that are only used for preparing the input images + do_convert_rgb = kwargs.pop("do_convert_rgb") + input_data_format = kwargs.pop("input_data_format") + device = kwargs.pop("device") + + images = self._prepare_input_images( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + + # Pop kwargs that need further processing or won't be used in _preprocess + default_to_square = kwargs.pop("default_to_square") + size = kwargs.pop("size") + crop_size = kwargs.pop("crop_size") + image_mean = kwargs.pop("image_mean") + image_std = kwargs.pop("image_std") + data_format = kwargs.pop("data_format") + resample = kwargs.pop("resample") + + # Make hashable for cache + size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None + crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None + image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean + image_std = tuple(image_std) if isinstance(image_std, list) else image_std + + image_mean, image_std, interpolation = self._prepare_process_arguments( + size=size, + crop_size=crop_size, + resample=resample, + image_mean=image_mean, + image_std=image_std, + data_format=data_format if data_format is not None else ChannelDimension.FIRST, + device=images[0][0].device, + do_resize=kwargs.get("do_resize"), + do_center_crop=kwargs.get("do_center_crop"), + do_rescale=kwargs.get("do_rescale"), + rescale_factor=kwargs.get("rescale_factor"), + do_normalize=kwargs.get("do_normalize"), + return_tensors=kwargs.get("return_tensors"), + ) + + return self._preprocess( + images=images, + size=size, + crop_size=crop_size, + interpolation=interpolation, + image_mean=image_mean, + image_std=image_std, + **kwargs, + ) + + def _preprocess( + self, + images: List[List["torch.Tensor"]], + do_resize: bool, + size: SizeDict, + do_pan_and_scan: Optional[bool], + pan_and_scan_min_crop_size: Optional[int], + pan_and_scan_max_num_crops: Optional[int], + pan_and_scan_min_ratio_to_activate: Optional[float], + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + return_tensors: Optional[Union[str, TensorType]], + ) -> BatchFeature: + processed_images = [] + batch_num_crops = [] + + for image_list in images: + if do_pan_and_scan: + images_list, num_crops = self._process_images_for_pan_and_scan( + images=image_list, + do_pan_and_scan=do_pan_and_scan, + pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, + pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, + pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, + ) + else: + num_crops = [[0] for images in images_list] + + # Group images by size for batched processing + processed_image_patches_grouped = {} + grouped_image_patches, grouped_image_patches_index = group_images_by_shape(images_list) + for shape, stacked_image_patches in grouped_image_patches.items(): + if do_resize: + stacked_image_patches = self.resize( + image=stacked_image_patches, + size=size, + interpolation=interpolation, + ) + # Fused rescale and normalize + stacked_image_patches = self.rescale_and_normalize( + stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_image_patches_grouped[shape] = stacked_image_patches + processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index) + processed_image_patches = ( + torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches + ) + processed_images.extend(processed_image_patches) + batch_num_crops.extend(num_crops) + + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + return BatchFeature( + data={"pixel_values": processed_images, "num_crops": batch_num_crops}, tensor_type=return_tensors + ) + + +__all__ = ["Gemma3ImageProcessorFast"] diff --git a/hf-gemma3/modeling_gemma3.py b/hf-gemma3/modeling_gemma3.py index 5a79ec5143ff..fc4e686fbbbe 100644 --- a/hf-gemma3/modeling_gemma3.py +++ b/hf-gemma3/modeling_gemma3.py @@ -91,7 +91,7 @@ class Gemma3CausalLMOutputWithPast(ModelOutput): image_hidden_states: Optional[torch.FloatTensor] = None -class Gemma3ScaledWordEmbedding(nn.Embedding): +class Gemma3TextScaledWordEmbedding(nn.Embedding): """ This module overrides nn.Embeddings' forward by multiplying with embeddings scale. """ @@ -248,15 +248,16 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def eager_attention_forward( - module: "Gemma3Attention", + module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, + softcap: Optional[float] = None, **kwargs, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: if scaling is None: scaling = module.head_dim**-0.5 @@ -265,6 +266,10 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * softcap if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask @@ -278,51 +283,46 @@ def eager_attention_forward( class Gemma3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + def __init__(self, config: Gemma3TextConfig, layer_idx: int): super().__init__() - self.attention_dropout = config.attention_dropout + self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) self.config = config - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.is_causal = True self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = config.query_pre_attn_scalar - self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) - self.sliding_window = config.sliding_window + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + self.is_causal = True self.q_proj = nn.Linear( - config.hidden_size, - config.num_attention_heads * self.head_dim, - bias=config.attention_bias, + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=config.attention_bias, + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=config.attention_bias, + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, - config.hidden_size, - bias=config.attention_bias, + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.sliding_window = config.sliding_window if self.is_sliding else None + self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, - position_embeddings_global: torch.Tensor, - position_embeddings_local: torch.Tensor, + position_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -333,11 +333,7 @@ def forward( query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) - if self.is_sliding: - cos, sin = position_embeddings_local - else: - cos, sin = position_embeddings_global - + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -438,11 +434,15 @@ def forward( hidden_states = self.input_layernorm(hidden_states) - # Self Attention + # apply global RoPE to non-sliding layer only + if self.self_attn.is_sliding: + position_embeddings = position_embeddings_local + else: + position_embeddings = position_embeddings_global + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - position_embeddings_global=position_embeddings_global, - position_embeddings_local=position_embeddings_local, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, @@ -603,15 +603,15 @@ def _init_weights(self, module): @add_start_docstrings( - "The bare Gemma3 Model outputting raw hidden-states without any specific head on top.", + "The bare Gemma3Text Model outputting raw hidden-states without any specific head on top.", GEMMA3_START_DOCSTRING, ) -class Gemma3Model(Gemma3PreTrainedModel): +class Gemma3TextModel(Gemma3PreTrainedModel): """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma3DecoderLayer`] + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma3TextDecoderLayer`] Args: - config: Gemma3Config + config: Gemma3TextConfig """ config_class = Gemma3TextConfig @@ -622,7 +622,7 @@ def __init__(self, config: Gemma3TextConfig): self.vocab_size = config.vocab_size # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402 - self.embed_tokens = Gemma3ScaledWordEmbedding( + self.embed_tokens = Gemma3TextScaledWordEmbedding( config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 ) self.layers = nn.ModuleList( @@ -792,7 +792,7 @@ def _update_causal_mask( past_key_values: HybridCache, output_attentions: bool, ): - # Flash Attention currently doesn't support static cache but Gemma3 work only with static cache. + # Flash Attention currently doesn't support static cache but Gemma3Text work only with static cache. # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible # as it doesn't cause dynamic control issues. @@ -886,7 +886,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): def __init__(self, config: Gemma3TextConfig): super().__init__(config) - self.model = Gemma3Model(config) + self.model = Gemma3TextModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -1154,9 +1154,7 @@ def _update_causal_mask( is_training: bool = False, ): if self.config.text_config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None + return attention_mask if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted @@ -1447,4 +1445,4 @@ def prepare_inputs_for_generation( return model_inputs -__all__ = ["Gemma3PreTrainedModel", "Gemma3Model", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration"] +__all__ = ["Gemma3PreTrainedModel", "Gemma3TextModel", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration"] diff --git a/hf-gemma3/modular_gemma3.py b/hf-gemma3/modular_gemma3.py index 8c2653cab03b..fa3107ab5baa 100644 --- a/hf-gemma3/modular_gemma3.py +++ b/hf-gemma3/modular_gemma3.py @@ -36,7 +36,9 @@ logging, ) from ..bart.modeling_bart import BartScaledWordEmbedding +from ..gemma2.configuration_gemma2 import Gemma2Config from ..gemma2.modeling_gemma2 import ( + Gemma2Attention, Gemma2ForCausalLM, Gemma2MLP, Gemma2Model, @@ -44,7 +46,7 @@ Gemma2RMSNorm, Gemma2RotaryEmbedding, apply_rotary_pos_emb, - repeat_kv, + eager_attention_forward, ) from ..paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration from ..siglip import SiglipVisionConfig @@ -58,19 +60,18 @@ GEMMA3_INPUTS_DOCSTRING = "" -class Gemma3TextConfig(PretrainedConfig): +class Gemma3TextConfig(Gemma2Config): r""" - This is the configuration class to store the configuration of a [`Gemma3Model`]. It is used to instantiate a Gemma3 + This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Gemma3-4B. - e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) + defaults will yield a similar configuration to that of the Gemma3Text-7B. + e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 262144): - Vocabulary size of the Gemma3 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Gemma3Model`] + vocab_size (`int`, *optional*, defaults to 262208): + Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Gemma3TextModel`] hidden_size (`int`, *optional*, defaults to 2304): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 9216): @@ -89,14 +90,43 @@ class Gemma3TextConfig(PretrainedConfig): `num_attention_heads`. head_dim (`int`, *optional*, defaults to 256): The attention head dimension. - sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window - attention. This is the size of the sliding window. - query_pre_attn_scalar (`float`, *optional*): - The scaling factor used on the attention scores, not that + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 1000000.0): - The base period of the RoPE embeddings used for global attention. + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): + Scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the + size of the sliding window. + final_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the attention scores. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. Expected contents: @@ -136,113 +166,42 @@ class Gemma3TextConfig(PretrainedConfig): The base period of the RoPE embeddings for local attention. sliding_window_pattern (`int`, *optional*, defaults to 6): Pattern for the sliding window attention. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the decoder. Will default to - `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` - activation function. - pad_token_id (`int`, *optional*, defaults to 0): - Padding token id. - eos_token_id (`int`, *optional*, defaults to 1): - End of stream token id. - bos_token_id (`int`, *optional*, defaults to 2): - Beginning of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether to tie weight embeddings - max_position_embeddings (`int`, *optional*, defaults to 131072): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - attention_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - final_logit_softcapping (`bool`, *optional*, defaults to `True`): - Whether to apply logit softcapping or nor - attn_logit_softcapping (`float`, *optional*, defaults to 50.0): - Scaling factor when applying tanh soft-capping on the attention scorexs. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): - The cache type to be used with `generate`. ```python - >>> from transformers import Gemma3Model, Gemma3TextConfig - >>> # Initializing a Gemma3 gemma3-4b style configuration - >>> configuration = Gemma3Config() - >>> # Initializing a model from the gemma3-4b style configuration - >>> model = Gemma3Model(configuration) + >>> from transformers import Gemma3TextModel, Gemma3TextConfig + >>> # Initializing a Gemma3Text gemma3_text-7b style configuration + >>> configuration = Gemma3TextConfig() + >>> # Initializing a model from the gemma3_text-7b style configuration + >>> model = Gemma3TextModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config - ```""" + ``` + rope_local_base_freq (float, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings for local attention. + sliding_window_pattern (`int`, *optional*, defaults to 6): + Pattern for the sliding window attention. + """ model_type = "gemma3_text" def __init__( self, - vocab_size: int = 262_144, - hidden_size: int = 2304, - intermediate_size: int = 9216, - num_hidden_layers: int = 26, - num_attention_heads: int = 8, - num_key_value_heads: int = 4, - head_dim: int = 256, - sliding_window: int = 4096, - query_pre_attn_scalar: Optional[float] = None, - rope_theta: float = 1_000_000.0, + vocab_size=262_208, + rope_theta=1_000_000.0, rope_scaling=None, - rope_local_base_freq: float = 10_000.0, - sliding_window_pattern: int = 6, - rms_norm_eps: float = 1e-6, - hidden_activation: str = "gelu_pytorch_tanh", - pad_token_id: int = 0, - eos_token_id: int = 1, - bos_token_id: int = 2, - tie_word_embeddings: bool = True, - max_position_embeddings: int = 131_072, - initializer_range: float = 0.02, - attention_bias: bool = False, - attention_dropout: float = 0.0, - use_cache: bool = True, + rope_local_base_freq=10_000.0, + sliding_window_pattern=6, + max_position_embeddings=131_072, final_logit_softcapping=None, attn_logit_softcapping=None, - cache_implementation: str = "hybrid", - **kwargs, + **super_kwargs, ): - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) + super().__init__(self, **super_kwargs) - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.head_dim = head_dim - self.num_key_value_heads = num_key_value_heads - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling self.rope_local_base_freq = rope_local_base_freq # For configuring HybridCache to work with 5:1 attention pattern self.sliding_window_pattern = sliding_window_pattern - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.hidden_activation = hidden_activation - self.query_pre_attn_scalar = query_pre_attn_scalar - self.sliding_window = sliding_window - self.final_logit_softcapping = final_logit_softcapping - self.attn_logit_softcapping = attn_logit_softcapping - self.cache_implementation = cache_implementation + self.rope_scaling = rope_scaling rope_config_validation(self) @@ -258,21 +217,22 @@ class Gemma3Config(PretrainedConfig): documentation from [`PretrainedConfig`] for more information. Args: - vision_config (`Union[AutoConfig, dict]`, *optional*): - Custom vision config or dict. text_config (`Union[Gemma3TextConfig, dict]`, *optional*): The config object of the text backbone. + vision_config (`Union[AutoConfig, dict]`, *optional*): + Custom vision config or dict. mm_tokens_per_image (`int`, *optional*, defaults to 256): The number of tokens per image embedding. - image_token_index (`int`, *optional*, defaults to 262_144): - The image token index to encode the image prompt. - boi_token_index (`int`, *optional*, defaults to 255_999): + boi_token_index (`int`, *optional*, defaults to 255999): The begin-of-image token index to wrap the image prompt. - eoi_token_index (`int`, *optional*, defaults to 256_000): + eoi_token_index (`int`, *optional*, defaults to 256000): The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 262144): + The image token index to encode the image prompt. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + Example: ```python @@ -377,7 +337,7 @@ class Gemma3CausalLMOutputWithPast(ModelOutput): image_hidden_states: Optional[torch.FloatTensor] = None -class Gemma3ScaledWordEmbedding(BartScaledWordEmbedding): +class Gemma3TextScaledWordEmbedding(BartScaledWordEmbedding): pass @@ -396,77 +356,21 @@ def __init__(self, config: Gemma3TextConfig, device=None): super().__init__(config) -def eager_attention_forward( - module: "Gemma3Attention", - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - dropout: float = 0.0, - scaling: Optional[float] = None, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor]: - if scaling is None: - scaling = module.head_dim**-0.5 - - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -class Gemma3Attention(nn.Module): +# Weird way to inherit but otherwise the sliding window gets defined first and can't access `is_sliding` +class Gemma3Attention(Gemma2Attention): def __init__(self, config: Gemma3TextConfig, layer_idx: int): - super().__init__() - self.attention_dropout = config.attention_dropout - self.config = config - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.is_causal = True - self.layer_idx = layer_idx - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = config.query_pre_attn_scalar self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) - self.sliding_window = config.sliding_window - self.q_proj = nn.Linear( - config.hidden_size, - config.num_attention_heads * self.head_dim, - bias=config.attention_bias, - ) - self.k_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=config.attention_bias, - ) - self.v_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=config.attention_bias, - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, - config.hidden_size, - bias=config.attention_bias, - ) + super().__init__() + self.sliding_window = config.sliding_window if self.is_sliding else None + self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, - position_embeddings_global: torch.Tensor, - position_embeddings_local: torch.Tensor, + position_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, @@ -482,11 +386,7 @@ def forward( query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) - if self.is_sliding: - cos, sin = position_embeddings_local - else: - cos, sin = position_embeddings_global - + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -587,11 +487,15 @@ def forward( hidden_states = self.input_layernorm(hidden_states) - # Self Attention + # apply global RoPE to non-sliding layer only + if self.self_attn.is_sliding: + position_embeddings = position_embeddings_local + else: + position_embeddings = position_embeddings_global + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - position_embeddings_global=position_embeddings_global, - position_embeddings_local=position_embeddings_local, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, @@ -648,14 +552,14 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -class Gemma3Model(Gemma2Model): +class Gemma3TextModel(Gemma2Model): config_class = Gemma3TextConfig def __init__(self, config: Gemma3TextConfig): super().__init__(config) # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402 - self.embed_tokens = Gemma3ScaledWordEmbedding( + self.embed_tokens = Gemma3TextScaledWordEmbedding( config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 ) @@ -807,6 +711,7 @@ class Gemma3ForCausalLM(Gemma2ForCausalLM): def __init__(self, config: Gemma3TextConfig): super().__init__(config) + self.model = Gemma3TextModel(config) class Gemma3MultiModalProjector(nn.Module): @@ -870,9 +775,7 @@ def _update_causal_mask( is_training: bool = False, ): if self.config.text_config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None + return attention_mask if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted @@ -936,7 +839,7 @@ def _update_causal_mask( "Gemma3Config", "Gemma3TextConfig", "Gemma3PreTrainedModel", # noqa: F822 - "Gemma3Model", + "Gemma3TextModel", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration", ] diff --git a/hf-gemma3/processing_gemma3.py b/hf-gemma3/processing_gemma3.py index eab9759d9631..e82b609bdb10 100644 --- a/hf-gemma3/processing_gemma3.py +++ b/hf-gemma3/processing_gemma3.py @@ -51,7 +51,7 @@ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): class Gemma3Processor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] valid_kwargs = ["chat_template", "image_seq_length"] - image_processor_class = "Gemma3ImageProcessor" + image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" def __init__( @@ -163,7 +163,6 @@ def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) @property - # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"] image_processor_input_names = self.image_processor.model_input_names From 4c675733d19ea22c241f3b12a497a7377c547dec Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 14:23:20 -0700 Subject: [PATCH 15/52] Fix scaling Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 16fa1d704506..b4bcb2010044 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -114,7 +114,7 @@ def __init__(self, self.head_dim = head_dim self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = config.query_pre_attn_scalar + self.scaling = config.query_pre_attn_scalar**-0.5 self.qkv_proj = QKVParallelLinear( hidden_size, From d90d410ce31bec128333e70c7d50a139015421b0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 15:42:04 -0700 Subject: [PATCH 16/52] bidirectional attn Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3.py | 42 +++++++++++++++++++++- vllm/model_executor/models/gemma3_mm.py | 47 ++++++++++++++++++++++++- 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index b4bcb2010044..9399d28ea6cd 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -17,6 +17,7 @@ import torch from torch import nn +import torch.nn.functional as F from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig from vllm.attention import Attention @@ -186,7 +187,46 @@ def forward( k = k.flatten(-2, -1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, **kwargs) + attn_output = self.attn(q, k, v) + + if not kwargs.get("has_images", False): + output, _ = self.o_proj(attn_output) + return output + + q = q.view(-1, self.num_heads, self.head_dim) + # Expand the key and value to handle GQA. + num_queries_per_kv = self.num_heads // self.num_kv_heads + k = k.view(-1, self.num_kv_heads, self.head_dim) + k = k.repeat_interleave(num_queries_per_kv, dim=-2) + v = v.view(-1, self.num_kv_heads, self.head_dim) + v = v.repeat_interleave(num_queries_per_kv, dim=-2) + + if self.is_sliding: + attn_masks = kwargs["local_attn_masks"] + else: + attn_masks = kwargs["global_attn_masks"] + + seq_lens = kwargs["seq_lens"] + start_idx = 0 + for seq_len, attn_mask in zip(seq_lens, attn_masks): + end_idx = start_idx + seq_len + query = q[start_idx:end_idx].unsqueeze(0) + key = k[start_idx:end_idx].unsqueeze(0) + value = v[start_idx:end_idx].unsqueeze(0) + + # Transpose. + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + out = F.scaled_dot_product_attention( + query, key, value, attn_mask, self.scaling, + ) + + out = out.transpose(1, 2).flatten(-2, -1) + attn_output[start_idx:end_idx] = out + start_idx = end_idx + output, _ = self.o_proj(attn_output) return output diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 14d98047cdaf..e409147cefd5 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -323,19 +323,64 @@ def forward(self, **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: if intermediate_tensors is not None: inputs_embeds = None + kwargs.clear() # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) + kwargs.clear() + inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) + if vision_embeddings is not None: + kwargs["has_images"] = True + start_idices = (positions == 0).cpu().nonzero() + num_seqs = len(start_idices) + seq_lens = [] + for i in range(num_seqs): + start_idx = start_idices[i].item() + if i < num_seqs - 1: + end_idx = start_idices[i + 1].item() + else: + end_idx = len(input_ids) + seq_lens.append(end_idx - start_idx) + kwargs["seq_lens"] = seq_lens + + global_attn_masks = [] + local_attn_masks = [] + for seq_len in seq_lens: + input_token_ids = input_ids[start_idx:end_idx] + global_attn_mask = torch.empty( + 1, 1, seq_len, seq_len, + dtype=vision_embeddings.dtype, + device=vision_embeddings.device, + ) + global_attn_mask.fill_(float("-inf")) + # Fill the lower triangle with 0. + global_attn_mask = global_attn_mask.triu(diagonal=1) + + img_mask = torch.zeros_like(global_attn_mask) + img_pos = (input_token_ids == self.config.image_token_index) + img_mask[:, :, :, img_pos] += 1 + img_mask[:, :, img_pos, :] += 1 + + global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) + global_attn_masks.append(global_attn_mask) + + local_attn_mask = global_attn_mask + local_attn_masks.append(local_attn_mask) + + kwargs["global_attn_masks"] = global_attn_masks + kwargs["local_attn_masks"] = local_attn_masks + input_ids = None hidden_states = self.language_model.model(input_ids, positions, intermediate_tensors, - inputs_embeds=inputs_embeds) + inputs_embeds=inputs_embeds, + **kwargs) return hidden_states From 366e4b5fa4eea12ce7190ae1ee81168bf1c7df33 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 16:03:01 -0700 Subject: [PATCH 17/52] sliding window Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3_mm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index e409147cefd5..b5cd54283533 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -364,11 +364,13 @@ def forward(self, img_pos = (input_token_ids == self.config.image_token_index) img_mask[:, :, :, img_pos] += 1 img_mask[:, :, img_pos, :] += 1 - global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) global_attn_masks.append(global_attn_mask) - local_attn_mask = global_attn_mask + SLIDING_WINDOW_SIZE = 1024 + local_attn_mask = torch.ones_like(global_attn_mask) + local_attn_mask = torch.tril(local_attn_mask, diagonal=-SLIDING_WINDOW_SIZE) + local_attn_mask = torch.where(local_attn_mask == 0, global_attn_mask, float("-inf")) local_attn_masks.append(local_attn_mask) kwargs["global_attn_masks"] = global_attn_masks From 77b9dd7bdc103eb3efe9809aef4bc7a994f1561b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 16:28:32 -0700 Subject: [PATCH 18/52] Remove HF Signed-off-by: Woosuk Kwon --- hf-gemma3/__init__.py | 30 - hf-gemma3/configuration_gemma3.py | 330 ---- .../convert_gemma3_weights_orbax_to_hf.py | 588 ------- hf-gemma3/image_processing_gemma3.py | 413 ----- hf-gemma3/image_processing_gemma3_fast.py | 387 ----- hf-gemma3/modeling_gemma3.py | 1448 ----------------- hf-gemma3/modular_gemma3.py | 845 ---------- hf-gemma3/processing_gemma3.py | 172 -- 8 files changed, 4213 deletions(-) delete mode 100644 hf-gemma3/__init__.py delete mode 100644 hf-gemma3/configuration_gemma3.py delete mode 100644 hf-gemma3/convert_gemma3_weights_orbax_to_hf.py delete mode 100644 hf-gemma3/image_processing_gemma3.py delete mode 100644 hf-gemma3/image_processing_gemma3_fast.py delete mode 100644 hf-gemma3/modeling_gemma3.py delete mode 100644 hf-gemma3/modular_gemma3.py delete mode 100644 hf-gemma3/processing_gemma3.py diff --git a/hf-gemma3/__init__.py b/hf-gemma3/__init__.py deleted file mode 100644 index 37ec82f91037..000000000000 --- a/hf-gemma3/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import TYPE_CHECKING - -from ...utils import _LazyModule -from ...utils.import_utils import define_import_structure - - -if TYPE_CHECKING: - from .configuration_gemma3 import * - from .image_processing_gemma3 import * - from .image_processing_gemma3_fast import * - from .modeling_gemma3 import * - from .processing_gemma3 import * -else: - import sys - - _file = globals()["__file__"] - sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/hf-gemma3/configuration_gemma3.py b/hf-gemma3/configuration_gemma3.py deleted file mode 100644 index c19a05ba60c4..000000000000 --- a/hf-gemma3/configuration_gemma3.py +++ /dev/null @@ -1,330 +0,0 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_gemma3.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# coding=utf-8 -# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -from ...configuration_utils import PretrainedConfig -from ...modeling_rope_utils import rope_config_validation -from ...utils import logging -from ..siglip import SiglipVisionConfig - - -logger = logging.get_logger(__name__) - - -class Gemma3TextConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Gemma3Text-7B. - e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b) - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 262208): - Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Gemma3TextModel`] - hidden_size (`int`, *optional*, defaults to 2304): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 9216): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 26): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 8): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*, defaults to 4): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - head_dim (`int`, *optional*, defaults to 256): - The attention head dimension. - hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` - if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. - max_position_embeddings (`int`, *optional*, defaults to 131072): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*, defaults to 0): - Padding token id. - eos_token_id (`int`, *optional*, defaults to 1): - End of stream token id. - bos_token_id (`int`, *optional*, defaults to 2): - Beginning of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 1000000.0): - The base period of the RoPE embeddings. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - query_pre_attn_scalar (`float`, *optional*, defaults to 256): - Scaling factor used on the attention scores - sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the - size of the sliding window. - final_logit_softcapping (`float`, *optional*): - Scaling factor when applying tanh softcapping on the logits. - attn_logit_softcapping (`float`, *optional*): - Scaling factor when applying tanh softcapping on the attention scores. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'llama3'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - rope_local_base_freq (float, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings for local attention. - sliding_window_pattern (`int`, *optional*, defaults to 6): - Pattern for the sliding window attention. - - ```python - >>> from transformers import Gemma3TextModel, Gemma3TextConfig - >>> # Initializing a Gemma3Text gemma3_text-7b style configuration - >>> configuration = Gemma3TextConfig() - >>> # Initializing a model from the gemma3_text-7b style configuration - >>> model = Gemma3TextModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ``` - rope_local_base_freq (float, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings for local attention. - sliding_window_pattern (`int`, *optional*, defaults to 6): - Pattern for the sliding window attention. - """ - - model_type = "gemma3_text" - keys_to_ignore_at_inference = ["past_key_values"] - base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", - } - base_model_pp_plan = { - "embed_tokens": (["input_ids"], ["inputs_embeds"]), - "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), - "norm": (["hidden_states"], ["hidden_states"]), - } - - def __init__( - self, - vocab_size=262_208, - hidden_size=2304, - intermediate_size=9216, - num_hidden_layers=26, - num_attention_heads=8, - num_key_value_heads=4, - head_dim=256, - hidden_activation="gelu_pytorch_tanh", - max_position_embeddings=131_072, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - eos_token_id=1, - bos_token_id=2, - tie_word_embeddings=True, - rope_theta=1_000_000.0, - attention_bias=False, - attention_dropout=0.0, - query_pre_attn_scalar=256, - sliding_window=4096, - final_logit_softcapping=None, - attn_logit_softcapping=None, - cache_implementation="hybrid", - rope_scaling=None, - rope_local_base_freq=10_000.0, - sliding_window_pattern=6, - **kwargs, - ): - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.head_dim = head_dim - self.num_key_value_heads = num_key_value_heads - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.hidden_activation = hidden_activation - self.query_pre_attn_scalar = query_pre_attn_scalar - self.sliding_window = sliding_window - self.final_logit_softcapping = final_logit_softcapping - self.attn_logit_softcapping = attn_logit_softcapping - self.cache_implementation = cache_implementation - - self.rope_local_base_freq = rope_local_base_freq - # For configuring HybridCache to work with 5:1 attention pattern - self.sliding_window_pattern = sliding_window_pattern - self.rope_scaling = rope_scaling - rope_config_validation(self) - - -class Gemma3Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an - Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the PaliGemma-2B. - - e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - text_config (`Union[Gemma3TextConfig, dict]`, *optional*): - The config object of the text backbone. - vision_config (`Union[AutoConfig, dict]`, *optional*): - Custom vision config or dict. - mm_tokens_per_image (`int`, *optional*, defaults to 256): - The number of tokens per image embedding. - boi_token_index (`int`, *optional*, defaults to 255999): - The begin-of-image token index to wrap the image prompt. - eoi_token_index (`int`, *optional*, defaults to 256000): - The end-of-image token index to wrap the image prompt. - image_token_index (`int`, *optional*, defaults to 262144): - The image token index to encode the image prompt. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - - - Example: - - ```python - >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig - - >>> # Initializing a Siglip-like vision config - >>> vision_config = SiglipVisionConfig() - - >>> # Initializing a Gemma3 Text config - >>> text_config = Gemma3TextConfig() - - >>> # Initializing a Gemma3 gemma-3-4b style configuration - >>> configuration = Gemma3Config(vision_config, text_config) - - >>> # Initializing a model from the gemma-3-4b style configuration - >>> model = Gemma3TextConfig(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "gemma3" - sub_configs = { - "text_config": Gemma3TextConfig, - "vision_config": SiglipVisionConfig, - } - - def __init__( - self, - text_config: Optional[Gemma3TextConfig] = None, - vision_config: Optional[SiglipVisionConfig] = None, - mm_tokens_per_image: int = 256, - boi_token_index: int = 255_999, - eoi_token_index: int = 256_000, - image_token_index: int = 262_144, - initializer_range: float = 0.02, - **kwargs, - ): - if text_config is None: - text_config = Gemma3TextConfig() - logger.info("text_config is None, using default Gemma3TextConfig vision config.") - elif isinstance(text_config, dict): - text_config = Gemma3TextConfig(**text_config) - - if isinstance(vision_config, dict): - vision_config = SiglipVisionConfig(**vision_config) - else: - vision_config = SiglipVisionConfig() - logger.info( - "vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited " - "to text tasks." - ) - - self.text_config = text_config - self.vision_config = vision_config - self.mm_tokens_per_image = mm_tokens_per_image - self.boi_token_index = boi_token_index - self.eoi_token_index = eoi_token_index - self.image_token_index = image_token_index - self.initializer_range = initializer_range - - super().__init__(**kwargs) - - -__all__ = ["Gemma3Config", "Gemma3TextConfig"] diff --git a/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py b/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py deleted file mode 100644 index bb833fa7c3b2..000000000000 --- a/hf-gemma3/convert_gemma3_weights_orbax_to_hf.py +++ /dev/null @@ -1,588 +0,0 @@ -# coding=utf-8 -# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint. - -python -m transformers.models.gemma3.convert_gemma3_weights_orbax_to_hf \ - --variant='gemma3_4b' \ - --tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \ - --checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \ - --output_path="$HOME/gemma3/gemma3_4b_pt_safetensors/" \ - --precision='bfloat16' -""" - -import dataclasses -from collections.abc import Iterator, Sequence -from typing import Any - -import accelerate -import numpy as np -import torch -import tree -from absl import app, flags, logging -from orbax import checkpoint as obc - -from ...image_utils import PILImageResampling -from ..gemma import GemmaTokenizerFast -from . import ( - Gemma3ForCausalLM, - Gemma3ForConditionalGeneration, - Gemma3ImageProcessor, - Gemma3Processor, -) -from .configuration_gemma3 import ( - Gemma3Config, - Gemma3TextConfig, - SiglipVisionConfig, -) - - -# ==== Internal Constants and Classes ==== - - -_CHAT_TEMPLATE = """{{ bos_token }} -{%- if messages[0]['role'] == 'system' -%} - {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%} - {%- set loop_messages = messages[1:] -%} -{%- else -%} - {%- set first_user_prefix = "" -%} - {%- set loop_messages = messages -%} -{%- endif -%} -{%- for message in loop_messages -%} - {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} - {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} - {%- endif -%} - {%- if (message['role'] == 'assistant') -%} - {%- set role = "model" -%} - {%- else -%} - {%- set role = message['role'] -%} - {%- endif -%} - {{ '' + role + '\n' + (first_user_prefix if loop.first else "") }} - {%- if message['content'] is string -%} - {{ message['content'] | trim }} - {%- elif message['content'] is iterable -%} - {%- for item in message['content'] -%} - {%- if item['type'] == 'image' -%} - {{ '' }} - {%- elif item['type'] == 'text' -%} - {{ item['text'] | trim }} - {%- endif -%} - {%- endfor -%} - {%- else -%} - {{ raise_exception("Invalid content type") }} - {%- endif -%} - {{ '\n' }} -{%- endfor -%} -{%- if add_generation_prompt -%} - {{'model\n'}} -{%- endif -%} -""" - -_DTYPES = { - "float32": torch.float32, - "bfloat16": torch.bfloat16, - "float16": torch.float16, -} - -_SIGLIP_BASE = "SigLiPFromPatches_0/siglip_encoder" -_SIGLIP_EMBEDDING = "SigLiPFromPatches_0/siglip_encoder/embedding" -_SIGLIP_TRANSFORMER_ENCODER_BLOCK = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_" -_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK) -_SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm" - -_TRANSFORMER_DECODER_BLOCK = "transformer/layer_" -_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK) -_TRANSFORMER_EMBEDDER = "transformer/embedder" -_TRANSFORMER_FINAL_NORM = "transformer/final_norm" -_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/" -_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX) - -_VISION_CONFIG = { - "hidden_size": 1152, - "intermediate_size": 4304, - "num_hidden_layers": 27, - "num_attention_heads": 16, - "num_channels": 3, - "image_size": 896, - "patch_size": 14, - "hidden_act": "gelu_pytorch_tanh", - "layer_norm_eps": 1e-6, - "attention_dropout": 0.0, - "vision_use_head": False, -} - -_VARIANT_GEMMA_3_1B = "gemma3_1b" -_VARIANT_GEMMA_3_4B = "gemma3_4b" -_VARIANT_GEMMA_3_12B = "gemma3_12b" -_VARIANT_GEMMA_3_27B = "gemma3_27b" -_VARIANTS = { - _VARIANT_GEMMA_3_1B: Gemma3Config( - text_config=Gemma3TextConfig( - vocab_size=262_144, - hidden_size=1152, - intermediate_size=6 * 1152, - num_attention_heads=4, - num_hidden_layers=26, - num_key_value_heads=1, - head_dim=256, - sliding_window=512, - rope_theta=1_000_000, # used for global RoPE only - rope_local_base_freq=10_000, - attn_logit_softcapping=None, - query_pre_attn_scalar=256, - max_position_embeddings=32_768, - ), - vision_config=None, - ), - _VARIANT_GEMMA_3_4B: Gemma3Config( - text_config=Gemma3TextConfig( - vocab_size=262_208, - hidden_size=2560, - intermediate_size=2560 * 8 // 2, - num_attention_heads=8, - head_dim=256, - num_hidden_layers=34, - num_key_value_heads=4, - sliding_window=1024, - rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only - rope_theta=1_000_000, - rope_local_base_freq=10_000, - attn_logit_softcapping=None, - query_pre_attn_scalar=256, - ), - vision_config=_VISION_CONFIG, - ), - _VARIANT_GEMMA_3_12B: Gemma3Config( - text_config=Gemma3TextConfig( - vocab_size=262_208, - hidden_size=30 * 128, - intermediate_size=30 * 128 * 8 // 2, - num_attention_heads=16, - head_dim=256, - num_hidden_layers=48, - num_key_value_heads=8, - sliding_window=1024, - rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only - rope_theta=1_000_000, - rope_local_base_freq=10_000, - attn_logit_softcapping=None, - query_pre_attn_scalar=256, - ), - vision_config=_VISION_CONFIG, - ), - _VARIANT_GEMMA_3_27B: Gemma3Config( - text_config=Gemma3TextConfig( - vocab_size=262_208, - hidden_size=42 * 128, - intermediate_size=42 * 128 * 8 // 2, - num_attention_heads=32, - num_hidden_layers=62, - num_key_value_heads=16, - head_dim=128, - sliding_window=1024, - rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only - rope_theta=1_000_000, - rope_local_base_freq=10_000, - attn_logit_softcapping=None, - query_pre_attn_scalar=(42 * 128 // 32), # 1 / sqrt(hidden_size // num_attention_heads) - ), - vision_config=_VISION_CONFIG, - ), -} - -# ==== Flags ==== - -CHECKPOINT_PATH = flags.DEFINE_string( - name="checkpoint_path", - default=None, - help="Path to the Orbax checkpoint.", - required=True, -) - -INCLUDE_CHAT_TEMPLATE = flags.DEFINE_bool( - name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer" -) - -OUTPUT_PATH = flags.DEFINE_string( - name="output_path", - default=None, - help="Path to store the HF checkpoint.", - required=True, -) - -PRECISION = flags.DEFINE_enum( - name="precision", - default=None, - help="The floating point precision (aka dtype) of the model.", - enum_values=set(_DTYPES.keys()), - required=True, -) - -_TEXT_ONLY = flags.DEFINE_bool( - name="text_only", - default=False, - help=( - "If True, the model is loaded and saved as a Gemma3ForCausalLM, " - "otherwise model saed as Gemma3ForConditionalGeneration." - ), -) - -TOKENIZER_PATH = flags.DEFINE_string( - name="tokenizer_path", - default=None, - help="Path to the SentencePiece model file.", - required=True, -) - -_VARIANT = flags.DEFINE_enum( - name="variant", - default=_VARIANT_GEMMA_3_4B, - help="The model variant to convert.", - enum_values=set(_VARIANTS.keys()), -) - - -def convert_siglip_weight( - config: SiglipVisionConfig, - paths: Sequence[str], - weights: np.ndarray, -) -> tuple[str, np.ndarray]: - path, prop = paths - normalized_path: str = "" - updated_weights: np.ndarray = None - - if path == _SIGLIP_BASE: - normalized_path = "vision_tower.vision_model.embeddings.position_embedding.weight" - updated_weights = weights.reshape(-1, config.hidden_size) - elif path == _SIGLIP_EMBEDDING: - if prop == "kernel": - normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.weight" - updated_weights = weights.transpose(3, 2, 0, 1) - elif prop == "bias": - normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.bias" - updated_weights = weights - else: - raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") - elif path.startswith(_SIGLIP_TRANSFORMER_ENCODER_BLOCK): - encoder_block_path = path[_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN:] - next_path_seperator_idx = encoder_block_path.find("/") - layer_idx = encoder_block_path[:next_path_seperator_idx] - encoder_block_path = encoder_block_path[next_path_seperator_idx:] - normalized_path = f"vision_tower.vision_model.encoder.layers.{layer_idx}" - - if encoder_block_path.startswith("/LayerNorm"): - normalized_path += ".layer_norm1" if path.endswith("_0") else ".layer_norm2" - - if prop == "scale": - normalized_path += ".weight" - updated_weights = weights.transpose() - elif prop == "bias": - normalized_path += ".bias" - updated_weights = weights - else: - raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.") - elif encoder_block_path.startswith("/MlpBlock_0"): - normalized_path += ".mlp.fc1" if "/Dense_0" in encoder_block_path else ".mlp.fc2" - - if prop == "kernel": - normalized_path += ".weight" - updated_weights = weights.transpose() - elif prop == "bias": - normalized_path += ".bias" - updated_weights = weights - else: - raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") - elif encoder_block_path.startswith("/MultiHeadDotProductAttention_0"): - if encoder_block_path.endswith("/key"): - normalized_path += ".self_attn.k_proj" - elif encoder_block_path.endswith("/out"): - normalized_path += ".self_attn.out_proj" - elif encoder_block_path.endswith("/query"): - normalized_path += ".self_attn.q_proj" - elif encoder_block_path.endswith("/value"): - normalized_path += ".self_attn.v_proj" - else: - raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer MultiHeadDotProductAttention_0.") - - if prop == "bias": - normalized_path += ".bias" - updated_weights = weights.reshape(-1, config.hidden_size).reshape(-1) - elif prop == "kernel": - normalized_path += ".weight" - updated_weights = weights.reshape(-1, config.hidden_size).transpose() - else: - raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") - else: - raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer Encoder Block.") - elif path == _SIGLIP_TRANSFORMER_ENCODER_NORM: - if prop == "scale": - normalized_path = "vision_tower.vision_model.post_layernorm.weight" - updated_weights = weights.transpose() - elif prop == "bias": - normalized_path = "vision_tower.vision_model.post_layernorm.bias" - updated_weights = weights - else: - raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.") - else: - raise ValueError(f"Unexpected path `{path}`.") - - if "vision" in normalized_path: - print(normalized_path) - return normalized_path, updated_weights - - -def convert_transformer_weights( - config: Gemma3TextConfig, - paths: Sequence[str], - weights: np.ndarray, -) -> Iterator[tuple[str, np.ndarray]]: - path, prop = paths - - if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX): - path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:] - - converted_paths: list[str] = [] - converted_weights: list[Any] = [] - - attn_head_dim = config.num_attention_heads * config.head_dim - kv_head_dim = config.num_key_value_heads * config.head_dim - - if path == _TRANSFORMER_EMBEDDER: - if prop == "input_embedding": - # Tied to language_model.lm_head.weight, assigned at the end. - converted_paths = ["language_model.model.embed_tokens.weight"] - - if not _TEXT_ONLY.value: - # Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama - pre_expansion_embeddings = weights - mu = np.mean(pre_expansion_embeddings, axis=0) - sigma = np.cov(pre_expansion_embeddings, rowvar=False, bias=True) - new_embeddings = np.random.multivariate_normal(mu, 1e-5 * sigma, size=64) - weights = np.vstack([pre_expansion_embeddings, new_embeddings]) - - converted_weights = [weights] - elif _TEXT_ONLY.value or prop in ("mm_output_embedding", "mm_input_embedding_extra"): - return zip([], []) - else: - raise ValueError(f"Unexpected member, {prop}, in Embedder.") - elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"): - if _TEXT_ONLY.value: - return zip([], []) - - if path.endswith("/mm_input_projection"): - converted_paths = ["multi_modal_projector.mm_input_projection_weight"] - converted_weights = [weights] - elif path.endswith("/mm_soft_embedding_norm"): - converted_paths = ["multi_modal_projector.mm_soft_emb_norm.weight"] - converted_weights = [weights] - else: - raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.") - elif path == _TRANSFORMER_FINAL_NORM: - converted_paths = ["language_model.model.norm.weight"] - converted_weights = [weights] - elif path.startswith(_TRANSFORMER_DECODER_BLOCK): - decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:] - next_path_seperator_idx = decoder_block_path.find("/") - layer_idx = decoder_block_path[:next_path_seperator_idx] - decoder_block_path = decoder_block_path[next_path_seperator_idx:] - - base_path = f"language_model.model.layers.{layer_idx}" - - if path.endswith("attn/attn_vec_einsum"): - converted_paths = [f"{base_path}.self_attn.o_proj.weight"] - converted_weights = [weights.transpose(2, 0, 1).reshape(config.hidden_size, attn_head_dim)] - elif path.endswith("attn/_key_norm"): - converted_paths = [f"{base_path}.self_attn.k_norm.weight"] - converted_weights = [weights] - elif path.endswith("attn/kv_einsum"): - converted_paths = [ - f"{base_path}.self_attn.k_proj.weight", - f"{base_path}.self_attn.v_proj.weight", - ] - k_proj_weights, v_proj_weights = weights - converted_weights = [ - k_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size), - v_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size), - ] - elif path.endswith("attn/q_einsum"): - converted_paths = [f"{base_path}.self_attn.q_proj.weight"] - converted_weights = [weights.transpose(0, 2, 1).reshape(attn_head_dim, config.hidden_size)] - elif path.endswith("attn/_query_norm"): - converted_paths = [f"{base_path}.self_attn.q_norm.weight"] - converted_weights = [weights] - elif path.endswith("mlp/gating_einsum"): - converted_paths = [ - f"{base_path}.mlp.gate_proj.weight", - f"{base_path}.mlp.up_proj.weight", - ] - gate_proj_weight, up_proj_weight = weights - converted_weights = [gate_proj_weight, up_proj_weight] - elif path.endswith("mlp/linear"): - converted_paths = [f"{base_path}.mlp.down_proj.weight"] - converted_weights = [weights.transpose()] - elif path.endswith("post_attention_norm"): - converted_paths = [f"{base_path}.post_attention_layernorm.weight"] - converted_weights = [weights] - elif path.endswith("post_ffw_norm"): - converted_paths = [f"{base_path}.post_feedforward_layernorm.weight"] - converted_weights = [weights] - elif path.endswith("pre_attention_norm"): - converted_paths = [f"{base_path}.input_layernorm.weight"] - converted_weights = [weights] - elif path.endswith("pre_ffw_norm"): - converted_paths = [f"{base_path}.pre_feedforward_layernorm.weight"] - converted_weights = [weights] - else: - raise ValueError(f"Unexpected path `{path}` in Decoder Block.") - else: - raise ValueError(f"Unexpected path `{path}`.") - - if (cpl := len(converted_paths)) != (cwl := len(converted_weights)): - raise ValueError( - "The `converted_paths` and `converted_weights` should be the same " - f"length. Got {cpl} and {cwl}, respectively, for {path}." - ) - - return zip(converted_paths, converted_weights) - - -@dataclasses.dataclass(frozen=True) -class ConversionResult: - state_tree: dict[str, torch.Tensor] - config: Gemma3Config - - -def convert( - checkpoint_path: str, - config: Gemma3Config, - target_dtype: torch.dtype, -) -> ConversionResult: - """Loads Orbax checkpoint from `input_path` and converts it to HF tree.""" - checkpointer = obc.PyTreeCheckpointer() - ckpt = checkpointer.restore(checkpoint_path) - hf_tree: dict[str, torch.Tensor] = {} - - def update_tree(path: str, weights: np.ndarray) -> None: - torch_tensor = torch.from_numpy(weights.astype("float32")).type(target_dtype) - logging.info( - "%s converted shape=%s with dtype=%s", - path, - weights.shape, - torch_tensor.dtype, - ) - hf_tree[path] = torch_tensor - - for paths, value in tree.flatten_with_path(ckpt): - if paths[0].startswith("SigLiPFromPatches_"): - if config.vision_config is None: - continue - - path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value) - update_tree(path, weights) - else: - for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value): - if config.vision_config is None: - path = path[len("language_model.") :] - - update_tree(path, weights) - - if config.vision_config is None: - hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"] - else: - hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"] - - return ConversionResult(state_tree=hf_tree, config=config) - - -def main(*args): - del args - - variant = _VARIANT.value - dtype = getattr(torch, PRECISION.value) - config = _VARIANTS[variant] - output_path = OUTPUT_PATH.value - - if variant == _VARIANT_GEMMA_3_1B: - flags.FLAGS.set_default(_TEXT_ONLY.name, True) - - tokenizer = GemmaTokenizerFast( - TOKENIZER_PATH.value, - add_bos_token=True, - extra_special_tokens={ - "image_token": "", # Should be ID=262_144 - "boi_token": "", # Should be ID=255_999 - "eoi_token": "", # Should be ID=256_000 - }, - ) - - if INCLUDE_CHAT_TEMPLATE.value: - # Include chat template for CausalLM models - tokenizer.chat_template = _CHAT_TEMPLATE - config.eos_token_id = [1, 106] - - if _TEXT_ONLY.value: - config.vision_config = None - tokenizer.save_pretrained(output_path) - logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path) - del tokenizer - else: - image_processor = Gemma3ImageProcessor( - image_seq_length=256, - image_mean=(0.5,) * 3, - image_std=(0.5,) * 3, - size={"height": 896, "width": 896}, - resample=PILImageResampling.BILINEAR, - ) - processor = Gemma3Processor( - image_processor=image_processor, - tokenizer=tokenizer, - ) - if INCLUDE_CHAT_TEMPLATE.value: - # Duplicate so multimodal instruct models can also be used for CausalLM - processor.chat_template = tokenizer.chat_template - - processor.save_pretrained(output_path) - logging.info("Saved Gemma3Processor for %s to %s", variant, output_path) - del processor - del tokenizer - - logging.info("Gemma 3 (%s) configured as: %s", variant, config) - logging.info("Converting Gemma 3 (%s) @ %s", variant, dtype) - result = convert(CHECKPOINT_PATH.value, config, dtype) - logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant) - - with accelerate.init_empty_weights(): - if config.vision_config is None: - model = Gemma3ForCausalLM(config=config.text_config) - else: - model = Gemma3ForConditionalGeneration(config) - - model.load_state_dict(result.state_tree, assign=True, strict=True) - model.config.torch_dtype = dtype - logging.info("Loaded Gemma 3 (%s) in Hugging Face Transformers as a %s instance.", variant, type(model).__name__) - model.save_pretrained(output_path, safe_serialization=True) - logging.info( - "Saved Gemma 3 (%s) to SafeTensors in %s using %s", - variant, - output_path, - type(model).__name__, - ) - del model - del result - - -if __name__ == "__main__": - app.run(main) diff --git a/hf-gemma3/image_processing_gemma3.py b/hf-gemma3/image_processing_gemma3.py deleted file mode 100644 index f985a9a9dd80..000000000000 --- a/hf-gemma3/image_processing_gemma3.py +++ /dev/null @@ -1,413 +0,0 @@ -# coding=utf-8 -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Image processor class for Gemma3.""" - -import itertools -import math -from typing import Dict, List, Optional, Union - -import numpy as np - -from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict -from ...image_transforms import ( - convert_to_rgb, - resize, - to_channel_dimension_format, -) -from ...image_utils import ( - IMAGENET_STANDARD_MEAN, - IMAGENET_STANDARD_STD, - ChannelDimension, - ImageInput, - PILImageResampling, - get_image_size, - infer_channel_dimension_format, - is_scaled_image, - make_nested_list_of_images, - to_numpy_array, - valid_images, - validate_preprocess_arguments, -) -from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging - - -logger = logging.get_logger(__name__) - - -if is_vision_available(): - import PIL - - -class Gemma3ImageProcessor(BaseImageProcessor): - r""" - Constructs a SigLIP image processor. - - Args: - do_resize (`bool`, *optional*, defaults to `True`): - Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by - `do_resize` in the `preprocess` method. - size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): - Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. - resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): - Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. - do_rescale (`bool`, *optional*, defaults to `True`): - Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in - the `preprocess` method. - rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): - Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` - method. - do_normalize (`bool`, *optional*, defaults to `True`): - Whether to normalize the image by the specified mean and standard deviation. Can be overridden by - `do_normalize` in the `preprocess` method. - image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): - Mean to use if normalizing the image. This is a float or list of floats the length of the number of - channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. - image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): - Standard deviation to use if normalizing the image. This is a float or list of floats the length of the - number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. - Can be overridden by the `image_std` parameter in the `preprocess` method. - do_convert_rgb (`bool`, *optional*, defaults to `True`): - Whether to convert the image to RGB. - do_pan_and_scan (`bool`, *optional*): - Whether to apply `pan_and_scan` to images. - pan_and_scan_min_crop_size (`int`, *optional*): - Minimum size of each crop in pan and scan. - pan_and_scan_max_num_crops (`int`, *optional*): - Maximum number of crops per image in pan and scan. - pan_and_scan_min_ratio_to_activate (`float`, *optional*): - Minimum aspect ratio to activate pan and scan. - """ - - model_input_names = ["pixel_values", "num_crops"] - - def __init__( - self, - do_resize: bool = True, - size: Dict[str, int] = None, - resample: PILImageResampling = PILImageResampling.BILINEAR, - do_rescale: bool = True, - rescale_factor: Union[int, float] = 1 / 255, - do_normalize: bool = True, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - do_convert_rgb: bool = None, - do_pan_and_scan: bool = None, - pan_and_scan_min_crop_size: int = None, - pan_and_scan_max_num_crops: int = None, - pan_and_scan_min_ratio_to_activate: float = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - size = size if size is not None else {"height": 224, "width": 224} - size = get_size_dict(size, default_to_square=True) - image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN - image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD - - self.do_resize = do_resize - self.size = size - self.resample = resample - self.do_rescale = do_rescale - self.rescale_factor = rescale_factor - self.do_normalize = do_normalize - self.image_mean = image_mean - self.image_std = image_std - self.do_convert_rgb = do_convert_rgb - self.do_pan_and_scan = do_pan_and_scan - self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size - self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops - self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate - - def pan_and_scan( - self, - image: np.ndarray, - pan_and_scan_min_crop_size: int, - pan_and_scan_max_num_crops: int, - pan_and_scan_min_ratio_to_activate: float, - data_format: Optional[Union[str, ChannelDimension]] = None, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - ): - """ - Pan and Scan and image, by cropping into smaller images when the aspect ratio exceeds - minumum allowed ratio. - - Args: - image (`np.ndarray`): - Image to resize. - pan_and_scan_min_crop_size (`int`, *optional*): - Minimum size of each crop in pan and scan. - pan_and_scan_max_num_crops (`int`, *optional*): - Maximum number of crops per image in pan and scan. - pan_and_scan_min_ratio_to_activate (`float`, *optional*): - Minimum aspect ratio to activate pan and scan. - data_format (`str` or `ChannelDimension`, *optional*): - The channel dimension format of the image. If not provided, it will be the same as the input image. - input_data_format (`ChannelDimension` or `str`, *optional*): - The channel dimension format of the input image. If not provided, it will be inferred. - """ - height, width = get_image_size(image) - - # Square or landscape image. - if width >= height: - # Only apply PaS if the image is sufficiently exaggerated - if width / height < pan_and_scan_min_ratio_to_activate: - return [] - - # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. - num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding. - num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w) - - # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. - num_crops_w = max(2, num_crops_w) - num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) - num_crops_h = 1 - - # Portrait image. - else: - # Only apply PaS if the image is sufficiently exaggerated - if height / width < pan_and_scan_min_ratio_to_activate: - return [] - - # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. - num_crops_h = int(math.floor(height / width + 0.5)) - num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h) - - # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. - num_crops_h = max(2, num_crops_h) - num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) - num_crops_w = 1 - - crop_size_w = int(math.ceil(width / num_crops_w)) - crop_size_h = int(math.ceil(height / num_crops_h)) - - # Don't apply PaS if crop size is too small. - if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: - return [] - - crop_positions_w = [crop_size_w * i for i in range(num_crops_w)] - crop_positions_h = [crop_size_h * i for i in range(num_crops_h)] - - if input_data_format == ChannelDimension.LAST: - image_crops = [ - image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] - for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w) - ] - else: - image_crops = [ - image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] - for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w) - ] - - return image_crops - - def _process_images_for_pan_and_scan( - self, - images: List[np.ndarray], - do_pan_and_scan: bool, - pan_and_scan_min_crop_size: int, - pan_and_scan_max_num_crops: int, - pan_and_scan_min_ratio_to_activate: float, - data_format: Optional[Union[str, ChannelDimension]] = None, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - ): - pas_images_list = [] - num_crops = [] - for image in images: - pas_images = self.pan_and_scan( - image=image, - pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, - pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, - pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, - data_format=data_format, - input_data_format=input_data_format, - ) - pas_images_list.extend([image] + pas_images) - num_crops.append(len(pas_images)) - return pas_images_list, num_crops - - @filter_out_non_signature_kwargs() - def preprocess( - self, - images: ImageInput, - do_resize: bool = None, - size: Dict[str, int] = None, - resample: PILImageResampling = None, - do_rescale: bool = None, - rescale_factor: float = None, - do_normalize: bool = None, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - do_convert_rgb: bool = None, - do_pan_and_scan: bool = None, - pan_and_scan_min_crop_size: int = None, - pan_and_scan_max_num_crops: int = None, - pan_and_scan_min_ratio_to_activate: float = None, - ) -> PIL.Image.Image: - """ - Preprocess an image or batch of images. - - Args: - images (`ImageInput`): - Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If - passing in images with pixel values between 0 and 1, set `do_rescale=False`. - do_resize (`bool`, *optional*, defaults to `self.do_resize`): - Whether to resize the image. - size (`Dict[str, int]`, *optional*, defaults to `self.size`): - Size of the image after resizing. - resample (`int`, *optional*, defaults to `self.resample`): - Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only - has an effect if `do_resize` is set to `True`. - do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): - Whether to rescale the image. - rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): - Rescale factor to rescale the image by if `do_rescale` is set to `True`. - do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): - Whether to normalize the image. - image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): - Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. - image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): - Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to - `True`. - return_tensors (`str` or `TensorType`, *optional*): - The type of tensors to return. Can be one of: - - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): - The channel dimension format for the output image. Can be one of: - - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - - Unset: Use the channel dimension format of the input image. - input_data_format (`ChannelDimension` or `str`, *optional*): - The channel dimension format for the input image. If unset, the channel dimension format is inferred - from the input image. Can be one of: - - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): - Whether to convert the image to RGB. - do_pan_and_scan (`bool`, *optional*, defaults to `self.do_convert_rgb`): - Whether to apply `pan_and_scan` to images. - pan_and_scan_min_crop_size (`int`, *optional*, defaults to `self.pan_and_scan_min_crop_size`): - Minimum size of each crop in pan and scan. - pan_and_scan_max_num_crops (`int`, *optional*, defaults to `self.pan_and_scan_max_num_crops`): - Maximum number of crops per image in pan and scan. - pan_and_scan_min_ratio_to_activate (`float`, *optional*, defaults to `self.pan_and_scan_min_ratio_to_activate`): - Minimum aspect ratio to activate pan and scan. - """ - do_resize = do_resize if do_resize is not None else self.do_resize - size = size if size is not None else self.size - size = get_size_dict(size, param_name="size", default_to_square=False) - resample = resample if resample is not None else self.resample - do_rescale = do_rescale if do_rescale is not None else self.do_rescale - rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor - do_normalize = do_normalize if do_normalize is not None else self.do_normalize - image_mean = image_mean if image_mean is not None else self.image_mean - image_std = image_std if image_std is not None else self.image_std - do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - do_pan_and_scan = do_pan_and_scan if do_pan_and_scan is not None else self.do_pan_and_scan - pan_and_scan_min_crop_size = ( - pan_and_scan_min_crop_size if pan_and_scan_min_crop_size is not None else self.pan_and_scan_min_crop_size - ) - pan_and_scan_max_num_crops = ( - pan_and_scan_max_num_crops if pan_and_scan_max_num_crops is not None else self.pan_and_scan_max_num_crops - ) - pan_and_scan_min_ratio_to_activate = ( - pan_and_scan_min_ratio_to_activate - if pan_and_scan_min_ratio_to_activate is not None - else self.pan_and_scan_min_ratio_to_activate - ) - - images_list = make_nested_list_of_images(images) - - if not valid_images(images_list[0]): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) - - validate_preprocess_arguments( - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - do_resize=do_resize, - size=size, - resample=resample, - ) - if do_convert_rgb: - images_list = [[convert_to_rgb(image) for image in images] for images in images_list] - - # All transformations expect numpy arrays. - images_list = [[to_numpy_array(image) for image in images] for images in images_list] - - if do_rescale and is_scaled_image(images_list[0][0]): - logger.warning_once( - "It looks like you are trying to rescale already rescaled images. If the input" - " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." - ) - - if input_data_format is None: - # We assume that all images have the same channel dimension format. - input_data_format = infer_channel_dimension_format(images_list[0][0]) - - if do_pan_and_scan: - images_list_and_num_crops = [ - self._process_images_for_pan_and_scan( - images=images, - do_pan_and_scan=do_pan_and_scan, - pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, - pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, - pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, - data_format=data_format, - input_data_format=input_data_format, - ) - for images in images_list - ] - images_list = [images for images, _ in images_list_and_num_crops] - num_crops = [num_crops for _, num_crops in images_list_and_num_crops] - else: - num_crops = [[0] for images in images_list] - - processed_images = [] - for images in images_list: - for image in images: - if do_resize: - height, width = size["height"], size["width"] - image = resize( - image=image, size=(height, width), resample=resample, input_data_format=input_data_format - ) - - if do_rescale: - image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) - - if do_normalize: - image = self.normalize( - image=image, mean=image_mean, std=image_std, input_data_format=input_data_format - ) - - image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) - processed_images.append(image) - - data = {"pixel_values": processed_images, "num_crops": num_crops} - return BatchFeature(data=data, tensor_type=return_tensors) - - -__all__ = ["Gemma3ImageProcessor"] diff --git a/hf-gemma3/image_processing_gemma3_fast.py b/hf-gemma3/image_processing_gemma3_fast.py deleted file mode 100644 index fd4848ce21da..000000000000 --- a/hf-gemma3/image_processing_gemma3_fast.py +++ /dev/null @@ -1,387 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Fast Image processor class for SigLIP.""" - -import itertools -import math -from functools import partial -from typing import List, Optional, Union - -from ...image_processing_utils_fast import ( - BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, - BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, - BaseImageProcessorFast, - BatchFeature, - DefaultFastImageProcessorInitKwargs, - DefaultFastImageProcessorPreprocessKwargs, - get_size_dict, - group_images_by_shape, - reorder_images, -) -from ...image_utils import ( - IMAGENET_STANDARD_MEAN, - IMAGENET_STANDARD_STD, - ChannelDimension, - ImageInput, - SizeDict, - get_image_size, - make_nested_list_of_images, - validate_kwargs, -) -from ...processing_utils import Unpack -from ...utils import ( - TensorType, - add_start_docstrings, - is_torch_available, - is_torchvision_available, - is_torchvision_v2_available, - is_vision_available, - logging, -) - - -if is_vision_available(): - from ...image_utils import PILImageResampling - -if is_torch_available(): - import torch - -if is_torchvision_available(): - if is_torchvision_v2_available(): - from torchvision.transforms.v2 import functional as F - else: - from torchvision.transforms import functional as F - -logger = logging.get_logger(__name__) - - -class Gemma3FastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs): - do_pan_and_scan: Optional[bool] - pan_and_scan_min_crop_size: Optional[int] - pan_and_scan_max_num_crops: Optional[int] - pan_and_scan_min_ratio_to_activate: Optional[float] - - -class Gemma3FastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs): - do_pan_and_scan: Optional[bool] - pan_and_scan_min_crop_size: Optional[int] - pan_and_scan_max_num_crops: Optional[int] - pan_and_scan_min_ratio_to_activate: Optional[float] - - -@add_start_docstrings( - "Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of Pan adn Scan cropping method.", - BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, - """ - do_pan_and_scan (`bool`, *optional*): - Whether to apply `pan_and_scan` to images. - pan_and_scan_min_crop_size (`int`, *optional*): - Minimum size of each crop in pan and scan. - pan_and_scan_max_num_crops (`int`, *optional*): - Maximum number of crops per image in pan and scan. - pan_and_scan_min_ratio_to_activate (`float`, *optional*): - Minimum aspect ratio to activate pan and scan. - """, -) -class Gemma3ImageProcessorFast(BaseImageProcessorFast): - resample = PILImageResampling.BILINEAR - image_mean = IMAGENET_STANDARD_MEAN - image_std = IMAGENET_STANDARD_STD - size = {"height": 224, "width": 224} - default_to_square = True - do_resize = True - do_rescale = True - do_normalize = True - do_pan_and_scan = None - pan_and_scan_min_crop_size = None - pan_and_scan_max_num_crops = None - pan_and_scan_min_ratio_to_activate = None - valid_init_kwargs = Gemma3FastImageProcessorInitKwargs - valid_preprocess_kwargs = Gemma3FastImageProcessorPreprocessKwargs - - def __init__(self, **kwargs: Unpack[Gemma3FastImageProcessorInitKwargs]): - super().__init__(**kwargs) - - def _prepare_images_structure( - self, - images: ImageInput, - ) -> ImageInput: - """ - Prepare the images structure for processing. - - Args: - images (`ImageInput`): - The input images to process. - - Returns: - `ImageInput`: The images with a valid nesting. - """ - return make_nested_list_of_images(images) - - def _prepare_input_images( - self, - images: ImageInput, - do_convert_rgb: bool = None, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - device: Optional["torch.device"] = None, - ) -> List["torch.Tensor"]: - """ - Prepare the input images for processing. - """ - batch_images = self._prepare_images_structure(images) - process_image_fn = partial( - self._process_image, - do_convert_rgb=do_convert_rgb, - input_data_format=input_data_format, - device=device, - ) - # todo: yoni - check if we can parallelize this efficiently - batch_processed_images = [] - for image_list in batch_images: - processed_images = [] - for image in image_list: - processed_images.append(process_image_fn(image)) - batch_processed_images.append(processed_images) - - return batch_processed_images - - def pan_and_scan( - self, - image: "torch.Tensor", - pan_and_scan_min_crop_size: int, - pan_and_scan_max_num_crops: int, - pan_and_scan_min_ratio_to_activate: float, - ): - """ - Pan and Scan and image, by cropping into smaller images when the aspect ratio exceeds - minumum allowed ratio. - - Args: - image (`torch.Tensor`): - Image to resize. - pan_and_scan_min_crop_size (`int`, *optional*): - Minimum size of each crop in pan and scan. - pan_and_scan_max_num_crops (`int`, *optional*): - Maximum number of crops per image in pan and scan. - pan_and_scan_min_ratio_to_activate (`float`, *optional*): - Minimum aspect ratio to activate pan and scan. - """ - height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST) - - # Square or landscape image. - if width >= height: - # Only apply PaS if the image is sufficiently exaggerated - if width / height < pan_and_scan_min_ratio_to_activate: - return [] - - # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. - num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding. - num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w) - - # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. - num_crops_w = max(2, num_crops_w) - num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) - num_crops_h = 1 - - # Portrait image. - else: - # Only apply PaS if the image is sufficiently exaggerated - if height / width < pan_and_scan_min_ratio_to_activate: - return [] - - # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. - num_crops_h = int(math.floor(height / width + 0.5)) - num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h) - - # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. - num_crops_h = max(2, num_crops_h) - num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) - num_crops_w = 1 - - crop_size_w = int(math.ceil(width / num_crops_w)) - crop_size_h = int(math.ceil(height / num_crops_h)) - - # Don't apply PaS if crop size is too small. - if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: - return [] - - crop_positions_w = [crop_size_w * i for i in range(num_crops_w)] - crop_positions_h = [crop_size_h * i for i in range(num_crops_h)] - - return [ - image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] - for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w) - ] - - def _process_images_for_pan_and_scan( - self, - images: List["torch.Tensor"], - do_pan_and_scan: bool, - pan_and_scan_min_crop_size: int, - pan_and_scan_max_num_crops: int, - pan_and_scan_min_ratio_to_activate: float, - ): - pas_images_list = [] - num_crops = [] - for image in images: - pas_images = self.pan_and_scan( - image=image, - pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, - pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, - pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, - ) - pas_images_list.extend([image] + pas_images) - num_crops.append(len(pas_images)) - return pas_images_list, num_crops - - @add_start_docstrings( - BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, - """ - do_pan_and_scan (`bool`, *optional*): - Whether to apply `pan_and_scan` to images. - pan_and_scan_min_crop_size (`int`, *optional*): - Minimum size of each crop in pan and scan. - pan_and_scan_max_num_crops (`int`, *optional*): - Maximum number of crops per image in pan and scan. - pan_and_scan_min_ratio_to_activate (`float`, *optional*): - Minimum aspect ratio to activate pan and scan. - """, - ) - def preprocess( - self, - images: ImageInput, - **kwargs: Unpack[Gemma3FastImageProcessorPreprocessKwargs], - ) -> BatchFeature: - validate_kwargs( - captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_preprocess_kwargs.__annotations__.keys() - ) - # Set default kwargs from self. This ensures that if a kwarg is not provided - # by the user, it gets its default value from the instance, or is set to None. - for kwarg_name in self.valid_preprocess_kwargs.__annotations__: - kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) - - # Extract parameters that are only used for preparing the input images - do_convert_rgb = kwargs.pop("do_convert_rgb") - input_data_format = kwargs.pop("input_data_format") - device = kwargs.pop("device") - - images = self._prepare_input_images( - images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device - ) - - # Pop kwargs that need further processing or won't be used in _preprocess - default_to_square = kwargs.pop("default_to_square") - size = kwargs.pop("size") - crop_size = kwargs.pop("crop_size") - image_mean = kwargs.pop("image_mean") - image_std = kwargs.pop("image_std") - data_format = kwargs.pop("data_format") - resample = kwargs.pop("resample") - - # Make hashable for cache - size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None - crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None - image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean - image_std = tuple(image_std) if isinstance(image_std, list) else image_std - - image_mean, image_std, interpolation = self._prepare_process_arguments( - size=size, - crop_size=crop_size, - resample=resample, - image_mean=image_mean, - image_std=image_std, - data_format=data_format if data_format is not None else ChannelDimension.FIRST, - device=images[0][0].device, - do_resize=kwargs.get("do_resize"), - do_center_crop=kwargs.get("do_center_crop"), - do_rescale=kwargs.get("do_rescale"), - rescale_factor=kwargs.get("rescale_factor"), - do_normalize=kwargs.get("do_normalize"), - return_tensors=kwargs.get("return_tensors"), - ) - - return self._preprocess( - images=images, - size=size, - crop_size=crop_size, - interpolation=interpolation, - image_mean=image_mean, - image_std=image_std, - **kwargs, - ) - - def _preprocess( - self, - images: List[List["torch.Tensor"]], - do_resize: bool, - size: SizeDict, - do_pan_and_scan: Optional[bool], - pan_and_scan_min_crop_size: Optional[int], - pan_and_scan_max_num_crops: Optional[int], - pan_and_scan_min_ratio_to_activate: Optional[float], - interpolation: Optional["F.InterpolationMode"], - do_center_crop: bool, - crop_size: SizeDict, - do_rescale: bool, - rescale_factor: float, - do_normalize: bool, - image_mean: Optional[Union[float, List[float]]], - image_std: Optional[Union[float, List[float]]], - return_tensors: Optional[Union[str, TensorType]], - ) -> BatchFeature: - processed_images = [] - batch_num_crops = [] - - for image_list in images: - if do_pan_and_scan: - images_list, num_crops = self._process_images_for_pan_and_scan( - images=image_list, - do_pan_and_scan=do_pan_and_scan, - pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, - pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, - pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, - ) - else: - num_crops = [[0] for images in images_list] - - # Group images by size for batched processing - processed_image_patches_grouped = {} - grouped_image_patches, grouped_image_patches_index = group_images_by_shape(images_list) - for shape, stacked_image_patches in grouped_image_patches.items(): - if do_resize: - stacked_image_patches = self.resize( - image=stacked_image_patches, - size=size, - interpolation=interpolation, - ) - # Fused rescale and normalize - stacked_image_patches = self.rescale_and_normalize( - stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std - ) - processed_image_patches_grouped[shape] = stacked_image_patches - processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index) - processed_image_patches = ( - torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches - ) - processed_images.extend(processed_image_patches) - batch_num_crops.extend(num_crops) - - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - return BatchFeature( - data={"pixel_values": processed_images, "num_crops": batch_num_crops}, tensor_type=return_tensors - ) - - -__all__ = ["Gemma3ImageProcessorFast"] diff --git a/hf-gemma3/modeling_gemma3.py b/hf-gemma3/modeling_gemma3.py deleted file mode 100644 index fc4e686fbbbe..000000000000 --- a/hf-gemma3/modeling_gemma3.py +++ /dev/null @@ -1,1448 +0,0 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_gemma3.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# coding=utf-8 -# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import copy -from collections.abc import Callable -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from ...activations import ACT2FN -from ...cache_utils import Cache, HybridCache, StaticCache -from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import Unpack -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_torchdynamo_compiling, - logging, - replace_return_docstrings, -) -from ...utils.deprecation import deprecate_kwarg -from ..auto import AutoModel, AutoModelForCausalLM -from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig - - -logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "Gemma3Config" - - -@dataclass -class Gemma3CausalLMOutputWithPast(ModelOutput): - """ - Base class for Gemma3 causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder after projecting last hidden state. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[torch.FloatTensor] = None - - -class Gemma3TextScaledWordEmbedding(nn.Embedding): - """ - This module overrides nn.Embeddings' forward by multiplying with embeddings scale. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): - super().__init__(num_embeddings, embedding_dim, padding_idx) - self.embed_scale = embed_scale - - def forward(self, input_ids: torch.Tensor): - return super().forward(input_ids) * self.embed_scale - - -class Gemma3MLP(nn.Module): - def __init__(self, config: Gemma3TextConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_activation] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class Gemma3RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.zeros(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()) - # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) - # See https://github.com/huggingface/transformers/pull/29402 - output = output * (1.0 + self.weight.float()) - return output.type_as(x) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.eps}" - - -class Gemma3RotaryEmbedding(nn.Module): - def __init__(self, config: Gemma3TextConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - dropout: float = 0.0, - scaling: Optional[float] = None, - softcap: Optional[float] = None, - **kwargs, -) -> Tuple[torch.Tensor, torch.Tensor]: - if scaling is None: - scaling = module.head_dim**-0.5 - - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - - if softcap is not None: - attn_weights = attn_weights / softcap - attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * softcap - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -class Gemma3Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: Gemma3TextConfig, layer_idx: int): - super().__init__() - self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = config.query_pre_attn_scalar**-0.5 - self.attention_dropout = self.config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - self.attn_logit_softcapping = self.config.attn_logit_softcapping - self.sliding_window = config.sliding_window if self.is_sliding else None - - self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) - self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: torch.Tensor, - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - query_states = self.q_norm(query_states) - key_states = self.k_norm(key_states) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "cache_position": cache_position, - "sliding_window": self.sliding_window, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": - seq_len = attention_mask.shape[-1] - key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " - "Falling back to eager attention. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask.to(query_states), - dropout=self.attention_dropout if self.training else 0.0, - scaling=self.scaling, - sliding_window=self.sliding_window, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class Gemma3DecoderLayer(nn.Module): - def __init__(self, config: Gemma3TextConfig, layer_idx: int): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.layer_idx = layer_idx - self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) - self.mlp = Gemma3MLP(config) - self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.is_sliding = self.self_attn.is_sliding - self.sliding_window = config.sliding_window - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings_global: torch.Tensor, - position_embeddings_local: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - last_cache_position: int = 0, - **kwargs, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # In prefill, we may be larger than sliding window - effective_seq_len = max(cache_position.shape[0], self.sliding_window) - # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), - # thus we must slice from the right (at most `effective_seq_len` elements) - if self.config._attn_implementation == "flash_attention_2": - attention_mask = attention_mask[:, -effective_seq_len:] - # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice - # from the left, with an offset if we are beyond the sliding window - else: - min_dtype = torch.finfo(attention_mask.dtype).min - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - # In case we are beyond the sliding window, we need to correctly offset the mask slicing - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo - offset = last_cache_position - effective_seq_len - # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = max(0, offset) - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # apply global RoPE to non-sliding layer only - if self.self_attn.is_sliding: - position_embeddings = position_embeddings_local - else: - position_embeddings = position_embeddings_global - - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - position_embeddings=position_embeddings, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.pre_feedforward_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = self.post_feedforward_layernorm(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -GEMMA3_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Gemma3Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Gemma3 Model outputting raw hidden-states without any specific head on top.", - GEMMA3_START_DOCSTRING, -) -class Gemma3PreTrainedModel(PreTrainedModel): - config_class = Gemma3Config - base_model_prefix = "language_model" - supports_gradient_checkpointing = True - _no_split_modules = [ - "Gemma3DecoderLayer", - "SiglipVisionEmbeddings", - "SiglipEncoderLayer", - "SiglipMultiheadAttentionPoolingHead", - ] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_attention_backend = True - - def _init_weights(self, module): - # important: this ported version of Gemma2 isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -GEMMA3_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare Gemma3Text Model outputting raw hidden-states without any specific head on top.", - GEMMA3_START_DOCSTRING, -) -class Gemma3TextModel(Gemma3PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma3TextDecoderLayer`] - - Args: - config: Gemma3TextConfig - """ - - config_class = Gemma3TextConfig - - def __init__(self, config: Gemma3TextConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402 - self.embed_tokens = Gemma3TextScaledWordEmbedding( - config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 - ) - self.layers = nn.ModuleList( - [Gemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Gemma3RotaryEmbedding(config=config) - self.gradient_checkpointing = False - - # TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas - # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE - config = copy.deepcopy(config) - config.rope_theta = config.rope_local_base_freq - config.rope_scaling = {"rope_type": "default"} - self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - last_cache_position: Optional[int] = None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None and not self.training: - batch_size, seq_len, _ = inputs_embeds.shape - past_key_values = HybridCache( - self.config, - max_batch_size=batch_size, - max_cache_len=seq_len, - dtype=inputs_embeds.dtype, - ) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - if last_cache_position is None: - last_cache_position = 0 - if attention_mask is not None: - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) - last_cache_position = ( - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() - ) - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values, - output_attentions, - ) - - # embed positions - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings_global = self.rotary_emb(hidden_states, position_ids) - position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - position_embeddings_global, - position_embeddings_local, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - last_cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings_global=position_embeddings_global, - position_embeddings_local=position_embeddings_local, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - last_cache_position=last_cache_position, - **flash_attn_kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - output = BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - return output if return_dict else output.to_tuple() - - @torch.no_grad() - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: HybridCache, - output_attentions: bool, - ): - # Flash Attention currently doesn't support static cache but Gemma3Text work only with static cache. - # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape - # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible - # as it doesn't cause dynamic control issues. - if self.config._attn_implementation == "flash_attention_2": - return attention_mask - - dtype, device = input_tensor.dtype, input_tensor.device - sequence_length = input_tensor.shape[1] - if isinstance(past_key_values, (HybridCache, StaticCache)): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - config_class = Gemma3TextConfig - base_model_prefix = "language_model" - - def __init__(self, config: Gemma3TextConfig): - super().__init__(config) - self.model = Gemma3TextModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") - @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **loss_kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Gemma3ForCausalLM - - >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") - - >>> prompt = "What is your favorite condiment?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" - ```""" - - if self.training and self.config._attn_implementation != "eager": - logger.warning_once( - "It is strongly recommended to train Gemma3 models with the `eager` attention implementation " - f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." - ) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **loss_kwargs, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - if self.config.final_logit_softcapping is not None: - logits = logits / self.config.final_logit_softcapping - logits = torch.tanh(logits) - logits = logits * self.config.final_logit_softcapping - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - logits_to_keep=None, - **kwargs, - ): - # Overwritten: has a special cache type, `HybridCache` - - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - position_ids=position_ids, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - **kwargs, - ) - - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 - if logits_to_keep is None: - _ = model_inputs.pop("logits_to_keep", None) - - if ( - isinstance(past_key_values, HybridCache) - and attention_mask.ndim == 2 - and not self.config._attn_implementation == "flash_attention_2" - ): - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device - - attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - model_inputs["attention_mask"] = attention_mask - - return model_inputs - - -class Gemma3MultiModalProjector(nn.Module): - def __init__(self, config: Gemma3Config): - super().__init__() - - self.mm_input_projection_weight = nn.Parameter( - torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) - ) - - self.mm_soft_emb_norm = Gemma3RMSNorm( - config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps - ) - - self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) - self.tokens_per_side = int(config.mm_tokens_per_image**0.5) - self.kernel_size = self.patches_per_image // self.tokens_per_side - self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) - - def forward(self, vision_outputs: torch.Tensor): - batch_size, _, seq_length = vision_outputs.shape - - reshaped_vision_outputs = vision_outputs.transpose(1, 2) - reshaped_vision_outputs = reshaped_vision_outputs.reshape( - batch_size, seq_length, self.patches_per_image, self.patches_per_image - ) - reshaped_vision_outputs = reshaped_vision_outputs.contiguous() - - pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) - pooled_vision_outputs = pooled_vision_outputs.flatten(2) - pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) - - normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) - - projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight) - return projected_vision_outputs.type_as(vision_outputs) - - -@add_start_docstrings( - """The GEMMA3 model which consists of a vision backbone and a language model.""", - GEMMA3_START_DOCSTRING, -) -class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): - def __init__(self, config: Gemma3Config): - super().__init__(config) - self.vision_tower = AutoModel.from_config(config=config.vision_config) - self.multi_modal_projector = Gemma3MultiModalProjector(config) - self.vocab_size = config.text_config.vocab_size - - language_model = AutoModelForCausalLM.from_config(config=config.text_config) - - if language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] - self.language_model = language_model - - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - self.post_init() - - def get_input_embeddings(self): - return self.language_model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - - def _update_causal_mask( - self, - attention_mask, - token_type_ids, - past_key_values, - cache_position, - input_tensor, - is_training: bool = False, - ): - if self.config.text_config._attn_implementation == "flash_attention_2": - return attention_mask - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted - # form and requires no inversion or slicing. - return attention_mask - - using_static_cache = isinstance(past_key_values, StaticCache) - min_dtype = torch.finfo(self.dtype).min - inputs_lead_dim, sequence_length = input_tensor.shape[:2] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - elif isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else cache_position[0] + sequence_length + 1 - ) - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - return attention_mask - - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device - ) - - # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - - # Apply bidirectional mask on images if token type ids are provided - if token_type_ids is not None and sequence_length != 1: - token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) - token_type_mask[token_type_ids == 0] = False # if text token do not change anything - token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) - causal_mask = causal_mask.clone() - causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( - token_type_mask, 0.0 - ) - - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - - # Then apply padding mask (will mask pad tokens) - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - def get_image_features(self, pixel_values: torch.Tensor): - """ - Projects the last hidden state from the vision model into language model space. - - Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) - The tensors corresponding to the input images. - Returns: - image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). - """ - vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state - image_features = self.multi_modal_projector(vision_outputs) - return image_features - - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") - @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, - token_type_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, - ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration - - >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") - >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") - - >>> prompt = "answer en Where is the cow standing?" - >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_length=30) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "answer en Where is the cow standing?\nbeach" - ```""" - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - is_training = token_type_ids is not None and labels is not None - - # Replace image id woth PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_index >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_index - llm_input_ids = input_ids.clone() - llm_input_ids[special_image_mask] = 0 - else: - llm_input_ids = input_ids - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed - - # Merge text and images - if pixel_values is not None: - image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) - ) - else: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - # mask out pad-token-ids in labels for BC - if labels is not None and self.pad_token_id in labels: - logger.warning_once( - "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " - "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", - ) - labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) - - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training - ) - outputs = self.language_model( - attention_mask=causal_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - **lm_kwargs, - ) - - logits = outputs.logits - loss = None - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - shift_logits = logits[..., :-1, :] - shift_labels = labels[..., 1:] - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) - shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() - else: - shift_logits = shift_logits.contiguous() - shift_labels = shift_labels.contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - - flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) - flat_labels = shift_labels.view(-1).to(shift_logits.device) - loss = loss_fct(flat_logits, flat_labels) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Gemma3CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - pixel_values=None, - attention_mask=None, - token_type_ids=None, - use_cache=True, - logits_to_keep=None, - labels=None, - **kwargs, - ): - # Overwritten -- custom `position_ids` and `pixel_values` handling - model_inputs = self.language_model.prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - cache_position=cache_position, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - token_type_ids=token_type_ids, - **kwargs, - ) - - # position_ids in Gemma3 are 1-indexed - if model_inputs.get("position_ids") is not None: - model_inputs["position_ids"] += 1 - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore - # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always - if cache_position[0] == 0: - model_inputs["pixel_values"] = pixel_values - is_training = token_type_ids is not None and labels is not None - if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): - input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training - ) - model_inputs["attention_mask"] = causal_mask - - return model_inputs - - -__all__ = ["Gemma3PreTrainedModel", "Gemma3TextModel", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration"] diff --git a/hf-gemma3/modular_gemma3.py b/hf-gemma3/modular_gemma3.py deleted file mode 100644 index fa3107ab5baa..000000000000 --- a/hf-gemma3/modular_gemma3.py +++ /dev/null @@ -1,845 +0,0 @@ -# coding=utf-8 -# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import copy -from collections.abc import Callable -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.utils.checkpoint - -from ...cache_utils import Cache, HybridCache, StaticCache -from ...configuration_utils import PretrainedConfig -from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import ( - BaseModelOutputWithPast, - ModelOutput, -) -from ...modeling_rope_utils import rope_config_validation -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...processing_utils import Unpack -from ...utils import ( - logging, -) -from ..bart.modeling_bart import BartScaledWordEmbedding -from ..gemma2.configuration_gemma2 import Gemma2Config -from ..gemma2.modeling_gemma2 import ( - Gemma2Attention, - Gemma2ForCausalLM, - Gemma2MLP, - Gemma2Model, - Gemma2PreTrainedModel, - Gemma2RMSNorm, - Gemma2RotaryEmbedding, - apply_rotary_pos_emb, - eager_attention_forward, -) -from ..paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration -from ..siglip import SiglipVisionConfig - - -_CHECKPOINT_FOR_DOC = "google/gemma-3-4b" -_CONFIG_FOR_DOC = "Gemma3Config" - -logger = logging.get_logger(__name__) - -GEMMA3_INPUTS_DOCSTRING = "" - - -class Gemma3TextConfig(Gemma2Config): - r""" - This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Gemma3Text-7B. - e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b) - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 262208): - Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Gemma3TextModel`] - hidden_size (`int`, *optional*, defaults to 2304): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 9216): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 26): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 8): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*, defaults to 4): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - head_dim (`int`, *optional*, defaults to 256): - The attention head dimension. - hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` - if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. - max_position_embeddings (`int`, *optional*, defaults to 131072): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*, defaults to 0): - Padding token id. - eos_token_id (`int`, *optional*, defaults to 1): - End of stream token id. - bos_token_id (`int`, *optional*, defaults to 2): - Beginning of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 1000000.0): - The base period of the RoPE embeddings. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - query_pre_attn_scalar (`float`, *optional*, defaults to 256): - Scaling factor used on the attention scores - sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the - size of the sliding window. - final_logit_softcapping (`float`, *optional*): - Scaling factor when applying tanh softcapping on the logits. - attn_logit_softcapping (`float`, *optional*): - Scaling factor when applying tanh softcapping on the attention scores. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'llama3'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - rope_local_base_freq (float, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings for local attention. - sliding_window_pattern (`int`, *optional*, defaults to 6): - Pattern for the sliding window attention. - - ```python - >>> from transformers import Gemma3TextModel, Gemma3TextConfig - >>> # Initializing a Gemma3Text gemma3_text-7b style configuration - >>> configuration = Gemma3TextConfig() - >>> # Initializing a model from the gemma3_text-7b style configuration - >>> model = Gemma3TextModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ``` - rope_local_base_freq (float, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings for local attention. - sliding_window_pattern (`int`, *optional*, defaults to 6): - Pattern for the sliding window attention. - """ - - model_type = "gemma3_text" - - def __init__( - self, - vocab_size=262_208, - rope_theta=1_000_000.0, - rope_scaling=None, - rope_local_base_freq=10_000.0, - sliding_window_pattern=6, - max_position_embeddings=131_072, - final_logit_softcapping=None, - attn_logit_softcapping=None, - **super_kwargs, - ): - super().__init__(self, **super_kwargs) - - self.rope_local_base_freq = rope_local_base_freq - # For configuring HybridCache to work with 5:1 attention pattern - self.sliding_window_pattern = sliding_window_pattern - self.rope_scaling = rope_scaling - rope_config_validation(self) - - -class Gemma3Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an - Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the PaliGemma-2B. - - e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - text_config (`Union[Gemma3TextConfig, dict]`, *optional*): - The config object of the text backbone. - vision_config (`Union[AutoConfig, dict]`, *optional*): - Custom vision config or dict. - mm_tokens_per_image (`int`, *optional*, defaults to 256): - The number of tokens per image embedding. - boi_token_index (`int`, *optional*, defaults to 255999): - The begin-of-image token index to wrap the image prompt. - eoi_token_index (`int`, *optional*, defaults to 256000): - The end-of-image token index to wrap the image prompt. - image_token_index (`int`, *optional*, defaults to 262144): - The image token index to encode the image prompt. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - - - Example: - - ```python - >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig - - >>> # Initializing a Siglip-like vision config - >>> vision_config = SiglipVisionConfig() - - >>> # Initializing a Gemma3 Text config - >>> text_config = Gemma3TextConfig() - - >>> # Initializing a Gemma3 gemma-3-4b style configuration - >>> configuration = Gemma3Config(vision_config, text_config) - - >>> # Initializing a model from the gemma-3-4b style configuration - >>> model = Gemma3TextConfig(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "gemma3" - sub_configs = { - "text_config": Gemma3TextConfig, - "vision_config": SiglipVisionConfig, - } - - def __init__( - self, - text_config: Optional[Gemma3TextConfig] = None, - vision_config: Optional[SiglipVisionConfig] = None, - mm_tokens_per_image: int = 256, - boi_token_index: int = 255_999, - eoi_token_index: int = 256_000, - image_token_index: int = 262_144, - initializer_range: float = 0.02, - **kwargs, - ): - if text_config is None: - text_config = Gemma3TextConfig() - logger.info("text_config is None, using default Gemma3TextConfig vision config.") - elif isinstance(text_config, dict): - text_config = Gemma3TextConfig(**text_config) - - if isinstance(vision_config, dict): - vision_config = SiglipVisionConfig(**vision_config) - else: - vision_config = SiglipVisionConfig() - logger.info( - "vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited " - "to text tasks." - ) - - self.text_config = text_config - self.vision_config = vision_config - self.mm_tokens_per_image = mm_tokens_per_image - self.boi_token_index = boi_token_index - self.eoi_token_index = eoi_token_index - self.image_token_index = image_token_index - self.initializer_range = initializer_range - - super().__init__(**kwargs) - - -@dataclass -class Gemma3CausalLMOutputWithPast(ModelOutput): - """ - Base class for Gemma3 causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder after projecting last hidden state. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[torch.FloatTensor] = None - - -class Gemma3TextScaledWordEmbedding(BartScaledWordEmbedding): - pass - - -class Gemma3MLP(Gemma2MLP): - def __init__(self, config: Gemma3TextConfig): - super().__init__(config) - - -class Gemma3RMSNorm(Gemma2RMSNorm): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - - -class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding): - def __init__(self, config: Gemma3TextConfig, device=None): - super().__init__(config) - - -# Weird way to inherit but otherwise the sliding window gets defined first and can't access `is_sliding` -class Gemma3Attention(Gemma2Attention): - def __init__(self, config: Gemma3TextConfig, layer_idx: int): - self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) - - super().__init__() - self.sliding_window = config.sliding_window if self.is_sliding else None - - self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) - self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: torch.Tensor, - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - query_states = self.q_norm(query_states) - key_states = self.k_norm(key_states) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "cache_position": cache_position, - "sliding_window": self.sliding_window, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": - seq_len = attention_mask.shape[-1] - key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " - "Falling back to eager attention. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask.to(query_states), - dropout=self.attention_dropout if self.training else 0.0, - scaling=self.scaling, - sliding_window=self.sliding_window, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class Gemma3DecoderLayer(nn.Module): - def __init__(self, config: Gemma3TextConfig, layer_idx: int): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.layer_idx = layer_idx - self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) - self.mlp = Gemma3MLP(config) - self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.is_sliding = self.self_attn.is_sliding - self.sliding_window = config.sliding_window - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings_global: torch.Tensor, - position_embeddings_local: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - last_cache_position: int = 0, - **kwargs, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # In prefill, we may be larger than sliding window - effective_seq_len = max(cache_position.shape[0], self.sliding_window) - # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), - # thus we must slice from the right (at most `effective_seq_len` elements) - if self.config._attn_implementation == "flash_attention_2": - attention_mask = attention_mask[:, -effective_seq_len:] - # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice - # from the left, with an offset if we are beyond the sliding window - else: - min_dtype = torch.finfo(attention_mask.dtype).min - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - # In case we are beyond the sliding window, we need to correctly offset the mask slicing - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo - offset = last_cache_position - effective_seq_len - # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = max(0, offset) - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # apply global RoPE to non-sliding layer only - if self.self_attn.is_sliding: - position_embeddings = position_embeddings_local - else: - position_embeddings = position_embeddings_global - - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - position_embeddings=position_embeddings, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.pre_feedforward_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = self.post_feedforward_layernorm(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -GEMMA3_START_DOCSTRING = None - - -class Gemma3PreTrainedModel(Gemma2PreTrainedModel): - base_model_prefix = "language_model" - _no_split_modules = [ - "Gemma3DecoderLayer", - "SiglipVisionEmbeddings", - "SiglipEncoderLayer", - "SiglipMultiheadAttentionPoolingHead", - ] - - def _init_weights(self, module): - # important: this ported version of Gemma2 isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -class Gemma3TextModel(Gemma2Model): - config_class = Gemma3TextConfig - - def __init__(self, config: Gemma3TextConfig): - super().__init__(config) - - # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402 - self.embed_tokens = Gemma3TextScaledWordEmbedding( - config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 - ) - - # TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas - # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE - config = copy.deepcopy(config) - config.rope_theta = config.rope_local_base_freq - config.rope_scaling = {"rope_type": "default"} - self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - last_cache_position: Optional[int] = None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None and not self.training: - batch_size, seq_len, _ = inputs_embeds.shape - past_key_values = HybridCache( - self.config, - max_batch_size=batch_size, - max_cache_len=seq_len, - dtype=inputs_embeds.dtype, - ) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - if last_cache_position is None: - last_cache_position = 0 - if attention_mask is not None: - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) - last_cache_position = ( - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() - ) - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values, - output_attentions, - ) - - # embed positions - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings_global = self.rotary_emb(hidden_states, position_ids) - position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - position_embeddings_global, - position_embeddings_local, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - last_cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings_global=position_embeddings_global, - position_embeddings_local=position_embeddings_local, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - last_cache_position=last_cache_position, - **flash_attn_kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - output = BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - return output if return_dict else output.to_tuple() - - -class Gemma3ForCausalLM(Gemma2ForCausalLM): - config_class = Gemma3TextConfig - base_model_prefix = "language_model" - - def __init__(self, config: Gemma3TextConfig): - super().__init__(config) - self.model = Gemma3TextModel(config) - - -class Gemma3MultiModalProjector(nn.Module): - def __init__(self, config: Gemma3Config): - super().__init__() - - self.mm_input_projection_weight = nn.Parameter( - torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) - ) - - self.mm_soft_emb_norm = Gemma3RMSNorm( - config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps - ) - - self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) - self.tokens_per_side = int(config.mm_tokens_per_image**0.5) - self.kernel_size = self.patches_per_image // self.tokens_per_side - self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) - - def forward(self, vision_outputs: torch.Tensor): - batch_size, _, seq_length = vision_outputs.shape - - reshaped_vision_outputs = vision_outputs.transpose(1, 2) - reshaped_vision_outputs = reshaped_vision_outputs.reshape( - batch_size, seq_length, self.patches_per_image, self.patches_per_image - ) - reshaped_vision_outputs = reshaped_vision_outputs.contiguous() - - pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) - pooled_vision_outputs = pooled_vision_outputs.flatten(2) - pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) - - normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) - - projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight) - return projected_vision_outputs.type_as(vision_outputs) - - -class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): - def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: - """ - Projects the last hidden state from the vision model into language model space. - - Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) - The tensors corresponding to the input images. - Returns: - image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). - """ - vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state - image_features = self.multi_modal_projector(vision_outputs) - return image_features - - def _update_causal_mask( - self, - attention_mask, - token_type_ids, - past_key_values, - cache_position, - input_tensor, - is_training: bool = False, - ): - if self.config.text_config._attn_implementation == "flash_attention_2": - return attention_mask - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted - # form and requires no inversion or slicing. - return attention_mask - - using_static_cache = isinstance(past_key_values, StaticCache) - min_dtype = torch.finfo(self.dtype).min - inputs_lead_dim, sequence_length = input_tensor.shape[:2] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - elif isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else cache_position[0] + sequence_length + 1 - ) - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - return attention_mask - - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device - ) - - # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - - # Apply bidirectional mask on images if token type ids are provided - if token_type_ids is not None and sequence_length != 1: - token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) - token_type_mask[token_type_ids == 0] = False # if text token do not change anything - token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) - causal_mask = causal_mask.clone() - causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( - token_type_mask, 0.0 - ) - - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - - # Then apply padding mask (will mask pad tokens) - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -__all__ = [ - "Gemma3Config", - "Gemma3TextConfig", - "Gemma3PreTrainedModel", # noqa: F822 - "Gemma3TextModel", - "Gemma3ForCausalLM", - "Gemma3ForConditionalGeneration", -] diff --git a/hf-gemma3/processing_gemma3.py b/hf-gemma3/processing_gemma3.py deleted file mode 100644 index e82b609bdb10..000000000000 --- a/hf-gemma3/processing_gemma3.py +++ /dev/null @@ -1,172 +0,0 @@ -# coding=utf-8 -# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import re -from typing import List, Optional, Union - -import numpy as np - -from ...feature_extraction_utils import BatchFeature -from ...image_utils import ImageInput, make_nested_list_of_images -from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack -from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...utils import to_py_obj - - -class Gemma3ImagesKwargs(ImagesKwargs): - do_pan_and_scan: Optional[bool] - pan_and_scan_min_crop_size: Optional[int] - pan_and_scan_max_num_crops: Optional[int] - pan_and_scan_min_ratio_to_activate: Optional[float] - do_convert_rgb: Optional[bool] - - -class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): - images_kwargs: Gemma3ImagesKwargs - _defaults = { - "text_kwargs": { - "padding": False, - }, - "images_kwargs": { - "do_pan_and_scan": False, - "pan_and_scan_min_crop_size": 256, - "pan_and_scan_max_num_crops": 4, - "pan_and_scan_min_ratio_to_activate": 1.2, - }, - } - - -class Gemma3Processor(ProcessorMixin): - attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template", "image_seq_length"] - image_processor_class = "AutoImageProcessor" - tokenizer_class = "AutoTokenizer" - - def __init__( - self, - image_processor, - tokenizer, - chat_template=None, - image_seq_length: int = 256, - **kwargs, - ): - self.image_seq_length = image_seq_length - self.image_token_id = tokenizer.image_token_id - self.boi_token = tokenizer.boi_token - image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length) - self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n" - - super().__init__( - image_processor=image_processor, - tokenizer=tokenizer, - chat_template=chat_template, - **kwargs, - ) - - def __call__( - self, - images: ImageInput = None, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - videos=None, - audio=None, - **kwargs: Unpack[Gemma3ProcessorKwargs], - ) -> BatchFeature: - if text is None and images is None: - raise ValueError("Provide at least one of `text` or `images`.") - - output_kwargs = self._merge_kwargs( - Gemma3ProcessorKwargs, - tokenizer_init_kwargs=self.tokenizer.init_kwargs, - **kwargs, - ) - - if isinstance(text, str): - text = [text] - elif not isinstance(text, list) and not isinstance(text[0], str): - raise ValueError("Invalid input text. Please provide a string, or a list of strings") - - image_inputs = {} - if images is not None: - batched_images = make_nested_list_of_images(images) - image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"]) - - # Create empty text to be replaced with placeholders - if not text: - text = [" ".join([self.boi_token] * len(images)) for images in batched_images] - - if len(batched_images) != len(text): - raise ValueError( - f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})." - ) - - # Replace image tokens by the full expanded sequence - batch_num_crops = to_py_obj(image_inputs.pop("num_crops")) - text_with_crops = text - for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)): - image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)] - - if len(images) != len(image_indexes): - raise ValueError( - f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images." - ) - - # Insert additional image tokens for Pan-and-Scan crops - for num, idx in reversed(list(zip(num_crops, image_indexes))): - if num: - formatted_image_text = ( - f"Here is the original image {self.boi_token} and here are some crops to help you see better " - + " ".join([self.boi_token] * num) - ) - prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :] - text_with_crops[batch_idx] = prompt - - # Expand placeholder image tokens to the full image token sequence - text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text] - - return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) - text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np") - - # Add token type ids manually, as tokenizer can't do arbitrary position token types - array_ids = np.array(text_inputs["input_ids"]) - mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) - mm_token_type_ids[array_ids == self.image_token_id] = 1 - text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs - text_inputs["token_type_ids"] = mm_token_type_ids.tolist() - return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) - - # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma - def batch_decode(self, *args, **kwargs): - """ - This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please - refer to the docstring of this method for more information. - """ - return self.tokenizer.batch_decode(*args, **kwargs) - - # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma - def decode(self, *args, **kwargs): - """ - This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to - the docstring of this method for more information. - """ - return self.tokenizer.decode(*args, **kwargs) - - @property - def model_input_names(self): - tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"] - image_processor_input_names = self.image_processor.model_input_names - return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) - - -__all__ = ["Gemma3Processor"] From f0f8e9d120ecbb77a6de2d5844ebe52d0a2d4167 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 16:28:54 -0700 Subject: [PATCH 19/52] revert Signed-off-by: Woosuk Kwon --- examples/offline_inference/vision_language.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index cbb686c32c93..1b57637c3bf8 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -700,7 +700,7 @@ def main(args): # We set temperature to 0.2 so that outputs can be different # even when all prompts are identical when running batch inference. - sampling_params = SamplingParams(temperature=0.0, + sampling_params = SamplingParams(temperature=0.2, max_tokens=64, stop_token_ids=stop_token_ids) From 28e757b328f34387d2d4afc10056dd9536c86b0e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 16:30:18 -0700 Subject: [PATCH 20/52] add placeholder str Signed-off-by: Woosuk Kwon --- vllm/entrypoints/chat_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index b05842dd27d3..c0c6a819091f 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -419,6 +419,8 @@ def _placeholder_str(self, modality: ModalityStr, return "" if model_type == "aria": return "<|fim_prefix|><|img|><|fim_suffix|>" + if model_type == "gemma3": + return "" raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": From 285ffc45d194a409cdcce546266391d92e2300ff Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 19:35:20 -0700 Subject: [PATCH 21/52] minor Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3_mm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index b5cd54283533..012a6a3664ff 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -7,9 +7,7 @@ import torch from torch import nn -# FIXME -# from transformers import Gemma3Config -from transformers.models.gemma3.configuration_gemma3 import Gemma3Config +from transformers import Gemma3Config from transformers import BatchFeature, ProcessorMixin from vllm.config import VllmConfig From 713766b03653ddf73c95ae78073dd5e4d22e3121 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 20:05:37 -0700 Subject: [PATCH 22/52] Add comments Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3.py | 40 +++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 9399d28ea6cd..77829478f836 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -190,9 +190,38 @@ def forward( attn_output = self.attn(q, k, v) if not kwargs.get("has_images", False): + # Fast path for text-only inputs. The performance for the text-only + # inputs are not affected by the naive attention below. output, _ = self.o_proj(attn_output) return output + # NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens + # that correspond to the same image while using causal attention + # otherwise. Current attention backends cannot handle this pattern, so + # we temporarily use a naive attention implementation with mask tensors. + + # We intentionally keep the attention backend as-is and only override + # `attn_output` with the naive implementation's output. This minimizes + # changes to existing model runners and attention backends. The call to + # `self.attn(q, k, v)` is only used to populate the KV cache - its + # output is discarded and overwritten below. While this duplicates + # computation, it maintains compatibility. + # TODO(woosuk): Optimize by implementing custom attention kernels. + attn_output = self.naive_attn_with_masks( + q, k, v, out=attn_output, **kwargs) + output, _ = self.o_proj(attn_output) + return output + + def naive_attn_with_masks( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # NOTE(woosuk): As described in the comment above, this code is not + # meant to be performant. It is only meant to be correct. q = q.view(-1, self.num_heads, self.head_dim) # Expand the key and value to handle GQA. num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -219,16 +248,13 @@ def forward( key = key.transpose(1, 2) value = value.transpose(1, 2) - out = F.scaled_dot_product_attention( + output = F.scaled_dot_product_attention( query, key, value, attn_mask, self.scaling, ) - - out = out.transpose(1, 2).flatten(-2, -1) - attn_output[start_idx:end_idx] = out + output = output.transpose(1, 2).flatten(-2, -1) + out[start_idx:end_idx] = output start_idx = end_idx - - output, _ = self.o_proj(attn_output) - return output + return out class Gemma3DecoderLayer(nn.Module): From 67460869dde0a1e504c49a8c97715e1ca23aab41 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 20:06:08 -0700 Subject: [PATCH 23/52] Minor Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3_mm.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 012a6a3664ff..d9138adc071b 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -1,6 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -import math -import re from collections.abc import Iterable, Mapping, Sequence from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -33,6 +31,7 @@ logger = init_logger(__name__) +# TODO(woosuk): Get these values from the model config. NUM_TOKENS_PER_IMAGE = 256 BOI_TOKEN = "" @@ -104,7 +103,7 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> BatchFeature: - # FIXME(woosuk): Currently, PaS is not supported. + # TODO(woosuk): Support pan-and-scan. img_kwargs = mm_kwargs.get("images_kwargs", {}) if img_kwargs: img_kwargs["do_pan_and_scan"] = False @@ -252,7 +251,7 @@ def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Gemma3ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) - assert image_embeds is None + assert image_embeds is None, "Gemma3 does not support image_embeds." if pixel_values is None: return None @@ -302,12 +301,7 @@ def get_input_embeddings( if multimodal_embeddings is None: inputs_embeds = self.language_model.get_input_embeddings(input_ids) else: - # NOTE(woosuk): Gemma3 uses vocab_size as the image token index. - # To avoid out-of-range error in the embedding layer, we replace the - # image token index with 0. - safe_input_ids = torch.where( - input_ids == self.config.image_token_index, 0, input_ids) - inputs_embeds = self.language_model.get_input_embeddings(safe_input_ids) + inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, self.config.image_token_index) From 6fa03362550b3d660ee2e1e02da1f9b75c480b12 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 20:12:42 -0700 Subject: [PATCH 24/52] cleanup Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3_mm.py | 96 +++++++++++++------------ 1 file changed, 52 insertions(+), 44 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index d9138adc071b..e231aa105056 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -209,6 +209,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config + self.sliding_window = config.text_config.interleaved_sliding_window self.vision_tower = SiglipVisionModel(config.vision_config, quant_config, @@ -315,59 +316,18 @@ def forward(self, **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: if intermediate_tensors is not None: inputs_embeds = None - kwargs.clear() # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - kwargs.clear() inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) if vision_embeddings is not None: - kwargs["has_images"] = True - start_idices = (positions == 0).cpu().nonzero() - num_seqs = len(start_idices) - seq_lens = [] - for i in range(num_seqs): - start_idx = start_idices[i].item() - if i < num_seqs - 1: - end_idx = start_idices[i + 1].item() - else: - end_idx = len(input_ids) - seq_lens.append(end_idx - start_idx) - kwargs["seq_lens"] = seq_lens - - global_attn_masks = [] - local_attn_masks = [] - for seq_len in seq_lens: - input_token_ids = input_ids[start_idx:end_idx] - global_attn_mask = torch.empty( - 1, 1, seq_len, seq_len, - dtype=vision_embeddings.dtype, - device=vision_embeddings.device, - ) - global_attn_mask.fill_(float("-inf")) - # Fill the lower triangle with 0. - global_attn_mask = global_attn_mask.triu(diagonal=1) - - img_mask = torch.zeros_like(global_attn_mask) - img_pos = (input_token_ids == self.config.image_token_index) - img_mask[:, :, :, img_pos] += 1 - img_mask[:, :, img_pos, :] += 1 - global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) - global_attn_masks.append(global_attn_mask) - - SLIDING_WINDOW_SIZE = 1024 - local_attn_mask = torch.ones_like(global_attn_mask) - local_attn_mask = torch.tril(local_attn_mask, diagonal=-SLIDING_WINDOW_SIZE) - local_attn_mask = torch.where(local_attn_mask == 0, global_attn_mask, float("-inf")) - local_attn_masks.append(local_attn_mask) - - kwargs["global_attn_masks"] = global_attn_masks - kwargs["local_attn_masks"] = local_attn_masks - + self.prepare_attn_masks(input_ids, positions, + mask_dtype=vision_embeddings.dtype, + **kwargs) input_ids = None hidden_states = self.language_model.model(input_ids, @@ -378,6 +338,54 @@ def forward(self, return hidden_states + def prepare_attn_masks( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mask_dtype: torch.dtype, + **kwargs, + ) -> None: + kwargs["has_images"] = True + start_idices = (positions == 0).cpu().nonzero() + num_seqs = len(start_idices) + seq_lens = [] + for i in range(num_seqs): + start_idx = start_idices[i].item() + if i < num_seqs - 1: + end_idx = start_idices[i + 1].item() + else: + end_idx = len(input_ids) + seq_lens.append(end_idx - start_idx) + kwargs["seq_lens"] = seq_lens + + global_attn_masks = [] + local_attn_masks = [] + for seq_len in seq_lens: + input_token_ids = input_ids[start_idx:end_idx] + global_attn_mask = torch.empty( + 1, 1, seq_len, seq_len, + dtype=mask_dtype, + device=input_ids.device, + ) + global_attn_mask.fill_(float("-inf")) + # Fill the lower triangle with 0. + global_attn_mask = global_attn_mask.triu(diagonal=1) + + img_mask = torch.zeros_like(global_attn_mask) + img_pos = (input_token_ids == self.config.image_token_index) + img_mask[:, :, :, img_pos] += 1 + img_mask[:, :, img_pos, :] += 1 + global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) + global_attn_masks.append(global_attn_mask) + + local_attn_mask = torch.ones_like(global_attn_mask) + local_attn_mask = torch.tril(local_attn_mask, diagonal=-self.sliding_window_size) + local_attn_mask = torch.where(local_attn_mask == 0, global_attn_mask, float("-inf")) + local_attn_masks.append(local_attn_mask) + + kwargs["global_attn_masks"] = global_attn_masks + kwargs["local_attn_masks"] = local_attn_masks + def compute_logits( self, hidden_states: torch.Tensor, From 0384ceb7d6e90617b3c6d4002296b66df1a4e658 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 20:36:00 -0700 Subject: [PATCH 25/52] tmp Signed-off-by: Woosuk Kwon --- examples/offline_inference/basic/basic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index a6e96c0bb433..7162997056d5 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -13,7 +13,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="gg-hf-g/gemma-3-4b-it", tensor_parallel_size=2) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -21,4 +21,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") From 2beb199003ceaa874cdec7233a9ac7b53d5f216c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 20:44:31 -0700 Subject: [PATCH 26/52] minor Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3.py | 19 +++++--- vllm/model_executor/models/gemma3_mm.py | 65 +++++++++++++++---------- 2 files changed, 51 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 77829478f836..f1ecf7fa821d 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 # Copyright 2025 The vLLM team. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. # @@ -16,9 +17,9 @@ from typing import Iterable, Optional, Set, Tuple, Union import torch -from torch import nn import torch.nn.functional as F -from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig +from torch import nn +from transformers import Gemma3TextConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -138,7 +139,6 @@ def __init__(self, # TODO(woosuk): Add reference to the original HF implementation. layer_idx = extract_layer_index(prefix) self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) - # Initialize the rotary embedding. if self.is_sliding: # Local attention. Override the values in config.json. @@ -207,8 +207,11 @@ def forward( # output is discarded and overwritten below. While this duplicates # computation, it maintains compatibility. # TODO(woosuk): Optimize by implementing custom attention kernels. - attn_output = self.naive_attn_with_masks( - q, k, v, out=attn_output, **kwargs) + attn_output = self.naive_attn_with_masks(q, + k, + v, + out=attn_output, + **kwargs) output, _ = self.o_proj(attn_output) return output @@ -249,7 +252,11 @@ def naive_attn_with_masks( value = value.transpose(1, 2) output = F.scaled_dot_product_attention( - query, key, value, attn_mask, self.scaling, + query, + key, + value, + attn_mask, + self.scaling, ) output = output.transpose(1, 2).flatten(-2, -1) out[start_idx:end_idx] = output diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index e231aa105056..0fa306e32b5a 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -1,33 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 -from collections.abc import Iterable, Mapping, Sequence -from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set, + Tuple, TypedDict, Union) import torch from torch import nn -from transformers import Gemma3Config -from transformers import BatchFeature, ProcessorMixin +from transformers import BatchFeature, Gemma3Config, ProcessorMixin from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) -from vllm.multimodal.parse import (ImageSize, MultiModalDataItems) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings, flatten_bn) +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -149,29 +146,32 @@ def get_replacement_gemma3(item_idx: int): class Gemma3MultiModalProjector(nn.Module): + def __init__(self, config: Gemma3Config): super().__init__() self.mm_input_projection_weight = nn.Parameter( - torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) - ) + torch.zeros(config.vision_config.hidden_size, + config.text_config.hidden_size)) self.mm_soft_emb_norm = GemmaRMSNorm( - config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps - ) + config.vision_config.hidden_size, + eps=config.vision_config.layer_norm_eps) - self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) + self.patches_per_image = int(config.vision_config.image_size // + config.vision_config.patch_size) self.tokens_per_side = int(config.mm_tokens_per_image**0.5) self.kernel_size = self.patches_per_image // self.tokens_per_side - self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, + stride=self.kernel_size) def forward(self, vision_outputs: torch.Tensor): batch_size, _, seq_length = vision_outputs.shape reshaped_vision_outputs = vision_outputs.transpose(1, 2) reshaped_vision_outputs = reshaped_vision_outputs.reshape( - batch_size, seq_length, self.patches_per_image, self.patches_per_image - ) + batch_size, seq_length, self.patches_per_image, + self.patches_per_image) reshaped_vision_outputs = reshaped_vision_outputs.contiguous() pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) @@ -180,7 +180,8 @@ def forward(self, vision_outputs: torch.Tensor): normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) - projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight) + projected_vision_outputs = torch.matmul( + normed_vision_outputs, self.mm_input_projection_weight) return projected_vision_outputs.type_as(vision_outputs) @@ -258,7 +259,7 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, torch.Tensor): raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") + f"Got type: {type(pixel_values)}") pixel_values = flatten_bn(pixel_values) return Gemma3ImagePixelInputs( @@ -325,7 +326,8 @@ def forward(self, inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) if vision_embeddings is not None: - self.prepare_attn_masks(input_ids, positions, + self.prepare_attn_masks(input_ids, + positions, mask_dtype=vision_embeddings.dtype, **kwargs) input_ids = None @@ -346,6 +348,8 @@ def prepare_attn_masks( **kwargs, ) -> None: kwargs["has_images"] = True + # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. + # This is a HACK. Fix this. start_idices = (positions == 0).cpu().nonzero() num_seqs = len(start_idices) seq_lens = [] @@ -362,8 +366,12 @@ def prepare_attn_masks( local_attn_masks = [] for seq_len in seq_lens: input_token_ids = input_ids[start_idx:end_idx] + # Create a global causal mask. global_attn_mask = torch.empty( - 1, 1, seq_len, seq_len, + 1, + 1, + seq_len, + seq_len, dtype=mask_dtype, device=input_ids.device, ) @@ -371,6 +379,7 @@ def prepare_attn_masks( # Fill the lower triangle with 0. global_attn_mask = global_attn_mask.triu(diagonal=1) + # Consider the bidirectional attention between image tokens. img_mask = torch.zeros_like(global_attn_mask) img_pos = (input_token_ids == self.config.image_token_index) img_mask[:, :, :, img_pos] += 1 @@ -378,11 +387,13 @@ def prepare_attn_masks( global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) global_attn_masks.append(global_attn_mask) + # Create a local causal mask with sliding window (1024). local_attn_mask = torch.ones_like(global_attn_mask) - local_attn_mask = torch.tril(local_attn_mask, diagonal=-self.sliding_window_size) - local_attn_mask = torch.where(local_attn_mask == 0, global_attn_mask, float("-inf")) + local_attn_mask = torch.tril(local_attn_mask, + diagonal=-self.sliding_window_size) + local_attn_mask = torch.where(local_attn_mask == 0, + global_attn_mask, float("-inf")) local_attn_masks.append(local_attn_mask) - kwargs["global_attn_masks"] = global_attn_masks kwargs["local_attn_masks"] = local_attn_masks From 64ef15f63da18dc6dbfa53a33df67116daa70ae0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 20:55:02 -0700 Subject: [PATCH 27/52] Docs Signed-off-by: Woosuk Kwon --- docs/source/models/supported_models.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index e46934b9caeb..70ec270423a4 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -267,6 +267,11 @@ See [this page](#generative-models) for more information on how to use generativ * `google/gemma-2-9b`, `google/gemma-2-27b`, etc. * ✅︎ * ✅︎ +- * `Gemma3ForCausalLM` + * Gemma3 + * `google/gemma-3-1b-it`, etc. + * ✅︎ + * ✅︎ - * `GlmForCausalLM` * GLM-4 * `THUDM/glm-4-9b-chat-hf`, etc. @@ -752,6 +757,13 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `Gemma3ForConditionalGeneration` + * Gemma3 + * T + I + * `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. + * ✅︎ + * ✅︎ + * ✅︎ - * `GLM4VForCausalLM`^ * GLM-4V * T + I From d92c7c1966f1335bfcf529060c8c5f159c47d93e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 20:57:34 -0700 Subject: [PATCH 28/52] ruff Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f8c1515e54eb..5dd3aa2973cd 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -162,7 +162,7 @@ "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), - "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), + "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), From 82acdcdde9b71d40aeb52f5a1fd31741a3219453 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 20:58:45 -0700 Subject: [PATCH 29/52] Update transformers version Signed-off-by: Woosuk Kwon --- requirements/common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/common.txt b/requirements/common.txt index 13a06011e409..e4d508bc7e7d 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -5,7 +5,7 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.48.2 # Required for Bamba model and Transformers backend. +transformers >= 4.50.0 # Required for Gemma3. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. From 29434439ea7d2a12660684649fa0991e921518c2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 21:12:02 -0700 Subject: [PATCH 30/52] minor Signed-off-by: Woosuk Kwon --- vllm/config.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index fced8958c94e..3d3252240fc4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -350,7 +350,12 @@ def __init__( if self.enforce_eager is None: self.enforce_eager = False - interleaved_attn_models = ["gemma2", "gemma3", "gemma3_text", "cohere2"] + interleaved_attn_models = [ + "gemma2", + "gemma3", # Gemma3 1B + "gemma3_text", # Gemma3 4/9/12/27B + "cohere2", + ] sliding_window = getattr(self.hf_text_config, "sliding_window", None) has_interleaved_attention = (sliding_window is not None) and ( isinstance(sliding_window, list) or From 635e1a9d264fe984090ea7c60e2624494dade198 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 21:54:27 -0700 Subject: [PATCH 31/52] Minor Signed-off-by: Woosuk Kwon --- vllm/config.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 3d3252240fc4..77eac0619e49 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -350,12 +350,7 @@ def __init__( if self.enforce_eager is None: self.enforce_eager = False - interleaved_attn_models = [ - "gemma2", - "gemma3", # Gemma3 1B - "gemma3_text", # Gemma3 4/9/12/27B - "cohere2", - ] + interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"] sliding_window = getattr(self.hf_text_config, "sliding_window", None) has_interleaved_attention = (sliding_window is not None) and ( isinstance(sliding_window, list) or From 02bf606b016a43041eff849fc3cdacc455054df3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 11 Mar 2025 23:47:29 -0700 Subject: [PATCH 32/52] gg-hf-g -> google Signed-off-by: Woosuk Kwon --- examples/offline_inference/basic/basic.py | 4 ++-- examples/offline_inference/vision_language.py | 2 +- examples/offline_inference/vision_language_multi_image.py | 2 +- tests/models/multimodal/processing/test_common.py | 2 +- tests/models/registry.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 7162997056d5..a6e96c0bb433 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -13,7 +13,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="gg-hf-g/gemma-3-4b-it", tensor_parallel_size=2) +llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -21,4 +21,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index ef1999c51c96..326ae31c81f2 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -121,7 +121,7 @@ def run_fuyu(questions: list[str], modality: str): def run_gemma3(question: str, modality: str): assert modality == "image" prompt = f" {question}" - model_name = "gg-hf-g/gemma-3-4b-it" + model_name = "google/gemma-3-4b-it" llm = LLM(model=model_name, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 9fc889a286a3..27c32e0987c4 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -81,7 +81,7 @@ def load_deepseek_vl2(question: str, image_urls: list[str]): def load_gemma3(question, image_urls: list[str]) -> ModelRequestData: - model_name = "gg-hf-g/gemma-3-4b-it" + model_name = "google/gemma-3-4b-it" llm = LLM(model=model_name, max_model_len=8192, max_num_seqs=2, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 529ea5110407..467114eedb01 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -162,7 +162,7 @@ def _test_processing_correctness( "deepseek-ai/deepseek-vl2-tiny", "microsoft/Florence-2-base", "adept/fuyu-8b", - "gg-hf-g/gemma-3-4b-it", + "google/gemma-3-4b-it", "THUDM/glm-4v-9b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", diff --git a/tests/models/registry.py b/tests/models/registry.py index f81013f99cd3..78d3b8b0baac 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -241,7 +241,7 @@ def check_available_online( "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), - "Gemma3ForConditionalGeneration": _HfExamplesInfo("gg-hf-g/gemma-3-4b-it"), + "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 From 7097fa88092da3a366773f5ef4750915947a10a4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 01:14:17 -0700 Subject: [PATCH 33/52] Fix example Signed-off-by: Woosuk Kwon --- examples/offline_inference/vision_language.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 326ae31c81f2..429b30832a74 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -118,14 +118,15 @@ def run_fuyu(questions: list[str], modality: str): return llm, prompts, stop_token_ids -def run_gemma3(question: str, modality: str): +# Gemma 3 +def run_gemma3(questions: list[str], modality: str): assert modality == "image" - prompt = f" {question}" + prompts = [f" {question}" for question in questions] model_name = "google/gemma-3-4b-it" llm = LLM(model=model_name, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # GLM-4v From ac72f69dc50ad94705a5a4194d4aa3ac7bd29825 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 01:14:52 -0700 Subject: [PATCH 34/52] minor Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3_mm.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 0fa306e32b5a..9d9015a89aa1 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -102,10 +102,7 @@ def _call_hf_processor( ) -> BatchFeature: # TODO(woosuk): Support pan-and-scan. img_kwargs = mm_kwargs.get("images_kwargs", {}) - if img_kwargs: - img_kwargs["do_pan_and_scan"] = False - else: - img_kwargs = {"do_pan_and_scan": False} + img_kwargs["do_pan_and_scan"] = False mm_kwargs["images_kwargs"] = img_kwargs return super()._call_hf_processor( prompt=prompt, From 5ab60d014e9c4a6db9b74a2886898cb7ccf0156b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 01:23:18 -0700 Subject: [PATCH 35/52] minor Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3_mm.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 9d9015a89aa1..181136c4e5c4 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -4,7 +4,16 @@ import torch from torch import nn -from transformers import BatchFeature, Gemma3Config, ProcessorMixin +from transformers import BatchFeature, ProcessorMixin + +try: + from transformers import Gemma3Config +except ImportError as e: + raise ImportError( + "To use `Gemma3ForConditionalGeneration`, you have to install " + "Hugging Face Transformers library from source via " + "`pip install git+https://github.com/huggingface/transformers`." + ) from e from vllm.config import VllmConfig from vllm.logger import init_logger @@ -215,11 +224,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix, "vision_tower")) self.multi_modal_projector = Gemma3MultiModalProjector(config) - config.text_config.architectures = ["Gemma3ForCausalLM"] self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), + architectures=["Gemma3ForCausalLM"], ) logit_scale = getattr(config, "logit_scale", 1.0) self.language_model.logits_processor.scale *= logit_scale From 3c42695fed502b2ac12754bd3f5e9056a5b5f5e8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 01:23:30 -0700 Subject: [PATCH 36/52] fix docs Signed-off-by: Woosuk Kwon --- docs/source/models/supported_models.md | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 70ec270423a4..81cc35a23c05 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -763,7 +763,7 @@ See [this page](#generative-models) for more information on how to use generativ * `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. * ✅︎ * ✅︎ - * ✅︎ + * ✅︎\* - * `GLM4VForCausalLM`^ * GLM-4V * T + I @@ -949,6 +949,20 @@ For more details, please see: To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`. ::: +:::{note} +Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. +However, for text + image inputs, only V0 supports it correctly. +V1 does not strictly follow the original attention in Gemma 3. + +Specifically, the model uses bidirectional attention only for the image tokens. +Unfortunately, this attention pattern is not supported by any of the current attention backends. +Therefore, we temporarily use the naive PyTorch SDPA with masking tensors **in V0**. +This could lead to significant memory usage for long prompts (w/ images). + +In V1, we currently do not use the bidirectional attention for image tokens, and use the causal attention for all tokens. +The model still generates reasonable outputs, but this needs to be fixed to get the full accuracy when the input includes images. +::: + ### Pooling Models See [this page](pooling-models) for more information on how to use pooling models. From 883c656d8214e5ff1b84e8988977b31071270a99 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 01:25:07 -0700 Subject: [PATCH 37/52] minor docs Signed-off-by: Woosuk Kwon --- docs/source/models/supported_models.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 81cc35a23c05..d8b35823f048 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -954,7 +954,8 @@ Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. However, for text + image inputs, only V0 supports it correctly. V1 does not strictly follow the original attention in Gemma 3. -Specifically, the model uses bidirectional attention only for the image tokens. +Specifically, the model uses bidirectional attention between the image tokens while +using causal attention otherwise. Unfortunately, this attention pattern is not supported by any of the current attention backends. Therefore, we temporarily use the naive PyTorch SDPA with masking tensors **in V0**. This could lead to significant memory usage for long prompts (w/ images). From 1d2a06439e15f98cac7dc98e747931604eba0bb8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 01:28:21 -0700 Subject: [PATCH 38/52] polish Signed-off-by: Woosuk Kwon --- docs/source/models/supported_models.md | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index d8b35823f048..74f921c71f06 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -951,17 +951,20 @@ To use Qwen2.5-VL series models, you have to install Hugging Face Transformers l :::{note} Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. -However, for text + image inputs, only V0 supports it correctly. -V1 does not strictly follow the original attention in Gemma 3. +However, there are differences in how they handle text + image inputs: -Specifically, the model uses bidirectional attention between the image tokens while -using causal attention otherwise. -Unfortunately, this attention pattern is not supported by any of the current attention backends. -Therefore, we temporarily use the naive PyTorch SDPA with masking tensors **in V0**. -This could lead to significant memory usage for long prompts (w/ images). +V0 correctly implements the model's attention pattern: +- Uses bidirectional attention between the image tokens corresponding to the same image +- Uses causal attention for other tokens +- Implemented via (naive) PyTorch SDPA with masking tensors +- Note: May use significant memory for long prompts with image -In V1, we currently do not use the bidirectional attention for image tokens, and use the causal attention for all tokens. -The model still generates reasonable outputs, but this needs to be fixed to get the full accuracy when the input includes images. +V1 currently uses a simplified attention pattern: +- Uses causal attention for all tokens, including image tokens +- Generates reasonable outputs but does not match the original model's attention for text+image inputs +- Will be updated in the future to support the correct behavior + +This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. ::: ### Pooling Models From 0230237a2a5606d895d0ca2bae6451266d8c96d5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 01:28:59 -0700 Subject: [PATCH 39/52] polish Signed-off-by: Woosuk Kwon --- docs/source/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 74f921c71f06..4631e1a95ea3 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -961,7 +961,7 @@ V0 correctly implements the model's attention pattern: V1 currently uses a simplified attention pattern: - Uses causal attention for all tokens, including image tokens -- Generates reasonable outputs but does not match the original model's attention for text+image inputs +- Generates reasonable outputs but does not match the original model's attention for text + image inputs - Will be updated in the future to support the correct behavior This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. From 32ebaf1b42f89b6093e2113243f4f0ac059bb96a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 01:31:01 -0700 Subject: [PATCH 40/52] fix reqs Signed-off-by: Woosuk Kwon --- requirements/common.txt | 2 +- vllm/model_executor/models/gemma3_mm.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements/common.txt b/requirements/common.txt index e4d508bc7e7d..13a06011e409 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -5,7 +5,7 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.50.0 # Required for Gemma3. +transformers >= 4.48.2 # Required for Bamba model and Transformers backend. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 181136c4e5c4..92e6f17e2129 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -11,8 +11,8 @@ except ImportError as e: raise ImportError( "To use `Gemma3ForConditionalGeneration`, you have to install " - "Hugging Face Transformers library from source via " - "`pip install git+https://github.com/huggingface/transformers`." + "Hugging Face Transformers library from source via `pip install " + "git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3`." ) from e from vllm.config import VllmConfig From 4b65deb75de57d39844dcf0818168311d297cfbb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 01:44:00 -0700 Subject: [PATCH 41/52] Remove hardcoded values Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3_mm.py | 27 +++++++++++++++---------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 92e6f17e2129..7172ad381821 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -37,10 +37,6 @@ logger = init_logger(__name__) -# TODO(woosuk): Get these values from the model config. -NUM_TOKENS_PER_IMAGE = 256 -BOI_TOKEN = "" - class Gemma3ImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -61,7 +57,8 @@ def get_mm_max_tokens_per_item( seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: - return {"image": NUM_TOKENS_PER_IMAGE} + hf_config = self.ctx.get_hf_config() + return {"image": hf_config.mm_tokens_per_image} def get_num_image_tokens( self, @@ -70,7 +67,8 @@ def get_num_image_tokens( image_height: int, processor: Optional[ProcessorMixin], ) -> int: - return NUM_TOKENS_PER_IMAGE + hf_config = self.ctx.get_hf_config() + return hf_config.mm_tokens_per_image def get_image_size_with_most_features(self) -> ImageSize: # Result in the max possible feature size (h:w = 16:1) @@ -84,8 +82,10 @@ def get_dummy_processor_inputs( seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - num_images = mm_counts.get("image", 0) + tokenizer = self.info.get_tokenizer() + boi_token = tokenizer.boi_token + num_images = mm_counts.get("image", 0) target_width, target_height = \ self.info.get_image_size_with_most_features() @@ -96,7 +96,7 @@ def get_dummy_processor_inputs( num_images=num_images) } return ProcessorInputs( - prompt_text=" ".join([BOI_TOKEN] * num_images), + prompt_text=" ".join([boi_token] * num_images), mm_data=mm_data, ) @@ -132,9 +132,14 @@ def _get_prompt_updates( hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: + tokenizer = self.info.get_tokenizer() hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_token = hf_processor.tokenizer.image_token - image_tokens_expanded = "".join([image_token] * NUM_TOKENS_PER_IMAGE) + hf_config = self.info.get_hf_config() + + boi_token = tokenizer.boi_token + image_token = tokenizer.image_token + mm_tokens_per_image = hf_config.mm_tokens_per_image + image_tokens_expanded = "".join([image_token] * mm_tokens_per_image) def get_replacement_gemma3(item_idx: int): return PromptUpdateDetails( @@ -145,7 +150,7 @@ def get_replacement_gemma3(item_idx: int): return [ PromptReplacement( modality="image", - target=BOI_TOKEN, + target=boi_token, replacement=get_replacement_gemma3, ) ] From 7ca73b1ccf6b09f4831e301ca02667164ef2ce95 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 01:48:39 -0700 Subject: [PATCH 42/52] Fix Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/gemma3_mm.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 7172ad381821..b7ffa57bbde4 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -337,10 +337,11 @@ def forward(self, inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) if vision_embeddings is not None: - self.prepare_attn_masks(input_ids, - positions, - mask_dtype=vision_embeddings.dtype, - **kwargs) + kwargs = self.prepare_attn_masks( + input_ids, + positions, + mask_dtype=vision_embeddings.dtype, + **kwargs) input_ids = None hidden_states = self.language_model.model(input_ids, @@ -357,7 +358,7 @@ def prepare_attn_masks( positions: torch.Tensor, mask_dtype: torch.dtype, **kwargs, - ) -> None: + ): kwargs["has_images"] = True # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. # This is a HACK. Fix this. @@ -401,12 +402,13 @@ def prepare_attn_masks( # Create a local causal mask with sliding window (1024). local_attn_mask = torch.ones_like(global_attn_mask) local_attn_mask = torch.tril(local_attn_mask, - diagonal=-self.sliding_window_size) + diagonal=-self.sliding_window) local_attn_mask = torch.where(local_attn_mask == 0, global_attn_mask, float("-inf")) local_attn_masks.append(local_attn_mask) kwargs["global_attn_masks"] = global_attn_masks kwargs["local_attn_masks"] = local_attn_masks + return kwargs def compute_logits( self, From d5f2eef4442c07ef1cdde489e326180d6975e577 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 01:56:36 -0700 Subject: [PATCH 43/52] Add min_transformers_version Signed-off-by: Woosuk Kwon --- tests/models/registry.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 78d3b8b0baac..205b233157fd 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -241,7 +241,8 @@ def check_available_online( "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), - "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), + "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it", + min_transformers_version="4.50"), "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 From a4dbd56d1764d2ee6d197fdb4e6f3db7c21a2a7b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 02:00:04 -0700 Subject: [PATCH 44/52] fix Signed-off-by: Woosuk Kwon --- docs/source/models/supported_models.md | 4 ++++ vllm/model_executor/models/gemma3_mm.py | 11 +---------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 4631e1a95ea3..971ab3affc69 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -950,6 +950,10 @@ To use Qwen2.5-VL series models, you have to install Hugging Face Transformers l ::: :::{note} +To use Gemma3 series models, you have to install Hugging Face Transformers library from source via +`pip install git+https://github.com/huggingface/transformers`. +The earliest commit that supports this is [`50d3530aa04e7a7d003e6b255a98f79fd0447357`](https://github.com/huggingface/transformers/commit/50d3530aa04e7a7d003e6b255a98f79fd0447357). + Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. However, there are differences in how they handle text + image inputs: diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index b7ffa57bbde4..a21b526d7a65 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -4,16 +4,7 @@ import torch from torch import nn -from transformers import BatchFeature, ProcessorMixin - -try: - from transformers import Gemma3Config -except ImportError as e: - raise ImportError( - "To use `Gemma3ForConditionalGeneration`, you have to install " - "Hugging Face Transformers library from source via `pip install " - "git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3`." - ) from e +from transformers import BatchFeature, Gemma3Config, ProcessorMixin from vllm.config import VllmConfig from vllm.logger import init_logger From 2a8e2fa73093dff5cbea28c88fdfb438c0865252 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 02:05:17 -0700 Subject: [PATCH 45/52] comment about pas Signed-off-by: Woosuk Kwon --- docs/source/models/supported_models.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 971ab3affc69..28e789da4d66 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -969,6 +969,9 @@ V1 currently uses a simplified attention pattern: - Will be updated in the future to support the correct behavior This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. + +Additionally, vLLM's current Gemma 3 implementation does not support the pan-and-scan image pre-processing algorithm, which helps handle images with skewed aspect ratios by intelligently cropping them into multiple views. +Without this feature, model performance may degrade when processing images that deviate significantly from square dimensions. ::: ### Pooling Models From 6ecd1ec3469fe10fe81e6c7e2c74447980cd4267 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 12 Mar 2025 03:10:01 -0700 Subject: [PATCH 46/52] fix input indices Signed-off-by: Roger Wang --- vllm/model_executor/models/gemma3_mm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index a21b526d7a65..c98412077206 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -367,8 +367,11 @@ def prepare_attn_masks( global_attn_masks = [] local_attn_masks = [] + start_idx = 0 for seq_len in seq_lens: + end_idx = start_idx + seq_len input_token_ids = input_ids[start_idx:end_idx] + start_idx = end_idx # Create a global causal mask. global_attn_mask = torch.empty( 1, From e2e2a22a12247536e72bc1b2536b547fe56cb8dc Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 12 Mar 2025 04:59:04 -0700 Subject: [PATCH 47/52] fix batch with mixed numbers of images Signed-off-by: Roger Wang --- vllm/model_executor/models/gemma3_mm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index c98412077206..121aee51786b 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -259,11 +259,11 @@ def _parse_and_validate_image_input( if pixel_values is None: return None - if not isinstance(pixel_values, torch.Tensor): + if not isinstance(pixel_values, (torch.Tensor, list[torch.Tensor])): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - pixel_values = flatten_bn(pixel_values) + pixel_values = flatten_bn(pixel_values, concat=True) return Gemma3ImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), From 49b29cae070d4cd873cda1bf62dbe8f0be2306df Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 12 Mar 2025 05:32:30 -0700 Subject: [PATCH 48/52] update examples with instruct format Signed-off-by: Roger Wang --- examples/offline_inference/vision_language.py | 4 +++- .../vision_language_multi_image.py | 19 ++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 429b30832a74..8a2257a30351 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -121,7 +121,9 @@ def run_fuyu(questions: list[str], modality: str): # Gemma 3 def run_gemma3(questions: list[str], modality: str): assert modality == "image" - prompts = [f" {question}" for question in questions] + prompts = [("user\n" + f"{question}\n" + "model\n") for question in questions] model_name = "google/gemma-3-4b-it" llm = LLM(model=model_name, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 27c32e0987c4..11a438ce54f7 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -86,7 +86,24 @@ def load_gemma3(question, image_urls: list[str]) -> ModelRequestData: max_model_len=8192, max_num_seqs=2, limit_mm_per_prompt={"image": len(image_urls)}) - prompt = " " * len(image_urls) + question + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [{ + "role": + "user", + "content": [ + *placeholders, + { + "type": "text", + "text": question + }, + ], + }] + + processor = AutoProcessor.from_pretrained(model_name) + + prompt = processor.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) return ModelRequestData( llm=llm, prompt=prompt, From 41110046a4a626bc492c33d6603938bd72d15fe0 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 12 Mar 2025 05:38:44 -0700 Subject: [PATCH 49/52] update doc Signed-off-by: Roger Wang --- docs/source/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 28e789da4d66..5444f512e5de 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -759,7 +759,7 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ - * `Gemma3ForConditionalGeneration` * Gemma3 - * T + I + * T + I+ * `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. * ✅︎ * ✅︎ From 154c24c5da96d0bec6b667bd6a67e11dcf0f3004 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 12 Mar 2025 14:38:32 +0000 Subject: [PATCH 50/52] Fix registry Signed-off-by: DarkLight1337 --- tests/models/registry.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/registry.py b/tests/models/registry.py index 205b233157fd..eadbd7e6f492 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -124,6 +124,8 @@ def check_available_online( "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), + "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it", + min_transformers_version="4.50"), "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), "GPT2LMHeadModel": _HfExamplesInfo("gpt2"), "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"), From a2d2062c60a2a1b66b9017c4d01da2f2fb93d5d2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 12 Mar 2025 14:40:25 +0000 Subject: [PATCH 51/52] Update doc Signed-off-by: DarkLight1337 --- docs/source/models/supported_models.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 5444f512e5de..98e7572981de 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -263,12 +263,12 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ * ✅︎ - * `Gemma2ForCausalLM` - * Gemma2 + * Gemma 2 * `google/gemma-2-9b`, `google/gemma-2-27b`, etc. * ✅︎ * ✅︎ - * `Gemma3ForCausalLM` - * Gemma3 + * Gemma 3 * `google/gemma-3-1b-it`, etc. * ✅︎ * ✅︎ @@ -509,7 +509,7 @@ you should explicitly specify the task type to ensure that the model is used in * * - * `Gemma2Model` - * Gemma2-based + * Gemma 2-based * `BAAI/bge-multilingual-gemma2`, etc. * * ✅︎ @@ -758,7 +758,7 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ * ✅︎ - * `Gemma3ForConditionalGeneration` - * Gemma3 + * Gemma 3 * T + I+ * `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. * ✅︎ From d3286757f63d1baeccb34cb7dd272cfdc87e0952 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 12 Mar 2025 11:48:16 +0000 Subject: [PATCH 52/52] Clean up examples Signed-off-by: DarkLight1337 --- examples/offline_inference/vision_language.py | 12 ++++++++---- .../offline_inference/vision_language_multi_image.py | 3 +++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 8a2257a30351..39acab4765a3 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -121,12 +121,16 @@ def run_fuyu(questions: list[str], modality: str): # Gemma 3 def run_gemma3(questions: list[str], modality: str): assert modality == "image" - prompts = [("user\n" - f"{question}\n" - "model\n") for question in questions] model_name = "google/gemma-3-4b-it" + llm = LLM(model=model_name, + max_model_len=2048, + max_num_seqs=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) + + prompts = [("user\n" + f"{question}\n" + "model\n") for question in questions] stop_token_ids = None return llm, prompts, stop_token_ids @@ -418,7 +422,7 @@ def run_mllama(questions: list[str], modality: str): "type": "image" }, { "type": "text", - "text": f"{question}" + "text": question }] }] for question in questions] prompts = tokenizer.apply_chat_template(messages, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 11a438ce54f7..4963e6a8c4e7 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -82,10 +82,12 @@ def load_deepseek_vl2(question: str, image_urls: list[str]): def load_gemma3(question, image_urls: list[str]) -> ModelRequestData: model_name = "google/gemma-3-4b-it" + llm = LLM(model=model_name, max_model_len=8192, max_num_seqs=2, limit_mm_per_prompt={"image": len(image_urls)}) + placeholders = [{"type": "image", "image": url} for url in image_urls] messages = [{ "role": @@ -104,6 +106,7 @@ def load_gemma3(question, image_urls: list[str]) -> ModelRequestData: prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + return ModelRequestData( llm=llm, prompt=prompt,