-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[Model] Support Llama4 in vLLM #16104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
58d9c2f
dcb2c77
89083a6
188bb52
6ad393f
ee170a7
a19cf7b
bacd195
ec6cdaa
62e9744
b4533e3
0587bc7
c0ca739
866b94a
1b8b67a
4e45bfc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -582,6 +582,42 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: | |
| ) | ||
|
|
||
|
|
||
| def run_llama4(questions: list[str], modality: str): | ||
| assert modality == "image" | ||
|
|
||
| model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct" | ||
|
|
||
| engine_args = EngineArgs( | ||
| model=model_name, | ||
| max_model_len=8192, | ||
| max_num_seqs=4, | ||
| tensor_parallel_size=8, | ||
| disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, | ||
| gpu_memory_utilization=0.4, | ||
| ) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| messages = [[{ | ||
| "role": | ||
| "user", | ||
| "content": [{ | ||
| "type": "image" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing content of image, it should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The way it works with our offline inference |
||
| }, { | ||
| "type": "text", | ||
| "text": f"{question}" | ||
| }] | ||
| }] for question in questions] | ||
| prompts = tokenizer.apply_chat_template(messages, | ||
| add_generation_prompt=True, | ||
| tokenize=False) | ||
| stop_token_ids = None | ||
| return ModelRequestData( | ||
| engine_args=engine_args, | ||
| prompts=prompts, | ||
| stop_token_ids=stop_token_ids, | ||
| ) | ||
|
|
||
|
|
||
| # Molmo | ||
| def run_molmo(questions: list[str], modality: str) -> ModelRequestData: | ||
| assert modality == "image" | ||
|
|
@@ -907,6 +943,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: | |
| "minicpmv": run_minicpmv, | ||
| "mistral3": run_mistral3, | ||
| "mllama": run_mllama, | ||
| "llama4": run_llama4, | ||
| "molmo": run_molmo, | ||
| "NVLM_D": run_nvlm_d, | ||
| "paligemma": run_paligemma, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| """Tests for Llama4's multimodal preprocessing kwargs.""" | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm.multimodal import MULTIMODAL_REGISTRY | ||
| from vllm.transformers_utils.tokenizer import encode_tokens | ||
|
|
||
| from ....conftest import _ImageAssets | ||
| from ...utils import build_model_context | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model_id", | ||
| ["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) | ||
| @pytest.mark.parametrize("mm_processor_kwargs", [{}]) | ||
| @pytest.mark.parametrize("num_imgs", [1, 5]) | ||
| @pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False]) | ||
| @pytest.mark.parametrize("tokenized_prompt", [True, False]) | ||
| def test_processor_override( | ||
| image_assets: _ImageAssets, | ||
| model_id: str, | ||
| mm_processor_kwargs: dict, | ||
| num_imgs: int, | ||
| disable_mm_preprocessor_cache: bool, | ||
| tokenized_prompt: bool, | ||
| ): | ||
| """Ensure llama4 processor works properly.""" | ||
| ctx = build_model_context( | ||
| model_id, | ||
| mm_processor_kwargs=mm_processor_kwargs, | ||
| limit_mm_per_prompt={"image": num_imgs}, | ||
| disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, | ||
| ) | ||
| processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) | ||
| config = processor.info.get_hf_config() | ||
| tokenizer = processor.info.get_tokenizer() | ||
| hf_processor = processor.info.get_hf_processor() | ||
| vocab = tokenizer.get_vocab() | ||
|
|
||
| prompt = "<|begin_of_text|><|header_start|>user<|header_end|>" \ | ||
| + "<|image|>" * num_imgs \ | ||
| + "<|eot|><|header_start|>assistant<|header_end|>" | ||
| mm_data = { | ||
| "image": [ | ||
| image_assets[(i % len(image_assets))].pil_image | ||
| for i in range(num_imgs) | ||
| ] | ||
| } | ||
| if tokenized_prompt: | ||
| prompt = encode_tokens(tokenizer, prompt) | ||
|
|
||
| processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) | ||
| mm_kwargs = processed_inputs["mm_kwargs"] | ||
|
|
||
| # place holder replacements | ||
| prompt_token_ids = processed_inputs["prompt_token_ids"] | ||
| assert prompt_token_ids.count(config.boi_token_index) == num_imgs | ||
| assert prompt_token_ids.count(config.eoi_token_index) == num_imgs | ||
| assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs | ||
| aspect_ratios = mm_kwargs["aspect_ratios"] | ||
| num_x_separators = num_y_separators = 0 | ||
| for tiles_y, tiles_x in aspect_ratios: | ||
| if tiles_x * tiles_y > 1: | ||
| num_x_separators += (tiles_x - 1) * tiles_y | ||
| num_y_separators += tiles_y | ||
| assert prompt_token_ids.count(vocab[hf_processor.tile_token]) \ | ||
| == num_x_separators | ||
| assert prompt_token_ids.count(vocab[hf_processor.tile_global_token]) \ | ||
| == num_y_separators | ||
|
|
||
| # image token offsets | ||
| img_locs = processed_inputs["mm_placeholders"].get("image", []) | ||
| assert len(img_locs) == num_imgs | ||
| assert [img_loc["offset"] for img_loc in img_locs] == \ | ||
| [i for i, v in enumerate(prompt_token_ids) \ | ||
| if v == config.boi_token_index] | ||
|
|
||
| # patch sizes and masks | ||
| assert prompt_token_ids.count(config.image_token_index) \ | ||
| == sum(img_patch.sum() for img_patch in mm_kwargs["embed_is_patch"]) | ||
| patch_token_id = vocab[hf_processor.img_patch_token] | ||
| num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id) | ||
| mm_counts = {"image": num_imgs} | ||
| assert num_patches / num_imgs <= \ | ||
| processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"] | ||
| num_patches_per_chunk = processor.info.get_patch_per_chunk( | ||
| config.vision_config) | ||
| assert prompt_token_ids.count(config.image_token_index) \ | ||
| == mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk | ||
| assert mm_kwargs["pixel_values"].shape[0] \ | ||
| == mm_kwargs["patches_per_image"].sum() | ||
|
|
||
| for embed_is_patch, aspect_ratio in zip(mm_kwargs["embed_is_patch"], | ||
| mm_kwargs["aspect_ratios"]): | ||
| assert embed_is_patch.shape[0] == \ | ||
| len(tokenizer.encode( | ||
| hf_processor._prompt_split_image( | ||
| aspect_ratio, num_patches_per_chunk), | ||
| add_special_tokens=False)) |
Uh oh!
There was an error while loading. Please reload this page.