Skip to content

Conversation

@Lucaskabela
Copy link
Contributor

@Lucaskabela Lucaskabela commented Aug 19, 2025

We enable @supports_torch_compile on generic nn.Modules (as opposed to only top level architecture) and demonstrate the application in qwen_2_5_vl

Purpose

This PR is a first step towards supporting torch compile for multimodal encoders (such as vision or audio). Since these modality specific components do not have vLLM config, we begin by taking a step to allow these modules to be compiled

Test Plan

Unit Test

with-proxy pytest tests/compile/test_multimodal_compile.py
1 passed, 9 warnings in 45.26s

E2E Tests

VLLM_USE_V1=1 python examples/offline_inference/vision_language.py -m qwen2_5_vl

Test Result

BEFORE CHANGES:

^[[1;36m(EngineCore_0 pid=2638689)^[[0;0m INFO 08-19 14:14:20 [backends.py:559] Dynamo bytecode transform time: 3.26 s
^[[1;36m(EngineCore_0 pid=2638689)^[[0;0m INFO 08-19 14:14:22 [backends.py:161] Directly load the compiled graph(s) for dynamic shape from the cache, took 1.965 s
^[[1;36m(EngineCore_0 pid=2638689)^[[0;0m INFO 08-19 14:14:23 [monitor.py:34] torch.compile takes 3.26 s in total
...

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts:  25%|██▌       | 1/4 [00:00<00:02,  1.41it/s, est. speed input: 1800.03 toks/s, output: 90.28 toks/s]
Processed prompts: 100%|██████████| 4/4 [00:00<00:00,  1.41it/s, est. speed input: 7137.36 toks/s, output: 357.98 toks/s]
Processed prompts: 100%|██████████| 4/4 [00:00<00:00,  5.59it/s, est. speed input: 7137.36 toks/s, output: 357.98 toks/s]

AFTER CHANGES:

^[[1;36m(EngineCore_0 pid=2569806)^[[0;0m INFO 08-19 14:12:31 [backends.py:559] Dynamo bytecode transform time: 0.37 s
^[[1;36m(EngineCore_0 pid=2569806)^[[0;0m INFO 08-19 14:12:31 [backends.py:161] Directly load the compiled graph(s) for dynamic shape from the cache, took 0.020 s
^[[1;36m(EngineCore_0 pid=2569806)^[[0;0m INFO 08-19 14:12:31 [monitor.py:34] torch.compile takes 0.37 s in total
^[[1;36m(EngineCore_0 pid=2569806)^[[0;0m INFO 08-19 14:12:34 [backends.py:548] Using cache directory: /home/lucaskabela/.cache/vllm/torch_compile_cache/1de7867ecf/rank_0_0/backbone for vLLM's torch.compile
^[[1;36m(EngineCore_0 pid=2569806)^[[0;0m INFO 08-19 14:12:34 [backends.py:559] Dynamo bytecode transform time: 3.16 s
^[[1;36m(EngineCore_0 pid=2569806)^[[0;0m INFO 08-19 14:12:37 [backends.py:161] Directly load the compiled graph(s) for dynamic shape from the cache, took 2.109 s
^[[1;36m(EngineCore_0 pid=2569806)^[[0;0m INFO 08-19 14:12:38 [monitor.py:34] torch.compile takes 3.53 s in total
...
Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts:  25%|██▌       | 1/4 [00:00<00:02,  1.43it/s, est. speed input: 1822.83 toks/s, output: 91.42 toks/s]
Processed prompts: 100%|██████████| 4/4 [00:00<00:00,  1.43it/s, est. speed input: 7228.82 toks/s, output: 362.57 toks/s]
Processed prompts: 100%|██████████| 4/4 [00:00<00:00,  5.66it/s, est. speed input: 7228.82 toks/s, output: 362.57 toks/s]

Benchmarks

vllm serve Qwen/Qwen2.5-VL-3B-Instruct
vllm bench serve   --backend openai-chat   --endpoint-type openai-chat   --model Qwen/Qwen2.5-VL-3B-Instruct   --endpoint /v1/chat/completions   --dataset-name hf   --dataset-path lmarena-ai/VisionArena-Chat   --hf-split train   --num-prompts 1000

Main + Torch (nightly)

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  135.59    
Total input tokens:                      94327     
Total generated tokens:                  106141    
Request throughput (req/s):              7.37      
Output token throughput (tok/s):         782.79    
Total Token throughput (tok/s):          1478.44   
---------------Time to First Token----------------
Mean TTFT (ms):                          63924.39  
Median TTFT (ms):                        63589.24  
P99 TTFT (ms):                           127523.65 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          681.74    
Median TPOT (ms):                        661.11    
P99 TPOT (ms):                           1375.61   
---------------Inter-token Latency----------------
Mean ITL (ms):                           596.16    
Median ITL (ms):                         78.84     
P99 ITL (ms):                            1564.77   
==================================================

This PR + Torch (nightly)

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  129.73    
Total input tokens:                      94327     
Total generated tokens:                  106596    
Request throughput (req/s):              7.71      
Output token throughput (tok/s):         821.69    
Total Token throughput (tok/s):          1548.80   
---------------Time to First Token----------------
Mean TTFT (ms):                          61043.37  
Median TTFT (ms):                        60044.40  
P99 TTFT (ms):                           123046.15 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          648.66    
Median TPOT (ms):                        625.63    
P99 TPOT (ms):                           1308.29   
---------------Inter-token Latency----------------
Mean ITL (ms):                           569.48    
Median ITL (ms):                         76.11     
P99 ITL (ms):                            1492.54   
==================================================

So with this PR, we observe moderate improvements to several key metrics (~5% relative improvement to throughput, ~2% improvement to TTFT, ~7% improvement to TPOT, and ~7% improvement to mean ITL (median ITL has only a minor improvement)

This requires only a minor rewrite (moving attention to a custom operator) in order to torch compile the main vision backbone (the VisionBlock).

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the qwen Related to Qwen models label Aug 19, 2025
@miladm miladm self-requested a review August 19, 2025 22:42
@Lucaskabela Lucaskabela marked this pull request as ready for review August 19, 2025 22:51
@ProExpertProg
Copy link
Collaborator

Sorry I'm not familiar, why do mm components not have vllm*_config set?

@tanruixiang
Copy link
Contributor

This modification may cause recompilation during runtime. It appears that this may cause timeouts in some scenarios.🤔

@Lucaskabela
Copy link
Contributor Author

Sorry I'm not familiar, why do mm components not have vllm*_config set?

I am honestly not too sure; an alternative we could do here is forward the vllm_config to these subclasses; this would be a less flexible solution (as it requires any nn_module we want to compile to have vllm_config in the init), but it is doable

@Lucaskabela
Copy link
Contributor Author

This modification may cause recompilation during runtime. It appears that this may cause timeouts in some scenarios.🤔

This is a valid concern, I will stress test this to check - if there is some easy way to trigger recompiles, can you share examples or a script that would trigger recompiles? In the meanwhileI will continue to test and try and trigger on my own. Thanks!

@Lucaskabela
Copy link
Contributor Author

@zou3519
Copy link
Collaborator

zou3519 commented Aug 25, 2025

This modification may cause recompilation during runtime. It appears that this may cause timeouts in some scenarios.🤔

@tanruixiang what do you mean by this? The way the vLLM-compile integration works, vLLM deletes the guards that TorchDynamo produces, so there will never be recompilations during inference serving time.

@tanruixiang
Copy link
Contributor

@tanruixiang what do you mean by this? The way the vLLM-compile integration works, vLLM deletes the guards that TorchDynamo produces, so there will never be recompilations during inference serving time.

@zou3519 Thanks for pointing this out. I also see slack's discussion. Let's wait for the benchmark first.

@mergify mergify bot added documentation Improvements or additions to documentation llama Related to Llama models labels Aug 26, 2025
@Lucaskabela
Copy link
Contributor Author

@tanruixiang @zou3519 @ProExpertProg @ywang96 updated with some benchmark results for more context on our discussions

@miladm
Copy link
Collaborator

miladm commented Aug 27, 2025

@Lucaskabela do we have a bug for future steps of this effort?

@Lucaskabela
Copy link
Contributor Author

@miladm currently just waiting on feedback here

@Lucaskabela Lucaskabela force-pushed the lucaskabela/compile_nn_module branch from 37ba5ac to d7fd15b Compare September 8, 2025 17:31
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Lucaskabela! I actually had a chat with @youkaichao and the conclusion is that this current state is probably not good enough for us to merge this PR.

Generally speaking, while it's feasible for us to add torch.compile support for encoder in a model-by-model case (we're already doing model-by-model encoder DP support), there should be a clear evidence that this brings performance benefit since this adds additional startup time. WDYT?

@Lucaskabela
Copy link
Contributor Author

Hi @ywang96 I think that is a very valid point - I will work on applying this to some places in order to show strong benefits

@Lucaskabela Lucaskabela changed the title [Misc][qwen2_5_vl] Enable supports_torch_compile on generic nn.Module [DRAFT][Misc][qwen2_5_vl] Enable supports_torch_compile on generic nn.Module Sep 10, 2025
@mergify
Copy link

mergify bot commented Sep 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Lucaskabela.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 10, 2025

output, _ = self.proj(context_layer)
return output


@torch.library.custom_op("mylib::attn_executor", mutates_args=())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: We push all the compile breaking things here (such as .item/.tolist). This way, we can specify this as a custom op and compile around it

Copy link
Contributor

@laithsakka laithsakka Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm why are we not just instead allowing unbacked int the graph? have we tried?

@mergify mergify bot removed the needs-rebase label Sep 11, 2025
@DarkLight1337
Copy link
Member

Does #27705 fix the issue?

@ZJY0516
Copy link
Contributor

ZJY0516 commented Oct 29, 2025

Hi! @Lucaskabela @ywang96 This PR completely breaks the inference when I send an image to the model, for example I send it an invoice, and it tells me that it cannot see absolutely anything.

Could you provide a repro script?

@JartX
Copy link
Contributor

JartX commented Oct 29, 2025

@DarkLight1337 no, only have revert the #27705 to avoid problems with this: #23207

@ZJY0516

VLLM_ATTENTION_BACKEND=TORCH_SDPA lm_eval
--model local-chat-completions
--model_args model=Qwen/Qwen3-VL-30B-A3B-Instruct,base_url=http://localhost:8000/v1/chat/completions,num_concurrent=10,max_retries=3,tokenized_requests=False
--tasks chartqa
--batch_size auto
--apply_chat_template
After start vllm server

@JartX
Copy link
Contributor

JartX commented Oct 29, 2025

@ZJY0516 You can also try sending an image. If I send you the following: https://ibb.co/XZVsT7dB
It says it sees absolutely nothing. Just a black background with wavy lines.

It's this PR: #23207 where we are that has broken my inference, not @lgeiger's.

@lgeiger
Copy link
Contributor

lgeiger commented Oct 29, 2025

@JartX what Hardware are you testing this on?

@JartX
Copy link
Contributor

JartX commented Oct 29, 2025

@lgeiger
AMD ROCM RDNA3 Attn Backend ViT TORCH.SDPA

My command of docker run:

docker run -d --name vllm1 --tty --restart unless-stopped --shm-size 48gb -p 80:8000 --device /dev/kfd --device /dev/dri --group-add video --ipc host --network host --cap-add SYS_PTRACE --security-opt seccomp=unconfined --privileged -v ./chat_vl.jinja:/chat-template-tools.jinja -e HSA_OVERRIDE_GFX_VERSION=11.0.0 -e VLLM_USE_V1=1 -e VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 -e MIOPEN_USER_DB_PATH=/apps/miopen-cache -e MIOPEN_CUSTOM_CACHE_DIR=/apps/miopen-cache -e OMP_NUM_THREADS=1 -e VLLM_ENABLE_V1_MULTIPROCESSING=1 vllm-rocm-251029_broken_test_1 vllm serve Qwen/Qwen3-VL-30B-A3B-Instruct --gpu-memory-utilization 0.98 --max_model_len 65536 -tp 4 --served-model-name QWEN3--port 80 --limit-mm-per-prompt '{"image":6,"video":0}' --dtype float16 --enable-log-requests --chat-template /chat-template-tools.jinja

@DarkLight1337
Copy link
Member

cc @tjtanaa for AMD stuff

@lgeiger
Copy link
Contributor

lgeiger commented Oct 29, 2025

@JartX Maybe we need to add back the rocm-only contiguous() https://github.com/vllm-project/vllm/pull/23207/files#diff-19c9c9ac2ca98754bf900c292a95b9930f352f63f8619e3e0026a8e31f3e4396L432-L437?

@JartX
Copy link
Contributor

JartX commented Oct 29, 2025

Hi @lgeiger could you add the contiguous on torch.sdpa on your pr please?

https://github.com/vllm-project/vllm/pull/27741/files

@lgeiger
Copy link
Contributor

lgeiger commented Oct 29, 2025

Hi @lgeiger could you add the contiguous on torch.sdpa on your pr please?

@JartX I don't have access to AMD hardware to test. So would appreciate if you or somebody else with the ability to test would make a PR with these changes.

@tjtanaa
Copy link
Collaborator

tjtanaa commented Oct 29, 2025

I will be debugging on Mi300X.

@JartX
Copy link
Contributor

JartX commented Oct 29, 2025

adding

if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()

            works again

@tjtanaa @DarkLight1337 @lgeiger I'm going to do a PR express thanks @lgeiger

@tjtanaa
Copy link
Collaborator

tjtanaa commented Oct 29, 2025

@JartX go ahead and open that bugfix PR. For other path on ROCm, we still need to do some fixing.

ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
…generic nn.Module and demonstrate speedup on Qwen Vision model (vllm-project#23207)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Signed-off-by: Lucas Kabela <lucasakabela@gmail.com>
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
…generic nn.Module and demonstrate speedup on Qwen Vision model (vllm-project#23207)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Signed-off-by: Lucas Kabela <lucasakabela@gmail.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
…generic nn.Module and demonstrate speedup on Qwen Vision model (vllm-project#23207)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Signed-off-by: Lucas Kabela <lucasakabela@gmail.com>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Nov 12, 2025
### What this PR does / why we need it?
adapt vllm-ascend main branch with vllm releases/v0.11.1

fix `forward context not set` in test_vlm.py caused by:
vllm-project/vllm#23207

fix import `cdiv round` failed caused by:
vllm-project/vllm#27188

fix import `init_cached_hf_modules` failed caused by:
vllm-project/vllm#27567

adapt triton kernel `fused_recurrent_gated_delta_rule_fwd_kernel` caused
by: vllm-project/vllm#27654
- remove unused code in sigmoid_gating.py
- `class FusedRecurrentFunction` , `fused_recurrent_gated_delta_rule`,
`fused_recurrent_gated_delta_rule_fwd`

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI 


- vLLM version: v0.11.0
- vLLM main:
vllm-project/vllm@83f478b

Signed-off-by: 22dimensions <waitingwind@foxmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation llama Related to Llama models qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.