Skip to content

Commit 9a40f07

Browse files
committed
[Docs] Add guide to debugging vLLM-torch.compile integration
The main takeaways from this page are: 1) use tlparse to get logs (and/or send them around because they help diagnose issues) 2) there are a variety of flags that can be used to turn off specific features. use these to isolate the exact feature that is problematic. Signed-off-by: Richard Zou <zou3519@gmail.com>
1 parent 878fd5a commit 9a40f07

File tree

4 files changed

+218
-0
lines changed

4 files changed

+218
-0
lines changed
315 KB
Loading
359 KB
Loading
257 KB
Loading

docs/design/debug_vllm_compile.md

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# How to debug the vLLM-torch.compile integration
2+
3+
TL;DR:
4+
5+
- use tlparse to acquire torch.compile logs. Include these logs in bug reports and/or support asks.
6+
- The vLLM-torch.compile integration is multiple pieces. vLLM exposes flags to turn off each piece.
7+
8+
## vLLM-torch.compile overview
9+
10+
To improve performance, vLLM leverages torch.compile and CUDAGraphs to speed things up.
11+
torch.compile generates optimized kernels for PyTorch code while CUDAGraphs eliminates overhead.
12+
Most notably, vLLM-compile is NOT torch.compile, it is a custom compiler built using internal PyTorch Compile APIs.
13+
14+
![vLLM-compile diagram](../assets/design/debug_vllm_compile/design_diagram.png)
15+
16+
- Given a model, we do a full graph capture via TorchDynamo that is dynamic on the batch size (number of tokens)
17+
- vLLM then optionally splits and/or specializes this graph and then uses TorchInductor to compile each graph into a compiled artifact.
18+
This step may use vLLM custom Inductor passes to further optimize the graph.
19+
- The compiled artifact is saved to vLLM's compile cache so that it can be loaded in the future.
20+
- vLLM applies CUDAGraphs to reduce CPU overheads.
21+
22+
Things can go wrong in all four of these things. When something does go wrong, please try to isolate the subsystem
23+
that went wrong -- this will allow you to turn off the minimal number of things to keep reliability
24+
goals while minimizing impact to performance and also helps us (vLLM) when you open a bug report.
25+
26+
For more details on the design, please see the following talks:
27+
28+
- [PyTorch Conference 2025](https://youtu.be/1wV1ESbGrVQ?si=s1GqymUfwiwOrDTg&t=725)
29+
- [vLLM Office Hours #26](https://www.youtube.com/live/xLyxc7hxCJc?si=Xulo9pe53C6ywf0V&t=561)
30+
31+
## Use tlparse
32+
33+
Use tlparse to acquire torch.compile logs. These logs show all stages of the compilation process,
34+
including the fused kernels that torch.compile produces.
35+
If you can, we recommend sending these or pieces of these along with any bug reports --
36+
they are very helpful.
37+
38+
Install tlparse:
39+
40+
```sh
41+
pip install tlparse
42+
```
43+
44+
Usage (offline inference)
45+
46+
```sh
47+
TORCH_TRACE=~/trace_dir python my_script.py
48+
tlparse ~/trace_dir/<the_first_log_file>
49+
```
50+
51+
Usage (serving)
52+
53+
```sh
54+
TORCH_TRACE=~/trace_dir vllm serve
55+
# ctrl-c out of the server
56+
tlparse ~/trace_dir/<the_first_log_file>
57+
```
58+
59+
The `tlparse` command outputs some HTML files (perhaps into e.g. `./tl_out/index.html`).
60+
Open it to see the logs. It'll look something like the following:
61+
62+
![tlparse example](../assets/design/debug_vllm_compile/tlparse_inductor.png)
63+
64+
## Turn off vLLM-torch.compile integration
65+
66+
Pass `--enforce-eager` to turn off the vLLM-torch.compile integration and run entirely
67+
in eager mode. This includes turning off CUDAGraphs.
68+
69+
```sh
70+
# Online
71+
vllm serve --enforce-eager
72+
```
73+
74+
```py
75+
# Offline
76+
LLM(model, enforce_eager=True)
77+
```
78+
79+
To turn off just torch.compile, pass `mode = NONE` to the compilation config.
80+
(`-O` is short for `--compilation_config`):
81+
82+
```sh
83+
# Online
84+
vllm serve -O.mode=0
85+
```
86+
87+
```py
88+
# Offline
89+
from vllm.config.compilation import CompilationConfig, CompilationMode
90+
LLM(model, compilation_config=CompilationConfig(mode=CompilationMode.NONE))
91+
```
92+
93+
To turn off just CUDAGraphs, pass `cudagraph_mode = NONE`:
94+
95+
```sh
96+
# Online
97+
vllm serve -O.cudagraph_mode=NONE
98+
```
99+
100+
```py
101+
# Offline
102+
from vllm.config.compilation import CompilationConfig, CUDAGraphMode
103+
LLM(model, compilation_config=CompilationConfig(cudagraph_mode=CUDAGraphMode.NONE))
104+
```
105+
106+
## Debugging TorchDynamo
107+
108+
vLLM requires model code be capturable into a full graph via TorchDynamo (torch.compile's frontend).
109+
TorchDynamo does not support all of Python. It will error (in fullgraph mode) if it cannot support
110+
a feature (this is sometimes known as a graph break).
111+
112+
If you encounter a graph break, please [open an issue to pytorch/pytorch](https://github.com/pytorch/pytorch) so the PyTorch devs can prioritize.
113+
Then, try your best to rewrite the code to avoid the graph break.
114+
For more information, see this [Dynamo guide](https://docs.pytorch.org/docs/stable/compile/programming_model.dynamo_core_concepts.html).
115+
116+
## Debugging Dynamic Shape full graph capture
117+
118+
vLLM requires that the model's forward pass be capturable into a full graph that is dynamic
119+
on the batch size (i.e. the number of tokens). It (by default) compiles this one graph into
120+
one artifact and uses this artifact for all batch sizes.
121+
122+
If your code cannot be captured with Dynamic Shapes, you may see silent incorrectness,
123+
loud errors, or CUDA illegal memory accesses. For example, the following is not
124+
capturable into a single graph:
125+
126+
```py
127+
if data.size[0] % 128 == 0:
128+
foo(...)
129+
else:
130+
bar(...)
131+
```
132+
133+
This problem is easy to diagnose. Use tlparse and click on `compilation_metrics`:
134+
it will tell you symbolic constraints on the batch size. If there is any constraint
135+
that restricts the batch sizes, then we've got a problem.
136+
137+
![Bad tlparse example](../assets/design/debug_vllm_compile/dynamic_shapes.png)
138+
139+
To avoid this, please either:
140+
141+
1. avoid branching on the number of tokens
142+
2. wrap the branching logic into a custom operator. TorchDynamo does not
143+
trace into custom operators.
144+
145+
## Debugging TorchInductor
146+
147+
TorchInductor takes a captured graph and then compiles it down to some Python code
148+
that may call 1+ triton kernels. On rare (but unfortunate) occasions, it may
149+
produce an incorrect triton kernel. This may manifest as silent incorrectness,
150+
CUDA illegal memory accesses, or loud errors.
151+
152+
To debug if TorchInductor is at fault, you can disable it by passing `use_inductor=False`
153+
to the compilation config:
154+
155+
```sh
156+
# online
157+
vllm serve -O.use_inductor=False
158+
```
159+
160+
```py
161+
# offline
162+
LLM(compilation_config=CompilationConfig(use_inductor=False))
163+
```
164+
165+
If Inductor is at fault, [file a bug to PyTorch](https://github.com/pytorch/pytorch).
166+
If you're feeling adventurous, you can debug the triton kernels in the Inductor output code
167+
(that you can locate via using tlparse).
168+
169+
## Debugging vLLM-compile cache
170+
171+
vLLM built its own cache for torch.compile artifacts. The idea is that the artifacts
172+
can be compiled once and then reused after they have been compiled. This
173+
is a layer on top of [torch.compile's compiler cache](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html).
174+
175+
While torch.compile's compiler cache is rock-stable, vLLM's compiler cache is unfortunately
176+
not always correct. You can disable it via setting `VLLM_DISABLE_COMPILE_CACHE=1`.
177+
178+
You can also manually remove this cache.
179+
180+
- Remove vLLM's compile cache with `rm -rf ~/.cache/vllm` (look at logs to see if the location changed)
181+
- Remove torch.compile's built-in caches with `rm -rf /tmp/torchinductor_$(whoami)`
182+
183+
vLLM's cache is a mapping from cache key to a compiled artifact. vLLM computes
184+
the cache key via combining multiple factors (e.g. config flags and model name).
185+
If vLLM's compile cache is wrong, this usually means that a factor is missing.
186+
Please see [this example](https://github.com/vllm-project/vllm/blob/18b39828d90413d05d770dfd2e2f48304f4ca0eb/vllm/config/model.py#L310)
187+
of how vLLM computes part of the cache key.
188+
189+
## Debugging CUDAGraphs
190+
191+
CUDAGraphs is a feature that allows one to:
192+
193+
- Capture a callable that launches 1+ CUDA kernels into a CUDAGraph
194+
- Replay the CUDAGraph
195+
196+
The captured CUDAGraph contains all of the memory used during the capture process.
197+
The replay of the CUDAGraph reads and writes to exactly the same regions of memory.
198+
199+
This leads to some restrictions:
200+
201+
1. In order to use CUDAGraphs on new data, you'll need to copy the data into a buffer
202+
that the CUDAGraph is reading from
203+
2. CUDAGraphs only capture CUDA kernels, they don't capture work done on CPU.
204+
205+
vLLM uses the raw CUDAGraphs API, which is unsafe when used incorrectly.
206+
207+
To turn off just CUDAGraphs, pass `cudagraph_mode = NONE`:
208+
209+
```sh
210+
# Online
211+
vllm serve -O.cudagraph_mode=NONE
212+
```
213+
214+
```py
215+
# Offline
216+
from vllm.config.compilation import CompilationConfig, CUDAGraphMode
217+
LLM(model, compilation_config=CompilationConfig(cudagraph_mode=CUDAGraphMode.NONE))
218+
```

0 commit comments

Comments
 (0)