Skip to content

Conversation

@zou3519
Copy link
Collaborator

@zou3519 zou3519 commented Apr 18, 2025

When people say they don't like the "torch.compile startup time", they mean two things:

  1. the cold start time
  2. the warm start time (when the vLLM disk cache has already been
    populated).

We had logging for (1), we didn't have (2). This PR adds (2)

Test Plan:

I ran VLLM_USE_V1=1 python benchmark_latency.py --model meta-llama/Meta-Llama-3-8B --batch-size 1 -O '{"level": 3, "compile_sizes": {1, 2}}'

And observed the following logs:

INFO 04-18 08:26:11 [backends.py:431] Dynamo bytecode transform time:
5.03 s
INFO 04-18 08:26:15 [backends.py:120] Directly load the compiled
graph(s) for shape None from the cache, took 4.190 s
INFO 04-18 08:26:18 [kv_cache_utils.py:634] GPU KV cache size: 532,032
tokens

Side note: it's probably not good that loading from the cache takes 4 seconds?

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@zou3519 zou3519 marked this pull request as ready for review April 18, 2025 15:43
Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code change looks good to me. I didn't expect it takes so long though.

@zou3519 zou3519 added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 19, 2025
When people say they don't like the "torch.compile startup time", they
mean two things:
1) the cold start time
2) the warm start time (when the vLLM disk cache has already been
   populated).

We had logging for (1), we didn't have (2). This PR adds (2)

Test Plan:

I ran `VLLM_USE_V1=1 python benchmark_latency.py --model meta-llama/Meta-Llama-3-8B --batch-size 1 -O  '{"level": 3, "compile_sizes": {1, 2}}'`

And observed the following logs:
```
INFO 04-18 08:26:11 [backends.py:431] Dynamo bytecode transform time:
5.03 s
INFO 04-18 08:26:15 [backends.py:120] Directly load the compiled
graph(s) for shape None from the cache, took 4.190 s
INFO 04-18 08:26:18 [kv_cache_utils.py:634] GPU KV cache size: 532,032
tokens
```
Side note: it's probably not good that loading from the cache takes 4
seconds?

Signed-off-by: rzou <zou3519@gmail.com>
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) April 19, 2025 14:45
@tlrmchlsmth tlrmchlsmth merged commit 682e0b6 into vllm-project:main Apr 19, 2025
44 checks passed
@github-project-automation github-project-automation bot moved this from To triage to Done in torch.compile integration Apr 19, 2025
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
Signed-off-by: rzou <zou3519@gmail.com>
Signed-off-by: Yang Wang <elainewy@meta.com>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
Signed-off-by: rzou <zou3519@gmail.com>
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: rzou <zou3519@gmail.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed torch.compile

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants