Skip to content

Commit 50d3530

Browse files
RyanMullinsxenovapcuencaariG23498MayankChaturvedi
authored
Gemma3 (#36658)
* Fix converter * [Broken] Adds Gemma 3 to Hugging Face Transformers * Consolidating Config and Processor params across impls * Sorting out configuration parameters. Adds qk_norm before RoPE. Still not sure if RoPE is right. * Additional plumbing for CausalLM and ConditionalGeneration variants * incomplete draft of Orbax conversion script * More complete checkpoint conversion * Supporting Gemma 3 1B checkpoints * Updating RoPE for multiple frequencies * Adjustments to rotary embedder * Proof of life for text-only operation * Updating the conversion script to handle multimodal projection weights * Fixing tet-only conversions * Cleaner conversion script with multimodal support and a simpler processor * Additional refatcors to the Gemma3Processor * Simplified Processor to work over text representations * Updated conversion script to join text and vision embeddings at converion time * Logging for debugging * Update src/transformers/models/gemma2/modeling_gemma2.py Co-authored-by: Joshua Lochner <admin@xenova.com> * Removed extraneous Config params * Switching to fast tokenizer for checkpoint conversions * isolating siglip for performance tetsing * Minor changes for debugging tests against baselines * Adding average pooling for soft tokens * Updating processor code to enable simpler embedding interleaving for arbitrary number of images in prompts * Updating conversion script for ShieldGemma 2 conversion compatibility * Allow disable_compile to be provided as a kwarg * Refresh from modular * Updated conversion script and corrected sliding window * Fix type mismatch in cache_position (#4) * Fix dtype (#5) * Fix type mismatch in cache_position * Actually fix in the modular file Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com> --------- Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com> * fixes for embedding table overflow and missing image_soft_token_mask from Gemma3Processor * Adding 2D pooling for image embeddings * Revert "Adding 2D pooling for image embeddings" This reverts commit 65350cf. * Gemma3 average pooling changed from 1D to 2D * Major refactor to Gemma3MultimodalInputProjection * Updating Gemm 3 Auto* registrations * Add option to save Gemma 3 chat template with tokenizer during weights conversion * Removing unused imports * Moving out-of-vocab handling from Gemma3Processor to Gemma3ForConditionalGeneration * Removing duplicate config property * Removing final logit softcapping and 1-indexing of position ids * Fixing image processor config and none --> None typo * Fixing sliding window size for 1B * Updating image_mean and image_std in Image Processor * Attention masking changed to lower triangular * Moving image special tokens to conversion script * Mirror image processor defaults from conversion script into Gemma3ProcessorKwargs * Remove special token variables from symbol space * Moving image soft token mask computation from Gemma3Processor to Gemma3ForConditionalGeneration * tie lm_head and embedding weights Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Correct tied weights in Gemma3CausalLM * iterative bidirectional attention * resolving merge conflicts * Reverting to Gemma 2 HybridCache with sldiing window support and a sliding_window_pattern of 6 * Correcting RoPE scaling * clean up first pass, dummy model geenration works * final clean up before fixing tests * causal lm test works, so fine * Fix conversion * Update src/transformers/models/gemma3/processing_gemma3.py * model tests are happy * processor tests are happy * image processing tests added * fixup * Fix pre-processing in conversion * Inputs merging * Do not normalize vision embeddings * Apply Ryan's (and team) changes to attention * token type ids + mask * template * move embed scale, add rope scale, fix tests * Add chat template to tokenizer * Use prefix for causal model loading * use existing code for sliding mask from gemma2 * self.embed_tokens already normalizes * Correcting Gemma3TextConfig parameters in conversion script * typo, modular overwrites my fixes * enable device map for text model * Conversion updates * ultra nit: no einsums * update image token * copy deepcopy config + some docs * add some test, still WIP * Refactoring --include_chat_tempalte logic in converter * Update src/transformers/models/gemma3/modular_gemma3.py Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> * Add eos tokens for instruct models * dump so i can work on dgx * Removing add_bos by default * dump * add fast im proc * docs for PaS + fixup * another fixup * one more fixup * fix tests * Inverting prior BOS change * ultra nit * Reverting to Tokenizer saved with add_bos_token=True and chat template starting with BOS * resize embeds, remove sqrt, add slow test outputs * FA2 but quality is meh * nit * skip FA2, no idea what happened * last bit for green CI * please, green CI for docs * T_T * Fix for Gemma3 logits * Support both options for system prompt * Update src/transformers/models/gemma3/image_processing_gemma3_fast.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/model_doc/gemma3.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/model_doc/gemma3.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/model_doc/gemma3.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/model_doc/gemma3.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/model_doc/gemma3.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Docs updates now that assets are live * Style fixes --------- Co-authored-by: Joshua Lochner <admin@xenova.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com> Co-authored-by: Mayank Chaturvedi <imayank@google.com> Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Co-authored-by: raushan <raushan@huggingface.co> Co-authored-by: Raushan Turganbay <raushan.turganbay@alumni.nu.edu.kz> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> Co-authored-by: Lysandre <hi@lysand.re>
1 parent 81aa9b2 commit 50d3530

33 files changed

+5469
-116
lines changed

docs/source/en/model_doc/gemma3.md

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
2+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
5+
the License. You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
10+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11+
specific language governing permissions and limitations under the License.
12+
13+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
14+
rendered properly in your Markdown viewer.
15+
16+
-->
17+
18+
# Gemma3
19+
20+
## Overview
21+
22+
The Gemma 3 model was proposed in the [Gemma 3 Techncial Report](https://goo.gle/Gemma3Report) by Google. It is a vision-language model composed by a [SigLIP](siglip) vision encoder and a [Gemma 2](gemma_2) language decoder, linked by a multimodal linear projection. It cuts an image into a fixed number of tokens, in the same way as SigLIP, as long as the image does not exceed certain aspect ratio. For images that exceed the given aspect ratio, it crops the image into multiple smaller patches and concatenates them with the base image embedding. One particularity is that the model uses bidirectional attention on all the image tokens. In addition, the model interleaves sliding window local attention with full causal attention in the language backbone, where each sixth layer is a full causal attention layer.
23+
24+
This model was contributed by [Ryan Mullins](https://huggingface.co/RyanMullins), [Raushan Turganbay](https://huggingface.co/RaushanTurganbay) [Arthur Zucker](https://huggingface.co/ArthurZ), and [Pedro Cuenca](https://huggingface.co/pcuenq).
25+
26+
27+
## Usage tips
28+
29+
30+
- For image+text and image-only inputs use `Gemma3ForConditionalGeneration`.
31+
- For text-only inputs use `Gemma3ForCausalLM` for generation to avoid loading the vision tower.
32+
- Each sample can contain multiple images, and the number of images can vary between samples. However, make sure to pass correctly batched images to the processor, where each batch is a list of one or more images.
33+
- The text passed to the processor should have a `<start_of_image>` token wherever an image should be inserted.
34+
- The processor has its own `apply_chat_template` method to convert chat messages to model inputs. See the examples below for more details on how to use it.
35+
36+
37+
### Image cropping for high resolution images
38+
39+
The model supports cropping images into smaller patches when the image aspect ratio exceeds a certain value. By default the images are not cropped and only the base image is forwarded to the model. Users can set `do_pan_and_scan=True` to obtain several crops per image along with the base image to improve the quality in DocVQA or similar tasks requiring higher resolution images.
40+
41+
Pan and scan is an inference time optimization to handle images with skewed aspect ratios. When enabled, it improves performance on tasks related to document understanding, infographics, OCR, etc.
42+
43+
```python
44+
45+
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it", padding_side="left")
46+
47+
url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
48+
messages = [
49+
{
50+
"role": "system",
51+
"content": [
52+
{"type": "text", "text": "You are a helpful assistant."}
53+
]
54+
},
55+
{
56+
"role": "user", "content": [
57+
{"type": "image", "url": url},
58+
{"type": "text", "text": "What is shown in this image?"},
59+
]
60+
},
61+
]
62+
inputs = processor.apply_chat_template(
63+
messages,
64+
tokenize=True,
65+
return_dict=True,
66+
return_tensors="pt",
67+
add_generation_prompt=True,
68+
do_pan_and_scan=True,
69+
).to(model.device)
70+
71+
```
72+
73+
74+
## Usage Example
75+
76+
### Single-image Inference
77+
78+
```python
79+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
80+
81+
model_id = "google/gemma-3-4b-it"
82+
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
83+
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
84+
85+
url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
86+
messages = [
87+
{
88+
"role": "system",
89+
"content": [
90+
{"type": "text", "text": "You are a helpful assistant."}
91+
]
92+
},
93+
{
94+
"role": "user", "content": [
95+
{"type": "image", "url": url},
96+
{"type": "text", "text": "What is shown in this image?"},
97+
]
98+
},
99+
]
100+
inputs = processor.apply_chat_template(
101+
messages,
102+
tokenize=True,
103+
return_dict=True,
104+
return_tensors="pt",
105+
add_generation_prompt=True,
106+
).to(model.device)
107+
108+
output = model.generate(**inputs, max_new_tokens=50)
109+
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
110+
```
111+
112+
### Multi-image Inference
113+
114+
```python
115+
model_id = "google/gemma-3-4b-it"
116+
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
117+
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
118+
119+
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
120+
url_stop = "https://www.ilankelman.org/stopsigns/australia.jpg"
121+
messages = [
122+
{
123+
"role": "system",
124+
"content": [
125+
{"type": "text", "text": "You are a helpful assistant."}
126+
]
127+
},
128+
{
129+
"role": "user", "content": [
130+
{"type": "image", "url": url_cow},
131+
{"type": "image", "url": url_stop},
132+
{"type": "text", "text": "Are these two images identical?"},
133+
]
134+
},
135+
]
136+
inputs = processor.apply_chat_template(
137+
messages,
138+
tokenize=True,
139+
return_dict=True,
140+
return_tensors="pt",
141+
add_generation_prompt=True,
142+
).to(model.device)
143+
144+
output = model.generate(**inputs, max_new_tokens=50)
145+
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
146+
147+
```
148+
149+
### Text-only inference
150+
151+
You can use the VLMs for text-only generation by omitting images in your input. However, you can also load the models in text-only mode as shown below. This will skip loading the vision tower and will save resources when you just need the LLM capabilities.
152+
```python
153+
from transformers import AutoTokenizer, Gemma3ForCausalLM
154+
155+
model_id = "google/gemma-3-1b-it"
156+
157+
tokenizer = AutoTokenizer.from_pretrained(model_id)
158+
model = Gemma3ForCausalLM.from_pretrained(model_id, device_map="auto")
159+
160+
input_ids = tokenizer("Write me a poem about Machine Learning.", return_tensors="pt").to(model.device)
161+
162+
outputs = model.generate(**input_ids, max_new_tokens=100)
163+
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
164+
165+
print(text)
166+
167+
```
168+
169+
170+
## Gemma3ImageProcessor
171+
172+
[[autodoc]] Gemma3ImageProcessor
173+
174+
## Gemma3ImageProcessorFast
175+
176+
[[autodoc]] Gemma3ImageProcessorFast
177+
178+
## Gemma3Processor
179+
180+
[[autodoc]] Gemma3Processor
181+
182+
## Gemma3TextConfig
183+
184+
[[autodoc]] Gemma3TextConfig
185+
186+
## Gemma3Config
187+
188+
[[autodoc]] Gemma3Config
189+
190+
## Gemma3TextModel
191+
192+
[[autodoc]] Gemma3TextModel
193+
- forward
194+
195+
## Gemma3ForCausalLM
196+
197+
[[autodoc]] Gemma3ForCausalLM
198+
- forward
199+
200+
## Gemma3ForConditionalGeneration
201+
202+
[[autodoc]] Gemma3ForConditionalGeneration
203+
- forward

src/transformers/__init__.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@
474474
"models.fuyu": ["FuyuConfig"],
475475
"models.gemma": ["GemmaConfig"],
476476
"models.gemma2": ["Gemma2Config"],
477+
"models.gemma3": ["Gemma3Config", "Gemma3Processor", "Gemma3TextConfig"],
477478
"models.git": [
478479
"GitConfig",
479480
"GitProcessor",
@@ -1259,6 +1260,7 @@
12591260
_import_structure["models.emu3"].append("Emu3ImageProcessor")
12601261
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"])
12611262
_import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"])
1263+
_import_structure["models.gemma3"].append("Gemma3ImageProcessor")
12621264
_import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"])
12631265
_import_structure["models.got_ocr2"].extend(["GotOcr2ImageProcessor"])
12641266
_import_structure["models.grounding_dino"].extend(["GroundingDinoImageProcessor"])
@@ -1332,6 +1334,7 @@
13321334
_import_structure["models.deit"].append("DeiTImageProcessorFast")
13331335
_import_structure["models.depth_pro"].append("DepthProImageProcessorFast")
13341336
_import_structure["models.detr"].append("DetrImageProcessorFast")
1337+
_import_structure["models.gemma3"].append("Gemma3ImageProcessorFast")
13351338
_import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast")
13361339
_import_structure["models.llava"].append("LlavaImageProcessorFast")
13371340
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
@@ -2452,6 +2455,14 @@
24522455
"Gemma2PreTrainedModel",
24532456
]
24542457
)
2458+
_import_structure["models.gemma3"].extend(
2459+
[
2460+
"Gemma3ForCausalLM",
2461+
"Gemma3ForConditionalGeneration",
2462+
"Gemma3PreTrainedModel",
2463+
"Gemma3TextModel",
2464+
]
2465+
)
24552466
_import_structure["models.git"].extend(
24562467
[
24572468
"GitForCausalLM",
@@ -2554,14 +2565,14 @@
25542565
"GraniteMoePreTrainedModel",
25552566
]
25562567
)
2568+
25572569
_import_structure["models.granitemoeshared"].extend(
25582570
[
25592571
"GraniteMoeSharedForCausalLM",
25602572
"GraniteMoeSharedModel",
25612573
"GraniteMoeSharedPreTrainedModel",
25622574
]
25632575
)
2564-
25652576
_import_structure["models.grounding_dino"].extend(
25662577
[
25672578
"GroundingDinoForObjectDetection",
@@ -5629,6 +5640,7 @@
56295640
from .models.fuyu import FuyuConfig
56305641
from .models.gemma import GemmaConfig
56315642
from .models.gemma2 import Gemma2Config
5643+
from .models.gemma3 import Gemma3Config, Gemma3Processor, Gemma3TextConfig
56325644
from .models.git import (
56335645
GitConfig,
56345646
GitProcessor,
@@ -6450,6 +6462,7 @@
64506462
FlavaProcessor,
64516463
)
64526464
from .models.fuyu import FuyuImageProcessor, FuyuProcessor
6465+
from .models.gemma3 import Gemma3ImageProcessor
64536466
from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor
64546467
from .models.got_ocr2 import GotOcr2ImageProcessor
64556468
from .models.grounding_dino import GroundingDinoImageProcessor
@@ -6535,6 +6548,7 @@
65356548
from .models.deit import DeiTImageProcessorFast
65366549
from .models.depth_pro import DepthProImageProcessorFast
65376550
from .models.detr import DetrImageProcessorFast
6551+
from .models.gemma3 import Gemma3ImageProcessorFast
65386552
from .models.got_ocr2 import GotOcr2ImageProcessorFast
65396553
from .models.llava import LlavaImageProcessorFast
65406554
from .models.llava_next import LlavaNextImageProcessorFast
@@ -7461,6 +7475,12 @@
74617475
Gemma2Model,
74627476
Gemma2PreTrainedModel,
74637477
)
7478+
from .models.gemma3 import (
7479+
Gemma3ForCausalLM,
7480+
Gemma3ForConditionalGeneration,
7481+
Gemma3PreTrainedModel,
7482+
Gemma3TextModel,
7483+
)
74647484
from .models.git import (
74657485
GitForCausalLM,
74667486
GitModel,

src/transformers/convert_slow_tokenizer.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
113113
sp = self.sp
114114
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
115115

116-
# there is a missing token in the vocab. We have to do this to support merges
116+
# If "\t" is missing in the vocab, we have to do this to support merges
117117
# "<0x09>" is the bytefallback for `\t`
118-
vocab["\t"] = vocab.get("<0x09>")
119-
118+
if "\t" not in vocab:
119+
vocab["\t"] = vocab.get("<0x09>")
120120
merges = generate_merges(vocab, vocab_scores)
121121
return vocab, merges
122122

@@ -1296,12 +1296,14 @@ def vocab(self, proto):
12961296
(self.original_tokenizer.eos_token, 0.0),
12971297
(self.original_tokenizer.bos_token, 0.0),
12981298
]
1299-
for piece in proto.pieces[3:]:
1300-
if piece.piece == "<0x09>":
1301-
vocab += [("\t", piece.score)]
1302-
else:
1303-
vocab += [(piece.piece, piece.score)]
1304-
# vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
1299+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
1300+
1301+
# Older gemma tokenizers had a missing tab token, so we fix that here
1302+
if not any(x[0] == "\t" for x in vocab):
1303+
override_index = next((i for i, x in enumerate(vocab) if x[0] == "<0x09>"), None)
1304+
if override_index is not None:
1305+
vocab[override_index] = ("\t", 0.0)
1306+
13051307
return vocab
13061308

13071309
def pre_tokenizer(self, replacement, add_prefix_space):

src/transformers/modeling_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -849,13 +849,13 @@ def _load_state_dict_into_meta_model(
849849
is_quantized = hf_quantizer is not None
850850

851851
for serialized_param_name, empty_param in state_dict.items():
852+
if serialized_param_name not in expected_keys:
853+
continue
854+
852855
# serialized_param_name is the raw, serialized name
853856
# fixed_param_name is the model's equivalent
854857
fixed_param_name, _ = model.rename_key(serialized_param_name)
855858

856-
if fixed_param_name not in expected_keys:
857-
continue
858-
859859
# we need to use serialized_param_name as file pointer is untouched
860860
if shard_file.endswith(".safetensors"):
861861
param = file_pointer.get_slice(serialized_param_name)

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
fuyu,
107107
gemma,
108108
gemma2,
109+
gemma3,
109110
git,
110111
glm,
111112
glpn,

src/transformers/models/auto/configuration_auto.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@
124124
("fuyu", "FuyuConfig"),
125125
("gemma", "GemmaConfig"),
126126
("gemma2", "Gemma2Config"),
127+
("gemma3", "Gemma3Config"),
128+
("gemma3_text", "Gemma3TextConfig"),
127129
("git", "GitConfig"),
128130
("glm", "GlmConfig"),
129131
("glpn", "GLPNConfig"),
@@ -459,6 +461,8 @@
459461
("fuyu", "Fuyu"),
460462
("gemma", "Gemma"),
461463
("gemma2", "Gemma2"),
464+
("gemma3", "Gemma3ForConditionalGeneration"),
465+
("gemma3_text", "Gemma3ForCausalLM"),
462466
("git", "GIT"),
463467
("glm", "GLM"),
464468
("glpn", "GLPN"),
@@ -748,6 +752,7 @@
748752
("qwen2_audio_encoder", "qwen2_audio"),
749753
("clip_text_model", "clip"),
750754
("aria_text", "aria"),
755+
("gemma3_text", "gemma3"),
751756
("idefics3_vision", "idefics3"),
752757
("siglip_vision_model", "siglip"),
753758
("smolvlm_vision", "smolvlm"),

src/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
("flava", ("FlavaImageProcessor",)),
8787
("focalnet", ("BitImageProcessor",)),
8888
("fuyu", ("FuyuImageProcessor",)),
89+
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
8990
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
9091
("glpn", ("GLPNImageProcessor",)),
9192
("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),

0 commit comments

Comments
 (0)