|
15 | 15 |
|
16 | 16 | import vllm.envs as envs |
17 | 17 | from vllm.compilation.counter import compilation_counter |
18 | | -from vllm.config import VllmConfig |
| 18 | +from vllm.config import VllmConfig, get_current_vllm_config |
19 | 19 | from vllm.utils.torch_utils import is_torch_equal_or_newer |
20 | 20 |
|
21 | 21 |
|
@@ -163,6 +163,19 @@ def get_inductor_factors() -> list[Any]: |
163 | 163 | return factors |
164 | 164 |
|
165 | 165 |
|
| 166 | +def is_compile_cache_enabled() -> bool: |
| 167 | + # TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches |
| 168 | + # with torch.compiler.config.force_disable_caches when minimum PyTorch |
| 169 | + # version reaches 2.10 |
| 170 | + return ( |
| 171 | + not envs.VLLM_DISABLE_COMPILE_CACHE |
| 172 | + and not torch._inductor.config.force_disable_caches |
| 173 | + and not ( |
| 174 | + get_current_vllm_config().compilation_config.inductor_compile_config.force_disable_caches |
| 175 | + ) |
| 176 | + ) |
| 177 | + |
| 178 | + |
166 | 179 | class InductorStandaloneAdaptor(CompilerInterface): |
167 | 180 | """ |
168 | 181 | The adaptor for the Inductor compiler. |
@@ -219,7 +232,8 @@ def compile( |
219 | 232 | # Save the compiled artifact to disk in the specified path |
220 | 233 | assert key is not None |
221 | 234 | path = os.path.join(self.cache_dir, key) |
222 | | - if not envs.VLLM_DISABLE_COMPILE_CACHE: |
| 235 | + |
| 236 | + if is_compile_cache_enabled(): |
223 | 237 | compiled_graph.save(path=path, format="unpacked") |
224 | 238 | compilation_counter.num_compiled_artifacts_saved += 1 |
225 | 239 | return compiled_graph, (key, path) |
@@ -469,10 +483,8 @@ def _get_shape_env() -> AlwaysHitShapeEnv: |
469 | 483 | config_patches=current_config, |
470 | 484 | ) |
471 | 485 |
|
472 | | - # We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch |
473 | | - # compilation cache. So turn off the checks if we disable the |
474 | | - # compilation cache. |
475 | | - if not envs.VLLM_DISABLE_COMPILE_CACHE: |
| 486 | + # Turn off the checks if we disable the compilation cache. |
| 487 | + if is_compile_cache_enabled(): |
476 | 488 | if hash_str is None: |
477 | 489 | raise RuntimeError( |
478 | 490 | "vLLM failed to compile the model. The most " |
|
0 commit comments