Skip to content

Conversation

@yewentao256
Copy link
Member

@yewentao256 yewentao256 commented Sep 25, 2025

Purpose

2025-09-24T04:15:29.297966219Z (EngineCore_DP4 pid=1673) INFO 09-24 00:15:29 [v1/worker/gpu_model_runner.py:3380] Graph capturing finished in 79 secs, took -9.96 GiB
2025-09-24T04:15:29.298151199Z (EngineCore_DP4 pid=1673) DEBUG 09-24 00:15:29 [v1/worker/gpu_worker.py:401] Free memory on device (176.17/178.35 GiB) on startup. Desired GPU memory utilization is (0.9, 160.52 GiB). Actual usage is 60.79 GiB for weight, 13.33 GiB for peak activation, 10.96 GiB for non-torch memory, and -9.96 GiB for CUDAGraph memory. Replace gpu_memory_utilization config with `--kv-cache-memory=91539781222` (85.25 GiB) to fit into requested memory, or `--kv-cache-memory=108345741312` (100.9 GiB) to fully utilize gpu memory. Current kv cache memory in use is 75.44 GiB.

This PR updates the logic for memory usage capture, hopefully fix the issue.

Test

@smarterclayton Could you try again using this branch?

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug where CUDA graph memory usage was reported as a negative value. The root cause was asynchronous CUDA operations leading to inaccurate memory measurements. The fix correctly moves the memory measurement logic inside the graph_capture context and, crucially, adds a torch.cuda.synchronize() call before measuring the final memory usage. This ensures that all graph capture operations are complete before the memory is queried, providing an accurate measurement. The changes are logical, well-targeted, and should resolve the issue. The implementation is correct.

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 force-pushed the wentao-fix-negative-cuda-memory-usage branch from c3c2406 to d771de1 Compare September 25, 2025 15:43
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 26, 2025
@yewentao256
Copy link
Member Author

@smarterclayton Hi Clayton, wondering if you could validate this fix if you have time? So that we can get this landed.

@mgoin mgoin merged commit da554f9 into main Oct 1, 2025
54 checks passed
@mgoin mgoin deleted the wentao-fix-negative-cuda-memory-usage branch October 1, 2025 22:16
pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
yewentao256 added a commit that referenced this pull request Oct 3, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
tomeras91 pushed a commit to tomeras91/vllm that referenced this pull request Oct 6, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
southfreebird pushed a commit to southfreebird/vllm that referenced this pull request Oct 7, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants