|
| 1 | +# How to debug the vLLM-torch.compile integration |
| 2 | + |
| 3 | +TL;DR: |
| 4 | + |
| 5 | +- use tlparse to acquire torch.compile logs. Include these logs in bug reports and/or support asks. |
| 6 | +- The vLLM-torch.compile integration is multiple pieces. vLLM exposes flags to turn off each piece: |
| 7 | + |
| 8 | +| Online Flag | Offline Flag | Result | |
| 9 | +|----------|----------|-------------| |
| 10 | +| --enforce-eager | enforce_eager=True | Turn off torch.compile and CUDAGraphs | |
| 11 | +| -O.mode=0 | mode=CompilationMode.NONE | Turn off torch.compile only | |
| 12 | +| -O.cudagraph_mode=NONE | compilation_config=CompilationConfig(mode=CompilationMode.NONE) | Turn off CUDAGraphs only | |
| 13 | +| -O.backend=eager | compilation_config=CompilationConfig(backend='eager') | Turn off TorchInductor | |
| 14 | + |
| 15 | +## vLLM-torch.compile overview |
| 16 | + |
| 17 | +To improve performance, vLLM leverages torch.compile and CUDAGraphs to speed things up. |
| 18 | +torch.compile generates optimized kernels for PyTorch code while CUDAGraphs eliminates overhead. |
| 19 | +Most notably, vLLM-compile is NOT torch.compile, it is a custom compiler built using internal PyTorch Compile APIs. |
| 20 | + |
| 21 | + |
| 22 | + |
| 23 | +- Given a model, we do a full graph capture via TorchDynamo that is dynamic on the batch size (number of tokens) |
| 24 | +- vLLM then optionally splits and/or specializes this graph and then uses TorchInductor to compile each graph into a compiled artifact. |
| 25 | +This step may use vLLM custom Inductor passes to further optimize the graph. |
| 26 | +- The compiled artifact is saved to vLLM's compile cache so that it can be loaded in the future. |
| 27 | +- vLLM applies CUDAGraphs to reduce CPU overheads. |
| 28 | + |
| 29 | +Things can go wrong in each of the four steps. When something does go wrong, please try to isolate the subsystem |
| 30 | +that went wrong -- this will allow you to turn off the minimal number of things to keep reliability |
| 31 | +goals while minimizing impact to performance and also helps us (vLLM) when you open a bug report. |
| 32 | + |
| 33 | +For more details on the design, please see the following resources: |
| 34 | + |
| 35 | +- [Introduction to vLLM-torch.compile blogpost](https://blog.vllm.ai/2025/08/20/torch-compile.html) |
| 36 | +- [vLLM-torch.compile integration design](https://docs.vllm.ai/en/latest/design/torch_compile.html) |
| 37 | +- [vLLM Office Hours #26](https://www.youtube.com/live/xLyxc7hxCJc?si=Xulo9pe53C6ywf0V&t=561) |
| 38 | +- [Talk at PyTorch Conference 2025](https://youtu.be/1wV1ESbGrVQ?si=s1GqymUfwiwOrDTg&t=725) |
| 39 | + |
| 40 | +## Use tlparse |
| 41 | + |
| 42 | +Use [tlparse](https://github.com/meta-pytorch/tlparse) to acquire torch.compile logs. These logs show all stages of the compilation process, |
| 43 | +including the fused kernels that torch.compile produces. |
| 44 | +If you can, we recommend sending these or pieces of these along with any bug reports -- |
| 45 | +they are very helpful. |
| 46 | + |
| 47 | +Install tlparse: |
| 48 | + |
| 49 | +```sh |
| 50 | +pip install tlparse |
| 51 | +``` |
| 52 | + |
| 53 | +Usage (offline inference) |
| 54 | + |
| 55 | +```sh |
| 56 | +TORCH_TRACE=~/trace_dir python my_script.py |
| 57 | +tlparse ~/trace_dir/<the_first_log_file> |
| 58 | +``` |
| 59 | + |
| 60 | +Usage (serving) |
| 61 | + |
| 62 | +```sh |
| 63 | +TORCH_TRACE=~/trace_dir vllm serve |
| 64 | +# ctrl-c out of the server |
| 65 | +tlparse ~/trace_dir/<the_first_log_file> |
| 66 | +``` |
| 67 | + |
| 68 | +The `tlparse` command outputs some HTML files (perhaps into e.g. `./tl_out/index.html`). |
| 69 | +Open it to see the logs. It'll look something like the following: |
| 70 | + |
| 71 | + |
| 72 | + |
| 73 | +## Turn off vLLM-torch.compile integration |
| 74 | + |
| 75 | +Pass `--enforce-eager` to turn off the vLLM-torch.compile integration and run entirely |
| 76 | +in eager mode. This includes turning off CUDAGraphs. |
| 77 | + |
| 78 | +```sh |
| 79 | +# Online |
| 80 | +vllm serve --enforce-eager |
| 81 | +``` |
| 82 | + |
| 83 | +```py |
| 84 | +# Offline |
| 85 | +LLM(model, enforce_eager=True) |
| 86 | +``` |
| 87 | + |
| 88 | +To turn off just torch.compile, pass `mode = NONE` to the compilation config. |
| 89 | +(`-O` is short for `--compilation_config`): |
| 90 | + |
| 91 | +```sh |
| 92 | +# Online |
| 93 | +vllm serve -O.mode=0 |
| 94 | +``` |
| 95 | + |
| 96 | +```py |
| 97 | +# Offline |
| 98 | +from vllm.config.compilation import CompilationConfig, CompilationMode |
| 99 | +LLM(model, compilation_config=CompilationConfig(mode=CompilationMode.NONE)) |
| 100 | +``` |
| 101 | + |
| 102 | +To turn off just CUDAGraphs, pass `cudagraph_mode = NONE`: |
| 103 | + |
| 104 | +```sh |
| 105 | +# Online |
| 106 | +vllm serve -O.cudagraph_mode=NONE |
| 107 | +``` |
| 108 | + |
| 109 | +```py |
| 110 | +# Offline |
| 111 | +from vllm.config.compilation import CompilationConfig, CUDAGraphMode |
| 112 | +LLM(model, compilation_config=CompilationConfig(cudagraph_mode=CUDAGraphMode.NONE)) |
| 113 | +``` |
| 114 | + |
| 115 | +## Debugging TorchDynamo |
| 116 | + |
| 117 | +vLLM requires model code be capturable into a full graph via TorchDynamo (torch.compile's frontend). |
| 118 | +TorchDynamo does not support all of Python. It will error (in fullgraph mode) if it cannot support |
| 119 | +a feature (this is sometimes known as a graph break). |
| 120 | + |
| 121 | +If you encounter a graph break, please [open an issue to pytorch/pytorch](https://github.com/pytorch/pytorch) so the PyTorch devs can prioritize. |
| 122 | +Then, try your best to rewrite the code to avoid the graph break. |
| 123 | +For more information, see this [Dynamo guide](https://docs.pytorch.org/docs/stable/compile/programming_model.dynamo_core_concepts.html). |
| 124 | + |
| 125 | +## Debugging Dynamic Shape full graph capture |
| 126 | + |
| 127 | +vLLM requires that the model's forward pass be capturable into a full graph that is dynamic |
| 128 | +on the batch size (i.e. the number of tokens). It (by default) compiles this one graph into |
| 129 | +one artifact and uses this artifact for all batch sizes. |
| 130 | + |
| 131 | +If your code cannot be captured with Dynamic Shapes, you may see silent incorrectness, |
| 132 | +loud errors, or CUDA illegal memory accesses. For example, the following is not |
| 133 | +capturable into a single graph: |
| 134 | + |
| 135 | +```py |
| 136 | +if data.size[0] % 128 == 0: |
| 137 | + foo(...) |
| 138 | +else: |
| 139 | + bar(...) |
| 140 | +``` |
| 141 | + |
| 142 | +This problem is easy to diagnose. Use tlparse and click on `compilation_metrics`: |
| 143 | +it will tell you symbolic constraints on the batch size. If there is any constraint |
| 144 | +that restricts the batch sizes, then we've got a problem. |
| 145 | + |
| 146 | + |
| 147 | + |
| 148 | +To avoid this, please either: |
| 149 | + |
| 150 | +1. avoid branching on the number of tokens |
| 151 | +2. wrap the branching logic into a custom operator. TorchDynamo does not |
| 152 | +trace into custom operators. |
| 153 | + |
| 154 | +## Debugging TorchInductor |
| 155 | + |
| 156 | +TorchInductor takes a captured graph and then compiles it down to some Python code |
| 157 | +that may call 1+ triton kernels. On rare (but unfortunate) occasions, it may |
| 158 | +produce an incorrect triton kernel. This may manifest as silent incorrectness, |
| 159 | +CUDA illegal memory accesses, or loud errors. |
| 160 | + |
| 161 | +To debug if TorchInductor is at fault, you can disable it by passing `backend='eager'` |
| 162 | +to the compilation config: |
| 163 | + |
| 164 | +```sh |
| 165 | +# online |
| 166 | +vllm serve -O.backend=eager |
| 167 | +``` |
| 168 | + |
| 169 | +```py |
| 170 | +# offline |
| 171 | +LLM(compilation_config=CompilationConfig(backend='eager')) |
| 172 | +``` |
| 173 | + |
| 174 | +If Inductor is at fault, [file a bug to PyTorch](https://github.com/pytorch/pytorch). |
| 175 | +If you're feeling adventurous, you can debug the triton kernels in the Inductor output code |
| 176 | +(that you can locate via using tlparse). |
| 177 | + |
| 178 | + |
| 179 | + |
| 180 | +You can also use `TORCH_LOGS=output_code <command>` to print the Inductor output code. |
| 181 | + |
| 182 | +### Editable TorchInductor code |
| 183 | + |
| 184 | +You can edit the TorchInductor code that gets run by setting `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked` |
| 185 | +or passing `-O.compile_cache_save_format=unpacked`. The default is `binary`, which means it is not editable. |
| 186 | + |
| 187 | +This is a useful technique: you can put breakpoints (e.g. `torch.distributed.breakpoint()`) |
| 188 | +and print statements in the output code. |
| 189 | + |
| 190 | +## Debugging vLLM-compile cache |
| 191 | + |
| 192 | +vLLM built its own cache for torch.compile artifacts. The idea is that the artifacts |
| 193 | +can be compiled once and then reused after they have been compiled. This |
| 194 | +is a layer on top of [torch.compile's compiler cache](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html). |
| 195 | + |
| 196 | +While torch.compile's compiler cache is rock-stable, vLLM's compiler cache is unfortunately |
| 197 | +not always correct. You can disable it via setting `VLLM_DISABLE_COMPILE_CACHE=1`. |
| 198 | + |
| 199 | +You can also manually remove this cache. |
| 200 | + |
| 201 | +- Remove vLLM's compile cache with `rm -rf ~/.cache/vllm` (look at logs to see if the location changed) |
| 202 | +- Remove torch.compile's built-in caches with `rm -rf /tmp/torchinductor_$(whoami)` |
| 203 | + |
| 204 | +vLLM's cache is a mapping from cache key to a compiled artifact. vLLM computes |
| 205 | +the cache key via combining multiple factors (e.g. config flags and model name). |
| 206 | +If vLLM's compile cache is wrong, this usually means that a factor is missing. |
| 207 | +Please see [this example](https://github.com/vllm-project/vllm/blob/18b39828d90413d05d770dfd2e2f48304f4ca0eb/vllm/config/model.py#L310) |
| 208 | +of how vLLM computes part of the cache key. |
| 209 | + |
| 210 | +## Debugging CUDAGraphs |
| 211 | + |
| 212 | +CUDAGraphs is a feature that allows one to: |
| 213 | + |
| 214 | +- Capture a callable that launches 1+ CUDA kernels into a CUDAGraph |
| 215 | +- Replay the CUDAGraph |
| 216 | + |
| 217 | +The captured CUDAGraph contains all of the memory used during the capture process. |
| 218 | +The replay of the CUDAGraph reads and writes to exactly the same regions of memory. |
| 219 | + |
| 220 | +This leads to some restrictions: |
| 221 | + |
| 222 | +1. In order to use CUDAGraphs on new data, you'll need to copy the data into a buffer |
| 223 | +that the CUDAGraph is reading from |
| 224 | +2. CUDAGraphs only capture CUDA kernels, they don't capture work done on CPU. |
| 225 | + |
| 226 | +vLLM uses the raw CUDAGraphs API, which is unsafe when used incorrectly. |
| 227 | + |
| 228 | +To turn off just CUDAGraphs, pass `cudagraph_mode = NONE`: |
| 229 | + |
| 230 | +```sh |
| 231 | +# Online |
| 232 | +vllm serve -O.cudagraph_mode=NONE |
| 233 | +``` |
| 234 | + |
| 235 | +```py |
| 236 | +# Offline |
| 237 | +from vllm.config.compilation import CompilationConfig, CUDAGraphMode |
| 238 | +LLM(model, compilation_config=CompilationConfig(cudagraph_mode=CUDAGraphMode.NONE)) |
| 239 | +``` |
0 commit comments