Skip to content

Commit 726fea2

Browse files
WoosukKwonywang96DarkLight1337
authored andcommitted
[Model] Add support for Gemma 3 (vllm-project#14660)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Roger Wang <ywang@roblox.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 7276993 commit 726fea2

File tree

10 files changed

+1071
-9
lines changed

10 files changed

+1071
-9
lines changed

docs/source/models/supported_models.md

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,15 @@ See [this page](#generative-models) for more information on how to use generativ
263263
* ✅︎
264264
* ✅︎
265265
- * `Gemma2ForCausalLM`
266-
* Gemma2
266+
* Gemma 2
267267
* `google/gemma-2-9b`, `google/gemma-2-27b`, etc.
268268
* ✅︎
269269
* ✅︎
270+
- * `Gemma3ForCausalLM`
271+
* Gemma 3
272+
* `google/gemma-3-1b-it`, etc.
273+
* ✅︎
274+
* ✅︎
270275
- * `GlmForCausalLM`
271276
* GLM-4
272277
* `THUDM/glm-4-9b-chat-hf`, etc.
@@ -504,7 +509,7 @@ you should explicitly specify the task type to ensure that the model is used in
504509
*
505510
*
506511
- * `Gemma2Model`
507-
* Gemma2-based
512+
* Gemma 2-based
508513
* `BAAI/bge-multilingual-gemma2`, etc.
509514
*
510515
* ✅︎
@@ -752,6 +757,13 @@ See [this page](#generative-models) for more information on how to use generativ
752757
*
753758
* ✅︎
754759
* ✅︎
760+
- * `Gemma3ForConditionalGeneration`
761+
* Gemma 3
762+
* T + I<sup>+</sup>
763+
* `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc.
764+
* ✅︎
765+
* ✅︎
766+
* ✅︎\*
755767
- * `GLM4VForCausalLM`<sup>^</sup>
756768
* GLM-4V
757769
* T + I
@@ -937,6 +949,31 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
937949
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
938950
:::
939951

952+
:::{note}
953+
To use Gemma3 series models, you have to install Hugging Face Transformers library from source via
954+
`pip install git+https://github.com/huggingface/transformers`.
955+
The earliest commit that supports this is [`50d3530aa04e7a7d003e6b255a98f79fd0447357`](https://github.com/huggingface/transformers/commit/50d3530aa04e7a7d003e6b255a98f79fd0447357).
956+
957+
Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
958+
However, there are differences in how they handle text + image inputs:
959+
960+
V0 correctly implements the model's attention pattern:
961+
- Uses bidirectional attention between the image tokens corresponding to the same image
962+
- Uses causal attention for other tokens
963+
- Implemented via (naive) PyTorch SDPA with masking tensors
964+
- Note: May use significant memory for long prompts with image
965+
966+
V1 currently uses a simplified attention pattern:
967+
- Uses causal attention for all tokens, including image tokens
968+
- Generates reasonable outputs but does not match the original model's attention for text + image inputs
969+
- Will be updated in the future to support the correct behavior
970+
971+
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
972+
973+
Additionally, vLLM's current Gemma 3 implementation does not support the pan-and-scan image pre-processing algorithm, which helps handle images with skewed aspect ratios by intelligently cropping them into multiple views.
974+
Without this feature, model performance may degrade when processing images that deviate significantly from square dimensions.
975+
:::
976+
940977
### Pooling Models
941978

942979
See [this page](pooling-models) for more information on how to use pooling models.

examples/offline_inference/vision_language.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,23 @@ def run_fuyu(questions: list[str], modality: str):
118118
return llm, prompts, stop_token_ids
119119

120120

121+
# Gemma 3
122+
def run_gemma3(questions: list[str], modality: str):
123+
assert modality == "image"
124+
model_name = "google/gemma-3-4b-it"
125+
126+
llm = LLM(model=model_name,
127+
max_model_len=2048,
128+
max_num_seqs=2,
129+
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
130+
131+
prompts = [("<bos><start_of_turn>user\n"
132+
f"<start_of_image>{question}<end_of_turn>\n"
133+
"<start_of_turn>model\n") for question in questions]
134+
stop_token_ids = None
135+
return llm, prompts, stop_token_ids
136+
137+
121138
# GLM-4v
122139
def run_glm4v(questions: list[str], modality: str):
123140
assert modality == "image"
@@ -405,7 +422,7 @@ def run_mllama(questions: list[str], modality: str):
405422
"type": "image"
406423
}, {
407424
"type": "text",
408-
"text": f"{question}"
425+
"text": question
409426
}]
410427
}] for question in questions]
411428
prompts = tokenizer.apply_chat_template(messages,
@@ -664,6 +681,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str):
664681
"deepseek_vl_v2": run_deepseek_vl2,
665682
"florence2": run_florence2,
666683
"fuyu": run_fuyu,
684+
"gemma3": run_gemma3,
667685
"glm4v": run_glm4v,
668686
"h2ovl_chat": run_h2ovl,
669687
"idefics3": run_idefics3,

examples/offline_inference/vision_language_multi_image.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,42 @@ def load_deepseek_vl2(question: str, image_urls: list[str]):
8080
)
8181

8282

83+
def load_gemma3(question, image_urls: list[str]) -> ModelRequestData:
84+
model_name = "google/gemma-3-4b-it"
85+
86+
llm = LLM(model=model_name,
87+
max_model_len=8192,
88+
max_num_seqs=2,
89+
limit_mm_per_prompt={"image": len(image_urls)})
90+
91+
placeholders = [{"type": "image", "image": url} for url in image_urls]
92+
messages = [{
93+
"role":
94+
"user",
95+
"content": [
96+
*placeholders,
97+
{
98+
"type": "text",
99+
"text": question
100+
},
101+
],
102+
}]
103+
104+
processor = AutoProcessor.from_pretrained(model_name)
105+
106+
prompt = processor.apply_chat_template(messages,
107+
tokenize=False,
108+
add_generation_prompt=True)
109+
110+
return ModelRequestData(
111+
llm=llm,
112+
prompt=prompt,
113+
stop_token_ids=None,
114+
image_data=[fetch_image(url) for url in image_urls],
115+
chat_template=None,
116+
)
117+
118+
83119
def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
84120
model_name = "h2oai/h2ovl-mississippi-800m"
85121

@@ -496,6 +532,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
496532
model_example_map = {
497533
"aria": load_aria,
498534
"deepseek_vl_v2": load_deepseek_vl2,
535+
"gemma3": load_gemma3,
499536
"h2ovl_chat": load_h2ovl,
500537
"idefics3": load_idefics3,
501538
"internvl_chat": load_internvl,

tests/models/multimodal/processing/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def _test_processing_correctness(
162162
"deepseek-ai/deepseek-vl2-tiny",
163163
"microsoft/Florence-2-base",
164164
"adept/fuyu-8b",
165+
"google/gemma-3-4b-it",
165166
"THUDM/glm-4v-9b",
166167
"h2oai/h2ovl-mississippi-800m",
167168
"OpenGVLab/InternVL2-1B",

tests/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def check_available_online(
124124
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
125125
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
126126
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
127+
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it",
128+
min_transformers_version="4.50"),
127129
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
128130
"GPT2LMHeadModel": _HfExamplesInfo("gpt2"),
129131
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"),
@@ -241,6 +243,8 @@ def check_available_online(
241243
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
242244
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
243245
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
246+
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it",
247+
min_transformers_version="4.50"),
244248
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
245249
trust_remote_code=True,
246250
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501

vllm/config.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,11 @@ def __init__(
350350
if self.enforce_eager is None:
351351
self.enforce_eager = False
352352

353+
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
353354
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
354355
has_interleaved_attention = (sliding_window is not None) and (
355356
isinstance(sliding_window, list) or
356-
(self.hf_text_config.model_type in ["gemma2", "cohere2"]))
357+
(self.hf_text_config.model_type in interleaved_attn_models))
357358

358359
if (not self.disable_sliding_window and has_interleaved_attention):
359360
if (backend :=
@@ -2503,11 +2504,11 @@ def _get_and_verify_dtype(
25032504
dtype = dtype.lower()
25042505
if dtype == "auto":
25052506
if config_dtype == torch.float32:
2506-
if config.model_type == "gemma2":
2507+
if config.model_type in ("gemma2", "gemma3", "gemma3_text"):
25072508
logger.info(
2508-
"For Gemma 2, we downcast float32 to bfloat16 instead "
2509-
"of float16 by default. Please specify `dtype` if you "
2510-
"want to use float16.")
2509+
"For Gemma 2 and 3, we downcast float32 to bfloat16 "
2510+
"instead of float16 by default. Please specify `dtype` "
2511+
"if you want to use float16.")
25112512
torch_dtype = torch.bfloat16
25122513
else:
25132514
# Following the common practice, we use float16 for float32
@@ -2639,7 +2640,9 @@ def _get_and_verify_max_len(
26392640
derived_max_model_len = default_max_len
26402641

26412642
rope_scaling = getattr(hf_config, "rope_scaling", None)
2642-
if rope_scaling is not None:
2643+
# NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE
2644+
# scaling, so we skip applying the scaling factor again.
2645+
if rope_scaling is not None and "gemma3" not in hf_config.model_type:
26432646
# No need to consider "type" key because of patch_rope_scaling when
26442647
# loading HF config
26452648
rope_type = rope_scaling["rope_type"]

vllm/entrypoints/chat_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,8 @@ def _placeholder_str(self, modality: ModalityStr,
433433
return "<image>"
434434
if model_type == "aria":
435435
return "<|fim_prefix|><|img|><|fim_suffix|>"
436+
if model_type == "gemma3":
437+
return "<start_of_image>"
436438

437439
raise TypeError(f"Unknown {modality} model type: {model_type}")
438440
elif modality == "audio":

0 commit comments

Comments
 (0)