Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
if not hasattr(config, "hidden_size"):
# Support for llama4
config = config.text_config
# Default: Mixtral.
E = config.num_local_experts
topk = config.num_experts_per_tok
Expand Down
19 changes: 13 additions & 6 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ vLLM also supports model implementations that are available in Transformers. Thi

To check if the modeling backend is Transformers, you can simply do this:

```python
```python
from vllm import LLM
llm = LLM(model=..., task="generate") # Name or path of your model
llm.apply_model(lambda model: print(type(model)))
Expand Down Expand Up @@ -55,7 +55,7 @@ If your model is neither supported natively by vLLM or Transformers, you can sti
Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers.
Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM!

```python
```python
from vllm import LLM
llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model
llm.apply_model(lambda model: print(model.__class__))
Expand Down Expand Up @@ -840,6 +840,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
- * `Llama4ForConditionalGeneration`
* Llama-4-17B-Omni-Instruct
* T + I<sup>+</sup>
* `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc.
*
*
* ✅︎
- * `LlavaForConditionalGeneration`
* LLaVA-1.5
* T + I<sup>E+</sup>
Expand Down Expand Up @@ -982,10 +989,10 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
:::

<sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
&nbsp;&nbsp;&nbsp;&nbsp;• For example, to use DeepSeek-VL2 series models:
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;`--hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'`
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
&nbsp;&nbsp;&nbsp;&nbsp;• For example, to use DeepSeek-VL2 series models:
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;`--hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'`
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.

:::{important}
Expand Down
37 changes: 37 additions & 0 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,42 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
)


def run_llama4(questions: list[str], modality: str):
assert modality == "image"

model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=4,
tensor_parallel_size=8,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
gpu_memory_utilization=0.4,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [[{
"role":
"user",
"content": [{
"type": "image"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing content of image, it should be
{
"type": "image",
"image": "https://path/to/your/image.jpg"
}

Copy link
Member

@ywang96 ywang96 Apr 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way it works with our offline inference llm.generate interface is actually a bit different from huggingface interface. In this case we're adding this chunk here only for it to insert the image placeholder token into the prompt when we apply the chat template from the tokenizer.

}, {
"type": "text",
"text": f"{question}"
}]
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
add_generation_prompt=True,
tokenize=False)
stop_token_ids = None
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
stop_token_ids=stop_token_ids,
)


# Molmo
def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
Expand Down Expand Up @@ -907,6 +943,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
"minicpmv": run_minicpmv,
"mistral3": run_mistral3,
"mllama": run_mllama,
"llama4": run_llama4,
"molmo": run_molmo,
"NVLM_D": run_nvlm_d,
"paligemma": run_paligemma,
Expand Down
38 changes: 38 additions & 0 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
)


def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=4,
tensor_parallel_size=8,
limit_mm_per_prompt={"image": len(image_urls)},
)

placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]

processor = AutoProcessor.from_pretrained(model_name)

prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)


def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"

Expand Down Expand Up @@ -567,6 +604,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
"h2ovl_chat": load_h2ovl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
"llama4": load_llama4,
"mistral3": load_mistral3,
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
Expand Down
2 changes: 1 addition & 1 deletion requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers >= 4.50.3
transformers >= 4.51.0
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ mistral_common[opencv] >= 1.5.4 # required for pixtral test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test
transformers==4.50.3
transformers==4.51.0
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
# quantization
bitsandbytes>=0.45.3
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ tqdm==4.66.6
# transformers
tqdm-multiprocess==0.0.11
# via lm-eval
transformers==4.50.3
transformers==4.51.0
# via
# -r requirements/test.in
# genai-perf
Expand Down
16 changes: 16 additions & 0 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,22 @@
limit_mm_per_prompt={"image": 1},
)],
),
"llama4": VLMTestInfo(
models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"],
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501
img_idx_to_prompt=lambda _: "<|image|>",
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
distributed_executor_backend="mp",
image_size_factors=[(.25, 0.5, 1.0)],
hf_model_kwargs={"device_map": "auto"},
max_model_len=8192,
max_num_seqs=4,
dtype="bfloat16",
auto_cls=AutoModelForImageTextToText,
tensor_parallel_size=8,
vllm_runner_kwargs={"gpu_memory_utilization": 0.8},
marks=[large_gpu_mark(min_gb=80), multi_gpu_marks(num_gpus=8)],
),
}
# yapf: enable

Expand Down
1 change: 1 addition & 0 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def _test_processing_correctness_mistral(
"Skywork/Skywork-R1V-38B",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
Expand Down
99 changes: 99 additions & 0 deletions tests/models/multimodal/processing/test_llama4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for Llama4's multimodal preprocessing kwargs."""

import pytest

from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.transformers_utils.tokenizer import encode_tokens

from ....conftest import _ImageAssets
from ...utils import build_model_context


@pytest.mark.parametrize("model_id",
["meta-llama/Llama-4-Scout-17B-16E-Instruct"])
@pytest.mark.parametrize("mm_processor_kwargs", [{}])
@pytest.mark.parametrize("num_imgs", [1, 5])
@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False])
@pytest.mark.parametrize("tokenized_prompt", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict,
num_imgs: int,
disable_mm_preprocessor_cache: bool,
tokenized_prompt: bool,
):
"""Ensure llama4 processor works properly."""
ctx = build_model_context(
model_id,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": num_imgs},
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
config = processor.info.get_hf_config()
tokenizer = processor.info.get_tokenizer()
hf_processor = processor.info.get_hf_processor()
vocab = tokenizer.get_vocab()

prompt = "<|begin_of_text|><|header_start|>user<|header_end|>" \
+ "<|image|>" * num_imgs \
+ "<|eot|><|header_start|>assistant<|header_end|>"
mm_data = {
"image": [
image_assets[(i % len(image_assets))].pil_image
for i in range(num_imgs)
]
}
if tokenized_prompt:
prompt = encode_tokens(tokenizer, prompt)

processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
mm_kwargs = processed_inputs["mm_kwargs"]

# place holder replacements
prompt_token_ids = processed_inputs["prompt_token_ids"]
assert prompt_token_ids.count(config.boi_token_index) == num_imgs
assert prompt_token_ids.count(config.eoi_token_index) == num_imgs
assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs
aspect_ratios = mm_kwargs["aspect_ratios"]
num_x_separators = num_y_separators = 0
for tiles_y, tiles_x in aspect_ratios:
if tiles_x * tiles_y > 1:
num_x_separators += (tiles_x - 1) * tiles_y
num_y_separators += tiles_y
assert prompt_token_ids.count(vocab[hf_processor.tile_token]) \
== num_x_separators
assert prompt_token_ids.count(vocab[hf_processor.tile_global_token]) \
== num_y_separators

# image token offsets
img_locs = processed_inputs["mm_placeholders"].get("image", [])
assert len(img_locs) == num_imgs
assert [img_loc["offset"] for img_loc in img_locs] == \
[i for i, v in enumerate(prompt_token_ids) \
if v == config.boi_token_index]

# patch sizes and masks
assert prompt_token_ids.count(config.image_token_index) \
== sum(img_patch.sum() for img_patch in mm_kwargs["embed_is_patch"])
patch_token_id = vocab[hf_processor.img_patch_token]
num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id)
mm_counts = {"image": num_imgs}
assert num_patches / num_imgs <= \
processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"]
num_patches_per_chunk = processor.info.get_patch_per_chunk(
config.vision_config)
assert prompt_token_ids.count(config.image_token_index) \
== mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk
assert mm_kwargs["pixel_values"].shape[0] \
== mm_kwargs["patches_per_image"].sum()

for embed_is_patch, aspect_ratio in zip(mm_kwargs["embed_is_patch"],
mm_kwargs["aspect_ratios"]):
assert embed_is_patch.shape[0] == \
len(tokenizer.encode(
hf_processor._prompt_split_image(
aspect_ratio, num_patches_per_chunk),
add_special_tokens=False))
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def check_available_online(
tokenizer="facebook/bart-base",
trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
}

Expand Down
12 changes: 10 additions & 2 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@

@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch):

# Llama4ForCausalLM does not have a standalone model
if model_arch == "Llama4ForCausalLM":
return

model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_transformers_version(on_fail="skip")

Expand Down Expand Up @@ -91,8 +96,11 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda):


def test_hf_registry_coverage():
untested_archs = (ModelRegistry.get_supported_archs() -
HF_EXAMPLE_MODELS.get_supported_archs())
untested_archs = set(ModelRegistry.get_supported_archs() -
HF_EXAMPLE_MODELS.get_supported_archs())

# Llama4ForCausalLM does not have a standalone model
untested_archs.discard("Llama4ForCausalLM")

assert not untested_archs, (
"Please add the following architectures to "
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ def __init__(
self.hf_config = hf_config

self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(self.hf_text_config,
"attention_chunk_size", None)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def _placeholder_str(self, modality: ModalityStr,
"internvl_chat", "skywork_chat", "NVLM_D",
"h2ovl_chat"):
return "<image>"
if model_type == "mllama":
if model_type in ("mllama", "llama4"):
return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>"
Expand Down
Loading