Skip to content

Commit 60b709e

Browse files
committed
<Replace this line with a title. Use 1 line only, 67 chars or less>
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
1 parent b2e65cb commit 60b709e

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

vllm/compilation/backends.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
EagerAdaptor,
3434
InductorAdaptor,
3535
InductorStandaloneAdaptor,
36+
is_compile_cache_enabled,
3637
)
3738
from .counter import compilation_counter
3839
from .inductor_pass import InductorPass
@@ -238,7 +239,7 @@ def compile(
238239
assert compiled_graph is not None, "Failed to compile the graph"
239240

240241
# store the artifact in the cache
241-
if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None:
242+
if is_compile_cache_enabled() and handle is not None:
242243
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
243244
compilation_counter.num_cache_entries_updated += 1
244245
self.is_cache_updated = True
@@ -610,7 +611,7 @@ def __call__(
610611
os.makedirs(local_cache_dir, exist_ok=True)
611612
self.compilation_config.local_cache_dir = local_cache_dir
612613

613-
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
614+
disable_cache = not is_compile_cache_enabled()
614615

615616
if disable_cache:
616617
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")

vllm/compilation/compiler_interface.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import vllm.envs as envs
1717
from vllm.compilation.counter import compilation_counter
18-
from vllm.config import VllmConfig
18+
from vllm.config import VllmConfig, get_current_vllm_config
1919
from vllm.utils.torch_utils import is_torch_equal_or_newer
2020

2121

@@ -163,6 +163,19 @@ def get_inductor_factors() -> list[Any]:
163163
return factors
164164

165165

166+
def is_compile_cache_enabled() -> bool:
167+
# TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches
168+
# with torch.compiler.config.force_disable_caches when minimum PyTorch
169+
# version reaches 2.10
170+
return (
171+
not envs.VLLM_DISABLE_COMPILE_CACHE
172+
and not torch._inductor.config.force_disable_caches
173+
and not (
174+
get_current_vllm_config().compilation_config.inductor_compile_config.force_disable_caches
175+
)
176+
)
177+
178+
166179
class InductorStandaloneAdaptor(CompilerInterface):
167180
"""
168181
The adaptor for the Inductor compiler.
@@ -219,7 +232,8 @@ def compile(
219232
# Save the compiled artifact to disk in the specified path
220233
assert key is not None
221234
path = os.path.join(self.cache_dir, key)
222-
if not envs.VLLM_DISABLE_COMPILE_CACHE:
235+
236+
if is_compile_cache_enabled():
223237
compiled_graph.save(path=path, format="unpacked")
224238
compilation_counter.num_compiled_artifacts_saved += 1
225239
return compiled_graph, (key, path)
@@ -469,10 +483,8 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
469483
config_patches=current_config,
470484
)
471485

472-
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
473-
# compilation cache. So turn off the checks if we disable the
474-
# compilation cache.
475-
if not envs.VLLM_DISABLE_COMPILE_CACHE:
486+
# Turn off the checks if we disable the compilation cache.
487+
if is_compile_cache_enabled():
476488
if hash_str is None:
477489
raise RuntimeError(
478490
"vLLM failed to compile the model. The most "

vllm/compilation/decorators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ def patched_inline_call(self_):
447447
InliningInstructionTranslator, "inline_call_", patched_inline_call
448448
),
449449
torch._dynamo.config.patch(**dynamo_config_patches),
450+
set_current_vllm_config(self.vllm_config),
450451
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
451452
_torch27_patch_tensor_subclasses(),
452453
):

0 commit comments

Comments
 (0)