Skip to content

Conversation

@houseroad
Copy link
Collaborator

@houseroad houseroad commented Apr 5, 2025

Add the support for Llama4 Scout (17B x 16 Experts) and Maverick (17B x 128 Experts) in vLLM.

Using 8xH100, vLLM can serve Scout with 1M context and Maverick with about 430K.

vllm serve meta-llama/Llama-4-Scout-17B-16E-Instruct \
  --tensor-parallel-size 8 \
  --max-model-len 1280000

vllm serve meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 \
  --tensor-parallel-size 8 \
  --max-model-len 430000

Using 8xH200, vLLM can serve Scout with 3.6M context and Maverick with full 1M context.

vllm serve meta-llama/Llama-4-Scout-17B-16E-Instruct \
  --tensor-parallel-size 8 \
  --max-model-len 3600000

vllm serve meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 \
  --tensor-parallel-size 8 

Using MI300x, we can run with default settings.

VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_USE_MODELSCOPE=False \
SAFETENSORS_FAST_GPU=1 VLLM_USE_V1=1 vllm serve meta-llama/Llama-4-Maverick-17B-128E-Instruct \
  --disable-log-requests -tp 8 \
  --max-num-seqs 64

Check out blog post [link coming soon] for performance enhancement and leveraging long context.

FIX #16106

@github-actions
Copy link

github-actions bot commented Apr 5, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@houseroad houseroad marked this pull request as ready for review April 5, 2025 19:09
@mergify mergify bot added documentation Improvements or additions to documentation ci/build frontend multi-modality Related to multi-modality (#4194) v1 labels Apr 5, 2025
houseroad and others added 5 commits April 5, 2025 12:12
Co-authored-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com>
Co-authored-by: Chris Thi <chris.c.thi@gmail.com>
Co-authored-by: drisspg <drisspguessous@gmail.com>
Co-authored-by: Jon Swenson <jmswen@gmail.com>
Co-authored-by: Keyun Tong <tongkeyun@gmail.com>
Co-authored-by: Lu Fang <fanglu@meta.com>
Co-authored-by: Lu Fang <lufang@meta.com>
Co-authored-by: Xiaodong Wang <xdwang@meta.com>
Co-authored-by: Yang Chen <yangche@fb.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: Yong Hoon Shin <yhshin@meta.com>
Co-authored-by: Zijing Liu <liuzijing2014@gmail.com>

Signed-off-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com>
Signed-off-by: Chris Thi <chris.c.thi@gmail.com>
Signed-off-by: drisspg <drisspguessous@gmail.com>
Signed-off-by: Jon Swenson <jmswen@gmail.com>
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
Signed-off-by: Lu Fang <fanglu@meta.com>
Signed-off-by: Xiaodong Wang <xdwang@meta.com>
Signed-off-by: Yang Chen <yangche@fb.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Lu Fang <lufang@fb.com>
This reverts commit 188bb52.

Signed-off-by: Lu Fang <lufang@fb.com>
@ywang96 ywang96 self-assigned this Apr 5, 2025
@robertgshaw2-redhat
Copy link
Collaborator

🔥

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multimodal part looks fine to me - left some nits but we can fix them later

Comment on lines 100 to 101
assert topk == 1, \
"apply_router_weight_on_input is currently only implemented for topk=1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move this assert to be in the if apply_router_weight_on_input: conditional? This seems restrictive without checking if apply_router_weight_on_input is true

topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to add attribute like in other method

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is WIP by @luccafong

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you guys referring to the pre-commit failure? sorry I think this was from my changes, @luccafong I have a fix for this I can push if you want, otherwise I can send you a patch (if you haven't already fixed it)

This reverts commit ee170a7.

Signed-off-by: Lu Fang <lufang@fb.com>
@dsingal0
Copy link

dsingal0 commented Apr 5, 2025

Is it expected to get this error:
File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 451, in inspect_model_cls

return self._raise_for_unsupported(architectures)

       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 401, in _raise_for_unsupported

raise ValueError(

ValueError: Model architectures ['Llama4ForConditionalGeneration'] failed to be inspected. Please check the logs for more details.

* fix lint

* remove unnecessary codes

* remove apply_router_weight_on_input from abstract class and remaining unrelated moe quantized methods
@ywang96
Copy link
Member

ywang96 commented Apr 5, 2025

Is it expected to get this error: File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 451, in inspect_model_cls

return self._raise_for_unsupported(architectures)

       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 401, in _raise_for_unsupported

raise ValueError(

ValueError: Model architectures ['Llama4ForConditionalGeneration'] failed to be inspected. Please check the logs for more details.

@dsingal0 Which version of transformers are you on?

@dsingal0
Copy link

dsingal0 commented Apr 5, 2025

Is it expected to get this error: File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 451, in inspect_model_cls

return self._raise_for_unsupported(architectures)

       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 401, in _raise_for_unsupported

raise ValueError(

ValueError: Model architectures ['Llama4ForConditionalGeneration'] failed to be inspected. Please check the logs for more details.

@dsingal0 Which version of transformers are you on?

transformers-4.52.0.dev0

block_table: torch.tensor,
page_size: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.tensor]:
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be named q_seqlens_np?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be, I just dropped the np suffixes in this function since they are all numpy arrays, but we could add them back in a future PR

Copy link
Member

@tlrmchlsmth tlrmchlsmth Apr 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LocalAttentionMetadata and make_local_attention_virtual_batches look good to me. BTW has anybody profiled this? We should look at writing a "kernel" as a followup

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe so, atleast I never did. I think as a first cut we could even just write a C++ op, this code is ALOT easier to understand as a loop and honestly would probably be faster as a loop (assuming its a C++ loop and not a python loop) since theres sooo many numpy calls in this version. I just wrote it this way assuming it would scale to larger batch sizes better than a python loop.

@dsingal0
Copy link

dsingal0 commented Apr 5, 2025

I think transformers.models.llama4.image_processing_llama4 needs to be changed to transformers.models.llama4.image_processing_llama4_fast

@ywang96
Copy link
Member

ywang96 commented Apr 5, 2025

I think transformers.models.llama4.image_processing_llama4 needs to be changed to transformers.models.llama4.image_processing_llama4_fast

Yea it's been addressed in 62e9744 already

@simon-mo simon-mo added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 5, 2025
ywang96 added 2 commits April 5, 2025 15:55
Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
houseroad and others added 4 commits April 5, 2025 16:14
… apply (#4)

Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
* Add missing apply_router_weight_on_input arg to all FusedMoEMethodBase classes

* Make linter happy

* More lint fixes

* Revert "More lint fixes"

This reverts commit 675b3c1.
Signed-off-by: Roger Wang <ywang@roblox.com>
**kwargs)

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 10}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why limit it to 10 images only if the model has to support way more, given its context length and benchmark results published by Meta claiming of processing up to 20 hours of video?

Copy link
Member

@ywang96 ywang96 Apr 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think video inference is the scope of this release yet?

This PR doesn't support video modality so I guess it'll come in the next model update?

Copy link
Collaborator

@yeqcharlotte yeqcharlotte Apr 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlekseyKorshuk 8-10 image is the recommended mm limit giving you acceptable quality although from the infra perspective it can do more.

Llama4’s video tokenizer works slightly different form image and we’ll update that once it’s available.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a fair point, but it raises an error if set cli argument to >10 multimodal limit. Shouldn't 10 be a default value, but not the hard limit that is not possible to overcome without changing the code?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm okay with not capping it at 10, but setting a default value for this will be something model-dependent which we currently don't support today on vLLM (and it's tricky to do that since today there's no standard on how many images a model can support up to), so we let user do it by passing limit-mm-per-prompt.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, just wanted to make sure that this value is easy for users to change based on their needs. Thanks for the reply, gonna resolve the conversation

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given test failures are not particularly related to changes in this PR and non-blocking, I think this PR is good to go! Thanks to Meta team for this amazing contribution to vLLM!

@simon-mo simon-mo merged commit c575232 into vllm-project:v0.8.3 Apr 6, 2025
60 of 67 checks passed
"role":
"user",
"content": [{
"type": "image"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing content of image, it should be
{
"type": "image",
"image": "https://path/to/your/image.jpg"
}

Copy link
Member

@ywang96 ywang96 Apr 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way it works with our offline inference llm.generate interface is actually a bit different from huggingface interface. In this case we're adding this chunk here only for it to insert the image placeholder token into the prompt when we apply the chat template from the tokenizer.

houseroad added a commit to houseroad/vllm that referenced this pull request Apr 6, 2025
Co-authored-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com>
Co-authored-by: Chris Thi <chris.c.thi@gmail.com>
Co-authored-by: drisspg <drisspguessous@gmail.com>
Co-authored-by: Jon Swenson <jmswen@gmail.com>
Co-authored-by: Keyun Tong <tongkeyun@gmail.com>
Co-authored-by: Lu Fang <fanglu@meta.com>
Co-authored-by: Lu Fang <lufang@meta.com>
Co-authored-by: Xiaodong Wang <xdwang@meta.com>
Co-authored-by: Yang Chen <yangche@fb.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: Yong Hoon Shin <yhshin@meta.com>
Co-authored-by: Zijing Liu <liuzijing2014@gmail.com>

Signed-off-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com>
Signed-off-by: Chris Thi <chris.c.thi@gmail.com>
Signed-off-by: drisspg <drisspguessous@gmail.com>
Signed-off-by: Jon Swenson <jmswen@gmail.com>
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
Signed-off-by: Lu Fang <fanglu@meta.com>
Signed-off-by: Xiaodong Wang <xdwang@meta.com>
Signed-off-by: Yang Chen <yangche@fb.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Signed-off-by: Lu Fang <lufang@fb.com>
@fsaudm
Copy link

fsaudm commented Apr 6, 2025

Quantization support?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation frontend multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.