Skip to content

Commit c0083f3

Browse files
committed
[Bugfix] vLLM should check TorchInductor config for compile cache enablement status
Summary: vLLM should not assume compile cache is enabled when VLLM_DISABLE_COMPILE_CACHE=0. Users may use TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 to disable compile cache at PyTorch level, effectively making it impossible for vLLM compile cache to function. Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
1 parent d34f5fe commit c0083f3

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

vllm/compilation/backends.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
EagerAdaptor,
3434
InductorAdaptor,
3535
InductorStandaloneAdaptor,
36+
is_compile_cache_enabled,
3637
)
3738
from .counter import compilation_counter
3839
from .inductor_pass import InductorPass
@@ -238,7 +239,7 @@ def compile(
238239
assert compiled_graph is not None, "Failed to compile the graph"
239240

240241
# store the artifact in the cache
241-
if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None:
242+
if is_compile_cache_enabled() and handle is not None:
242243
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
243244
compilation_counter.num_cache_entries_updated += 1
244245
self.is_cache_updated = True
@@ -610,7 +611,7 @@ def __call__(
610611
os.makedirs(local_cache_dir, exist_ok=True)
611612
self.compilation_config.local_cache_dir = local_cache_dir
612613

613-
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
614+
disable_cache = not is_compile_cache_enabled()
614615

615616
if disable_cache:
616617
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")

vllm/compilation/compiler_interface.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,16 @@ def get_inductor_factors() -> list[Any]:
163163
return factors
164164

165165

166+
def is_compile_cache_enabled() -> bool:
167+
# TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches
168+
# with torch.compiler.config.force_disable_caches when minimum PyTorch
169+
# version reaches 2.9
170+
return (
171+
not envs.VLLM_DISABLE_COMPILE_CACHE
172+
and not torch._inductor.config.force_disable_caches
173+
)
174+
175+
166176
class InductorStandaloneAdaptor(CompilerInterface):
167177
"""
168178
The adaptor for the Inductor compiler.
@@ -219,7 +229,8 @@ def compile(
219229
# Save the compiled artifact to disk in the specified path
220230
assert key is not None
221231
path = os.path.join(self.cache_dir, key)
222-
if not envs.VLLM_DISABLE_COMPILE_CACHE:
232+
233+
if is_compile_cache_enabled():
223234
compiled_graph.save(path=path, format="unpacked")
224235
compilation_counter.num_compiled_artifacts_saved += 1
225236
return compiled_graph, (key, path)
@@ -469,10 +480,8 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
469480
config_patches=current_config,
470481
)
471482

472-
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
473-
# compilation cache. So turn off the checks if we disable the
474-
# compilation cache.
475-
if not envs.VLLM_DISABLE_COMPILE_CACHE:
483+
# Turn off the checks if we disable the compilation cache.
484+
if is_compile_cache_enabled():
476485
if hash_str is None:
477486
raise RuntimeError(
478487
"vLLM failed to compile the model. The most "

0 commit comments

Comments
 (0)