Skip to content

Commit 6dfdf90

Browse files
short refactor and addressing comments
- Concated the lazily. - refactored comments. - updated ignored factors - logged erors - persists if files don't already exist Srreyansh Sethi <srreyansh.sethi@gmail.com> Co-Authored-By: vnadathur <236933696+vnadathur@users.noreply.github.com> Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com>
1 parent ef0273a commit 6dfdf90

File tree

3 files changed

+66
-27
lines changed

3 files changed

+66
-27
lines changed

tests/config/test_config_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def _expected_path(p_str: str = ".") -> str:
2222
import pathlib
2323

2424
p = pathlib.Path(p_str)
25-
return str(p.expanduser().resolve())
25+
return p.expanduser().resolve().as_posix()
2626

2727

2828
# Minimal dataclass to test get_hash_factors.

vllm/compilation/backends.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import ast
5+
import logging
56
import dataclasses
67
import hashlib
78
import json
@@ -531,9 +532,17 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
531532
config_hash = vllm_config.compute_hash()
532533
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
533534
forward_code_files = list(sorted(self.compilation_config.traced_files))
535+
class _LazyJoin:
536+
def __init__(self, seq: list[str], sep: str = "\n"):
537+
self.seq = seq
538+
self.sep = sep
539+
540+
def __str__(self) -> str:
541+
return self.sep.join(self.seq)
542+
534543
logger.debug(
535544
"Traced files (to be considered for compilation cache):\n%s",
536-
"\n".join(forward_code_files),
545+
_LazyJoin(forward_code_files),
537546
)
538547
hash_content = []
539548
for filepath in forward_code_files:
@@ -558,7 +567,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
558567
# graph.
559568
factors = [env_hash, config_hash, code_hash, compiler_hash]
560569
# Use SHA-256 for cache key hashing to be consistent across
561-
# compute_hash functions. Truncate for a short, stable dir name.
570+
# compute_hash functions. Truncate for a short cache dir name.
562571
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10]
563572
cache_dir = os.path.join(
564573
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key
@@ -600,27 +609,36 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
600609

601610
# Persist and log only hash-relevant factors together.
602611
try:
603-
logger.debug(
604-
"Compile env factors (raw):\n%s\nVllm config hash: %s",
605-
pprint.pformat(env_factors, width=120),
606-
config_hash,
607-
)
608-
meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
609-
with open(meta_path, "w") as f:
610-
json.dump(
611-
{
612-
"env": env_factors, # raw factors used for env_hash
613-
"config_hash": config_hash,
614-
"code_hash": code_hash,
615-
"compiler_hash": compiler_hash,
616-
},
617-
f,
618-
indent=2,
619-
sort_keys=True,
612+
if logger.isEnabledFor(logging.DEBUG):
613+
logger.debug(
614+
"Compile env factors (raw):\n%s\nVllm config hash: %s",
615+
pprint.pformat(env_factors, width=120),
616+
config_hash,
620617
)
618+
meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
619+
if not os.path.exists(meta_path):
620+
with open(meta_path, "w") as f:
621+
json.dump(
622+
{
623+
"env": env_factors, # raw factors used for env_hash
624+
"config_hash": config_hash,
625+
"code_hash": code_hash,
626+
"compiler_hash": compiler_hash,
627+
},
628+
f,
629+
indent=2,
630+
sort_keys=True,
631+
)
621632
except Exception:
622633
# Best-effort only; metadata write failures are non-fatal.
623-
pass
634+
logger.warning(
635+
(
636+
"Could not write compile cache metadata at %s; continuing without "
637+
"metadata. Compiled cache remains valid; diagnostics may be limited."
638+
),
639+
local_cache_dir,
640+
exc_info=True,
641+
)
624642

625643
# when dynamo calls the backend, it means the bytecode
626644
# transform and analysis are done
@@ -727,4 +745,4 @@ def copy_and_call(*args):
727745
list_args[index] = static_tensor
728746
return self.split_gm(*list_args)
729747

730-
return copy_and_call
748+
return copy_and_call

vllm/envs.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,24 +1396,45 @@ def set_vllm_use_v1(use_v1: bool):
13961396

13971397
def compile_factors() -> dict[str, object]:
13981398
"""
1399-
Return raw env factors for compile hashing using the legacy opt-out
1400-
strategy: include all known env vars except a minimal set that clearly
1401-
does not affect compiled graph structure or kernel routing.
1399+
Return environment variables used to compute the compile cache key.
1400+
This includes all known vLLM environment variables.
1401+
This then excludes variables that cannot affect graph structure, codegen, or kernel
1402+
selection (see ignored_factors)
14021403
"""
14031404

14041405
ignored_factors: set[str] = {
14051406
"MAX_JOBS",
14061407
"VLLM_RPC_BASE_PATH",
14071408
"VLLM_USE_MODELSCOPE",
14081409
"VLLM_RINGBUFFER_WARNING_INTERVAL",
1410+
"VLLM_DEBUG_DUMP_PATH",
1411+
"VLLM_PORT",
1412+
"VLLM_CACHE_ROOT",
14091413
"LD_LIBRARY_PATH",
1410-
"VLLM_PATTERN_MATCH_DEBUG",
14111414
"VLLM_SERVER_DEV_MODE",
14121415
"VLLM_DP_MASTER_IP",
14131416
"VLLM_DP_MASTER_PORT",
14141417
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS",
14151418
"VLLM_CI_USE_S3",
14161419
"VLLM_MODEL_REDIRECT_PATH",
1420+
"VLLM_HOST_IP",
1421+
"S3_ACCESS_KEY_ID", "S3_SECRET_ACCESS_KEY", "S3_ENDPOINT_URL",
1422+
"VLLM_USAGE_STATS_SERVER", "VLLM_NO_USAGE_STATS", "VLLM_DO_NOT_TRACK",
1423+
"VLLM_LOGGING_LEVEL", "VLLM_LOGGING_PREFIX",
1424+
"VLLM_LOGGING_STREAM", "VLLM_LOGGING_CONFIG_PATH",
1425+
"VLLM_LOG_STATS_INTERVAL",
1426+
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE",
1427+
"VLLM_TUNED_CONFIG_FOLDER",
1428+
"VLLM_ENGINE_ITERATION_TIMEOUT_S",
1429+
"VLLM_HTTP_TIMEOUT_KEEP_ALIVE",
1430+
"VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS",
1431+
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH",
1432+
"VLLM_SLEEP_WHEN_IDLE",
1433+
"VLLM_IMAGE_FETCH_TIMEOUT", "VLLM_VIDEO_FETCH_TIMEOUT",
1434+
"VLLM_AUDIO_FETCH_TIMEOUT", "VLLM_MEDIA_URL_ALLOW_REDIRECTS",
1435+
"VLLM_MEDIA_LOADING_THREAD_COUNT",
1436+
"VLLM_MAX_AUDIO_CLIP_FILESIZE_MB",
1437+
"VLLM_VIDEO_LOADER_BACKEND",
14171438
}
14181439

14191440
from vllm.config.utils import normalize_value
@@ -1427,4 +1448,4 @@ def compile_factors() -> dict[str, object]:
14271448

14281449
factors[factor] = normalize_value(raw)
14291450

1430-
return factors
1451+
return factors

0 commit comments

Comments
 (0)