diff --git a/tests/config/test_config_utils.py b/tests/config/test_config_utils.py new file mode 100644 index 000000000000..9b2081b4d38e --- /dev/null +++ b/tests/config/test_config_utils.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from enum import Enum + +import pytest + +from vllm.config.utils import get_hash_factors, hash_factors, normalize_value + +# Helpers + + +def endswith_fqname(obj, suffix: str) -> bool: + # normalize_value(type) returns fully-qualified name + # Compare suffix to avoid brittle import paths. + out = normalize_value(obj) + return isinstance(out, str) and out.endswith(suffix) + + +def expected_path(p_str: str = ".") -> str: + import pathlib + + p = pathlib.Path(p_str) + return p.expanduser().resolve().as_posix() + + +# Minimal dataclass to test get_hash_factors. +# Avoid importing heavy vLLM configs. +@dataclass +class SimpleConfig: + a: object + b: object | None = None + + +class DummyLogprobsMode(Enum): + RAW_LOGITS = "raw_logits" + + +def test_hash_factors_deterministic(): + """Test that hash_factors produces consistent SHA-256 hashes""" + factors = {"a": 1, "b": "test"} + hash1 = hash_factors(factors) + hash2 = hash_factors(factors) + + assert hash1 == hash2 + assert len(hash1) == 64 + assert all(c in "0123456789abcdef" for c in hash1) + + +@pytest.mark.parametrize( + "inp, expected", + [ + (None, None), + (True, True), + (1, 1), + (1.0, 1.0), + ("x", "x"), + (b"ab", "6162"), + (bytearray(b"ab"), "6162"), + ([1, 2], (1, 2)), + ({"b": 2, "a": 1}, (("a", 1), ("b", 2))), + ], +) +def test_normalize_value_matrix(inp, expected): + """Parametric input→expected normalization table.""" + assert normalize_value(inp) == expected + + +def test_normalize_value_enum(): + # Enums normalize to (module.QualName, value). + # DummyLogprobsMode uses a string payload. + out = normalize_value(DummyLogprobsMode.RAW_LOGITS) + assert isinstance(out, tuple) + assert out[0].endswith("DummyLogprobsMode") + # Expect string payload 'raw_logits'. + assert out[1] == "raw_logits" + + +def test_normalize_value_set_order_insensitive(): + # Sets are unordered; normalize_value sorts elements for determinism. + assert normalize_value({3, 1, 2}) == normalize_value({1, 2, 3}) + + +def test_normalize_value_path_normalization(): + from pathlib import Path # local import to avoid global dependency + + # Paths expand/resolve to absolute strings. + # Stabilizes hashing across working dirs. + assert normalize_value(Path(".")) == expected_path(".") + + +def test_normalize_value_uuid_and_to_json(): + # Objects may normalize via uuid() or to_json_string(). + class HasUUID: + def uuid(self): + return "test-uuid" + + class ToJson: + def to_json_string(self): + return '{"x":1}' + + assert normalize_value(HasUUID()) == "test-uuid" + assert normalize_value(ToJson()) == '{"x":1}' + + +@pytest.mark.parametrize( + "bad", + [ + (lambda x: x), + (type("CallableInstance", (), {"__call__": lambda self: 0}))(), + (lambda: (lambda: 0))(), # nested function instance + ], +) +def test_error_cases(bad): + """Inputs expected to raise TypeError.""" + # Reject functions/lambdas/callable instances + # to avoid under-hashing. + with pytest.raises(TypeError): + normalize_value(bad) + + +def test_enum_vs_int_disambiguation(): + # int stays primitive + nf_int = normalize_value(1) + assert nf_int == 1 + + # enum becomes ("module.QualName", value) + nf_enum = normalize_value(DummyLogprobsMode.RAW_LOGITS) + assert isinstance(nf_enum, tuple) and len(nf_enum) == 2 + enum_type, enum_val = nf_enum + assert enum_type.endswith(".DummyLogprobsMode") + assert enum_val == "raw_logits" + + # Build factor dicts from configs with int vs enum + f_int = get_hash_factors(SimpleConfig(1), set()) + f_enum = get_hash_factors(SimpleConfig(DummyLogprobsMode.RAW_LOGITS), set()) + # The int case remains a primitive value + assert f_int["a"] == 1 + # The enum case becomes a tagged tuple ("module.QualName", "raw_logits") + assert isinstance(f_enum["a"], tuple) and f_enum["a"][1] == "raw_logits" + # Factor dicts must differ so we don't collide primitives with Enums. + assert f_int != f_enum + # Hash digests must differ correspondingly + assert hash_factors(f_int) != hash_factors(f_enum) + + # Hash functions produce stable hex strings + h_int = hash_factors(f_int) + h_enum = hash_factors(f_enum) + assert isinstance(h_int, str) and len(h_int) == 64 + assert isinstance(h_enum, str) and len(h_enum) == 64 + + +def test_classes_are_types(): + """Types normalize to FQNs; include real vLLM types.""" + # Only classes allowed; functions/lambdas are rejected. + # Canonical form is the fully-qualified name. + assert isinstance(normalize_value(str), str) + + class LocalDummy: + pass + + assert endswith_fqname(LocalDummy, ".LocalDummy") diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index be69075f94f0..91a798ffa628 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -4,11 +4,13 @@ import ast import dataclasses import hashlib +import json import os import pprint import time from collections.abc import Callable, Sequence from contextlib import contextmanager +from functools import partial from typing import Any import torch @@ -22,7 +24,9 @@ should_split, ) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig +from vllm.config.utils import hash_factors from vllm.logger import init_logger +from vllm.logging_utils import lazy from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -567,35 +571,47 @@ def configure_post_pass(self): def __call__( self, graph: fx.GraphModule, example_inputs ) -> VllmSerializableFunction: - from .caching import _compute_code_hash, compilation_config_hash_factors - vllm_config = self.vllm_config + # Minimal hashing here with existing utilities, reused below. + + env_factors = envs.compile_factors() + env_hash = hash_factors(env_factors) + # Compute config/compiler/code hashes once and reuse + config_hash = vllm_config.compute_hash() + compiler_hash = self.compiler_manager.compute_hash(vllm_config) + forward_code_files = list(sorted(self.compilation_config.traced_files)) + + logger.debug( + "Traced files (to be considered for compilation cache):\n%s", + lazy(lambda: "\n".join(forward_code_files)), + ) + hash_content = [] + for filepath in forward_code_files: + hash_content.append(filepath) + if filepath == "": + # This means the function was dynamically generated, with + # e.g. exec(). We can't actually check these. + continue + try: + with open(filepath) as f: + hash_content.append(f.read()) + except Exception: + logger.warning("Failed to read file %s", filepath) + continue + code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest() + # Clear after consumption + self.compilation_config.traced_files.clear() if not self.compilation_config.cache_dir: # no provided cache dir, generate one based on the known factors # that affects the compilation. if none of the factors change, # the cache dir will be the same so that we can reuse the compiled # graph. - - factors = compilation_config_hash_factors(vllm_config) - # 2. factors come from the code files that are traced by Dynamo ( - # it mainly summarizes how the model is used in forward pass) - code_hash = _compute_code_hash(self.compilation_config.traced_files) - self.compilation_config.traced_files.clear() - factors.append(code_hash) - - # 3. compiler hash - compiler_hash = self.compiler_manager.compute_hash(vllm_config) - factors.append(compiler_hash) - - # combine all factors to generate the cache dir - hash_key = hashlib.md5( - str(factors).encode(), usedforsecurity=False - ).hexdigest()[:10] - + factors = [env_hash, config_hash, code_hash, compiler_hash] + # Use SHA-256 for cache key hashing to be consistent across + # compute_hash functions. Truncate for a short cache dir name. + hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10] cache_dir = os.path.join( - envs.VLLM_CACHE_ROOT, - "torch_compile_cache", - hash_key, + envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key ) self.compilation_config.cache_dir = cache_dir @@ -625,6 +641,50 @@ def __call__( local_cache_dir, disable_cache, self.prefix ) + # Reuses existing cache key + + logger.debug( + "torch.compile cache factors: env=%s cfg=%s comp=%s code=%s dir=%s", + env_hash, + config_hash, + compiler_hash, + code_hash, + local_cache_dir, + ) + + # Persist and log only hash-relevant factors together. + try: + logger.debug( + "Compile env factors (raw):\n%s\nVllm config hash: %s", + lazy(partial(pprint.pformat, env_factors, width=120)), + config_hash, + ) + meta_path = os.path.join(local_cache_dir, "cache_key_factors.json") + if not os.path.exists(meta_path): + with open(meta_path, "w") as f: + json.dump( + { + "env": env_factors, # raw factors used for env_hash + "config_hash": config_hash, + "code_hash": code_hash, + "compiler_hash": compiler_hash, + }, + f, + indent=2, + sort_keys=True, + ) + except Exception: + # Best-effort only; metadata write failures are non-fatal. + logger.warning( + ( + "Could not write compile cache metadata at %s; continuing without " + "metadata. Compiled cache remains valid; diagnostics may be " + "limited." + ), + local_cache_dir, + exc_info=True, + ) + # when dynamo calls the backend, it means the bytecode # transform and analysis are done compilation_counter.num_graphs_seen += 1 diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 0c2210d72ce0..da9e95dfb3a9 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -127,7 +127,7 @@ def uuid(self): affects compilation caching. Its uuid depends on the UUIDs of all dependent passes and the pass config. See InductorPass for more info. """ - state = {"pass_config": self.pass_config.uuid(), "passes": []} + state = {"pass_config": self.pass_config.compute_hash(), "passes": []} for pass_ in self.passes: state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 864cf1be81b2..2652c7c06ad0 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib from dataclasses import field from typing import TYPE_CHECKING, Any, Literal @@ -160,13 +159,29 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - factors: list[Any] = [] - factors.append(self.cache_dtype) - factors.append(self.mamba_cache_dtype) - factors.append(self.mamba_ssm_cache_dtype) - # `cpu_offload_gb` does not use `torch.compile` yet. - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() - return hash_str + ignored_factors = { + # Runtime/derived knobs that don't affect compiled graph shape + "gpu_memory_utilization", + "swap_space", + "is_attention_free", + "num_gpu_blocks_override", + "enable_prefix_caching", + "prefix_caching_hash_algo", + # `cpu_offload_gb` does not use `torch.compile` yet. + "cpu_offload_gb", + "cpu_kvcache_space_bytes", + "mamba_page_size_padded", + # Post-init/derived counters + "num_gpu_blocks", + "num_cpu_blocks", + # WIP feature toggle not impacting compiled graph shape + "kv_sharing_fast_prefill", + } + + from vllm.config.utils import get_hash_factors, hash_factors + + factors = get_hash_factors(self, ignored_factors) + return hash_factors(factors) def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index e1d60ee84d89..00254a1cfce7 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import enum -import hashlib from collections import Counter from collections.abc import Callable from dataclasses import asdict, field @@ -159,7 +158,7 @@ def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: current_platform.get_device_capability().to_int(), {} ) - def uuid(self): + def compute_hash(self) -> str: """ Produces a hash unique to the pass configuration. Any new fields that affect compilation should be added to the hash. @@ -537,18 +536,27 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - factors: list[Any] = [] - factors.append(self.mode) - factors.append(self.backend) - factors.append(self.custom_ops) - factors.append(self.splitting_ops) - factors.append(self.use_inductor) - factors.append(self.use_inductor_graph_partition) - factors.append(self.inductor_compile_config) - factors.append(self.inductor_passes) - factors.append(self.pass_config.uuid()) - factors.append(self.compile_cache_save_format) - return hashlib.sha256(str(factors).encode()).hexdigest() + # Opt-out: default-include declared fields; keep a tiny exclude set; + # normalize types; keep SHA-256. For nested opaque configs, include a + # stable identifier (e.g., pass_config.compute_hash()) instead of object id. + + ignored_factors = { + # Paths/dirs and runtime/metrics that don’t affect compiled graph + "debug_dump_path", + "cache_dir", + "local_cache_dir", + "bs_to_padded_graph_size", + "traced_files", + "compilation_time", + "static_forward_context", + "pass_config", # handled separately below + } + + from vllm.config.utils import get_hash_factors, hash_factors + + factors = get_hash_factors(self, ignored_factors) + factors["pass_config"] = self.pass_config.compute_hash() + return hash_factors(factors) def __repr__(self) -> str: exclude = { diff --git a/vllm/config/model.py b/vllm/config/model.py index 6ce91ebb87b9..2458c72aa57f 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib -import json import warnings from collections.abc import Callable from dataclasses import InitVar, field @@ -18,7 +16,7 @@ from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import RunnerType -from vllm.config.utils import assert_hashable, config, getattr_iter +from vllm.config.utils import config, getattr_iter from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.transformers_utils.config import ( @@ -319,50 +317,50 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - factors: list[Any] = [] - factors.append(self.model) - factors.append(self.dtype) - factors.append(self.quantization) - factors.append(self.revision) - factors.append(self.code_revision) - factors.append(self.max_model_len) - factors.append(self.max_logprobs) - factors.append(self.disable_sliding_window) - factors.append(self.trust_remote_code) - factors.append(self.generation_config) - factors.append(self.model_impl) - factors.append(self.override_generation_config) - factors.append(self.video_pruning_rate) - factors.append(self.enable_prompt_embeds) - - # hf_config can control how the model looks! - try: - hf_config_json = self.hf_config.to_json_string(use_diff=False) - except TypeError: - from transformers import PretrainedConfig - - from vllm.utils.jsontree import json_map_leaves - - # Handle nested HF configs with unserializable values gracefully - hf_config_json = ( - json.dumps( - json_map_leaves( - lambda v: v.to_dict() - if isinstance(v, PretrainedConfig) - else str(v), - self.hf_config.to_dict(), - ), - indent=2, - sort_keys=True, - ) - + "\n" - ) - - factors.append(hf_config_json) - - str_factors = str(factors) - assert_hashable(str_factors) - return hashlib.sha256(str(factors).encode()).hexdigest() + ignored_factors = { + "runner", + "convert", + "task", + "tokenizer", + "tokenizer_mode", + "seed", + "hf_config_path", + "allowed_local_media_path", + "allowed_media_domains", + "tokenizer_revision", + "spec_target_max_model_len", + "enforce_eager", + "logprobs_mode", + "disable_cascade_attn", + "skip_tokenizer_init", + "enable_prompt_embeds", + "served_model_name", + "config_format", + "hf_token", + "hf_overrides", + "logits_processor_pattern", + "enable_sleep_mode", + "override_attention_dtype", + "logits_processors", + "io_processor_plugin", + "pooler_config", + "override_pooler_config", + "multimodal_config", + "limit_mm_per_prompt", + "media_io_kwargs", + "mm_processor_kwargs", + "mm_processor_cache_gb", + "mm_processor_cache_type", + "mm_shm_cache_max_object_size_mb", + "mm_encoder_tp_mode", + "interleave_mm_strings", + "skip_mm_profiling", + } + + from vllm.config.utils import get_hash_factors, hash_factors + + factors = get_hash_factors(self, ignored_factors) + return hash_factors(factors) def _update_nested( self, diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index b19c8beeae3d..8c7ef8089e85 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib import os from typing import TYPE_CHECKING, Any, Literal @@ -419,19 +418,33 @@ def compute_hash(self): This hash is also used for DP worker configuration validation to prevent hangs from mismatched collective communication patterns. """ - factors: list[Any] = [] - factors.append(self.pipeline_parallel_size) - factors.append(self.tensor_parallel_size) - factors.append(self.enable_expert_parallel) - factors.append(self.data_parallel_size) - factors.append(self.all2all_backend) - factors.append(self.enable_eplb) - if self.enable_eplb: - factors.append(self.eplb_config.log_balancedness) - factors.append(self.eplb_config.window_size) - factors.append(self.eplb_config.step_interval) - factors.append(self.eplb_config.num_redundant_experts) - return hashlib.sha256(str(factors).encode()).hexdigest() + ignored_factors = { + # Derived/runtime topology, networking, or launch details + "data_parallel_rank", + "data_parallel_rank_local", + "data_parallel_master_ip", + "data_parallel_rpc_port", + "data_parallel_master_port", + "data_parallel_backend", + "data_parallel_external_lb", + "data_parallel_hybrid_lb", + "max_parallel_loading_workers", + "disable_custom_all_reduce", + "ray_workers_use_nsight", + "ray_runtime_env", + "placement_group", + "distributed_executor_backend", + "worker_cls", + "sd_worker_cls", + "worker_extension_cls", + } + + from vllm.config.utils import get_hash_factors, hash_factors + + factors = get_hash_factors(self, ignored_factors) + # Explicitly include backend affecting env factor as before + factors["VLLM_ALL2ALL_BACKEND"] = str(envs.VLLM_ALL2ALL_BACKEND) + return hash_factors(factors) def __post_init__(self) -> None: # Set all2all_backend from env var if not specified, with deprecation warning diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 7e0878d96bbd..8140db33880e 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -3,14 +3,19 @@ """Utility functions for vLLM config dataclasses.""" import ast +import enum +import hashlib import inspect +import json +import pathlib import textwrap -from collections.abc import Iterable +from collections.abc import Iterable, Mapping, Sequence, Set from dataclasses import MISSING, Field, field, fields, is_dataclass, replace from itertools import pairwise from typing import TYPE_CHECKING, Any, Protocol, TypeVar import regex as re +import torch from pydantic.fields import FieldInfo from typing_extensions import runtime_checkable @@ -176,3 +181,107 @@ def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT: ) processed_overrides[field_name] = value return replace(config, **processed_overrides) + + +def normalize_value(x): + """Return a stable, JSON-serializable canonical form for hashing. + Order: primitives, special types (Enum, callable, torch.dtype, Path), then + generic containers (Mapping/Set/Sequence) with recursion. + """ + # Fast path + if x is None or isinstance(x, (bool, int, float, str)): + return x + + # Enums: tag with FQN to avoid primitive collisions. + # Ex: Enum(1) vs int(1) -> ("module.QualName", value). + if isinstance(x, enum.Enum): + enum_type = f"{x.__class__.__module__}.{x.__class__.__qualname__}" + return (enum_type, normalize_value(x.value)) + + # Classes (types) are accepted and canonicalized by their fully-qualified + # name (module.qualname) for a stable identifier. + # Instances are only accepted if they expose uuid(); otherwise they are + # rejected to avoid under-hashing object state. + + # Callables: accept classes only; reject funcs/lambdas/methods. + # Used by LogitsProcessor types and ModelConfig.hf_overrides. + if isinstance(x, type): + module = getattr(x, "__module__", "") + qual = getattr(x, "__qualname__", getattr(x, "__name__", "")) + return ".".join([p for p in (module, qual) if p]) or repr(x) + + # Prefer stable uuid identifiers for objects that provide them, even if + # they are callable instances (e.g., InductorPass wrappers). + if hasattr(x, "uuid") and callable(getattr(x, "uuid", None)): + return x.uuid() + + if callable(x): + raise TypeError("normalize_value: function or callable instance unsupported") + + # Torch dtype: stringify (torch.float64 -> "torch.float64"). + # Collisions with same literal ok; tag as ("torch.dtype", str(x)). + if isinstance(x, torch.dtype): + return str(x) + + # Bytes + if isinstance(x, (bytes, bytearray)): + return x.hex() + + # Paths (canonicalize) + if isinstance(x, pathlib.Path): + try: + return str(x.expanduser().resolve()) + except Exception: + return str(x) + + # Dataclasses: represent as (FQN, sorted(field,value) tuple) for stability. + if is_dataclass(x): + type_fqn = f"{x.__class__.__module__}.{x.__class__.__qualname__}" + items = tuple( + (f.name, normalize_value(getattr(x, f.name))) + for f in sorted(fields(x), key=lambda f: f.name) + ) + return (type_fqn, items) + + # Containers (generic) + if isinstance(x, Mapping): + return tuple(sorted((str(k), normalize_value(v)) for k, v in x.items())) + if isinstance(x, Set): + return tuple(sorted(repr(normalize_value(v)) for v in x)) + if isinstance(x, Sequence) and not isinstance(x, (str, bytes, bytearray)): + return tuple(normalize_value(v) for v in x) + + # PretrainedConfig + if hasattr(x, "to_json_string") and callable(x.to_json_string): + return x.to_json_string() + + # Unsupported type: e.g., modules, generators, open files, or objects + # without a stable JSON/UUID representation. Hard-error to avoid + # under-hashing. + raise TypeError(f"normalize_value: unsupported type '{type(x).__name__}'") + + +def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, object]: + """Gets the factors used for hashing a config class. + - Includes all dataclass fields not in `ignored_factors`. + - Errors on non-normalizable values. + """ + factors: dict[str, object] = {} + for dc_field in fields(config): + factor = dc_field.name + if factor in ignored_factors: + continue + value = getattr(config, factor, None) + try: + factors[factor] = normalize_value(value) + except TypeError as e: + raise TypeError( + f"get_hash_factors: unsupported type for key '{factor}' " + f"({type(value).__name__})" + ) from e + return factors + + +def hash_factors(items: dict[str, object]) -> str: + """Return a SHA-256 hex digest of the canonical items structure.""" + return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest() diff --git a/vllm/envs.py b/vllm/envs.py index b99e2524318f..0b94eb466436 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -import hashlib import json import os import sys @@ -1530,85 +1529,67 @@ def is_set(name: str): raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -def compute_hash() -> str: +def compile_factors() -> dict[str, object]: """ - WARNING: Whenever a new key is added to this environment - variables, ensure that it is included in the factors list if - it affects the computation graph. For example, different values - of VLLM_PP_LAYER_PARTITION will generate different computation - graphs, so it is included in the factors list. The env vars that - affect the choice of different kernels or attention backends should - also be included in the factors list. + Return environment variables used to compute the compile cache key. + This includes all known vLLM environment variables. + This then excludes variables that cannot affect graph structure, codegen, or kernel + selection (see ignored_factors) """ - # The values of envs may affects the computation graph. - # TODO(DefTruth): hash all environment variables? - # for key in environment_variables: - # factorize(key) - environment_variables_to_hash = [ - "VLLM_PP_LAYER_PARTITION", - "VLLM_MLA_DISABLE", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", - "VLLM_USE_TRITON_AWQ", - "VLLM_DP_RANK", - "VLLM_DP_SIZE", - "VLLM_USE_STANDALONE_COMPILE", - "VLLM_FUSED_MOE_CHUNK_SIZE", - "VLLM_FLASHINFER_MOE_BACKEND", - "VLLM_V1_USE_PREFILL_DECODE_ATTENTION", - "VLLM_ATTENTION_BACKEND", - "VLLM_USE_FLASHINFER_SAMPLER", - "VLLM_DISABLED_KERNELS", - "VLLM_USE_DEEP_GEMM", - "VLLM_MOE_USE_DEEP_GEMM", - "VLLM_USE_DEEP_GEMM_E8M0", - "VLLM_USE_FUSED_MOE_GROUPED_TOPK", - "VLLM_USE_FLASHINFER_MOE_FP16", - "VLLM_USE_FLASHINFER_MOE_FP8", - "VLLM_USE_FLASHINFER_MOE_FP4", - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", - "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", - "VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", - "VLLM_USE_CUDNN_PREFILL", - "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", - "VLLM_USE_TRTLLM_ATTENTION", - "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", - "VLLM_ROCM_USE_AITER", - "VLLM_ROCM_USE_AITER_PAGED_ATTN", - "VLLM_ROCM_USE_AITER_LINEAR", - "VLLM_ROCM_USE_AITER_MOE", - "VLLM_ROCM_USE_AITER_RMSNORM", - "VLLM_ROCM_USE_AITER_MLA", - "VLLM_ROCM_USE_AITER_MHA", - "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", - "VLLM_ROCM_USE_AITER_TRITON_ROPE", - "VLLM_ROCM_USE_AITER_FP8BMM", - "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", - "VLLM_ROCM_USE_AITER_TRITON_GEMM", - "VLLM_ROCM_USE_SKINNY_GEMM", - "VLLM_ROCM_FP8_PADDING", - "VLLM_ROCM_MOE_PADDING", - "VLLM_ROCM_CUSTOM_PAGED_ATTN", - "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", - "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", - "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", - "VLLM_ROCM_FP8_MFMA_PAGE_ATTN", - "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", - "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", - "VLLM_NVFP4_GEMM_BACKEND", - "VLLM_USE_FBGEMM", - "VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE", - "VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL", - ] - for key in environment_variables_to_hash: - # if this goes out of sync with environment_variables, - # it's not a user error, it's a bug - assert key in environment_variables, ( - "Please update environment_variables_to_hash in envs.py" - ) - - factors = [environment_variables[key]() for key in environment_variables_to_hash] + ignored_factors: set[str] = { + "MAX_JOBS", + "VLLM_RPC_BASE_PATH", + "VLLM_USE_MODELSCOPE", + "VLLM_RINGBUFFER_WARNING_INTERVAL", + "VLLM_DEBUG_DUMP_PATH", + "VLLM_PORT", + "VLLM_CACHE_ROOT", + "LD_LIBRARY_PATH", + "VLLM_SERVER_DEV_MODE", + "VLLM_DP_MASTER_IP", + "VLLM_DP_MASTER_PORT", + "VLLM_RANDOMIZE_DP_DUMMY_INPUTS", + "VLLM_CI_USE_S3", + "VLLM_MODEL_REDIRECT_PATH", + "VLLM_HOST_IP", + "S3_ACCESS_KEY_ID", + "S3_SECRET_ACCESS_KEY", + "S3_ENDPOINT_URL", + "VLLM_USAGE_STATS_SERVER", + "VLLM_NO_USAGE_STATS", + "VLLM_DO_NOT_TRACK", + "VLLM_LOGGING_LEVEL", + "VLLM_LOGGING_PREFIX", + "VLLM_LOGGING_STREAM", + "VLLM_LOGGING_CONFIG_PATH", + "VLLM_LOG_STATS_INTERVAL", + "VLLM_DEBUG_LOG_API_SERVER_RESPONSE", + "VLLM_TUNED_CONFIG_FOLDER", + "VLLM_ENGINE_ITERATION_TIMEOUT_S", + "VLLM_HTTP_TIMEOUT_KEEP_ALIVE", + "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS", + "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", + "VLLM_SLEEP_WHEN_IDLE", + "VLLM_IMAGE_FETCH_TIMEOUT", + "VLLM_VIDEO_FETCH_TIMEOUT", + "VLLM_AUDIO_FETCH_TIMEOUT", + "VLLM_MEDIA_URL_ALLOW_REDIRECTS", + "VLLM_MEDIA_LOADING_THREAD_COUNT", + "VLLM_MAX_AUDIO_CLIP_FILESIZE_MB", + "VLLM_VIDEO_LOADER_BACKEND", + } + + from vllm.config.utils import normalize_value + + factors: dict[str, object] = {} + for factor, getter in environment_variables.items(): + if factor in ignored_factors: + continue + + raw = getter() + + factors[factor] = normalize_value(raw) ray_noset_env_vars = [ # Refer to @@ -1631,8 +1612,8 @@ def compute_hash() -> str: "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", "RAY_EXPERIMENTAL_NOSET_RBLN_RT_VISIBLE_DEVICES", ] - factors.extend([os.getenv(var) for var in ray_noset_env_vars]) - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + for var in ray_noset_env_vars: + factors[var] = normalize_value(os.getenv(var)) - return hash_str + return factors diff --git a/vllm/logging_utils/__init__.py b/vllm/logging_utils/__init__.py index 7202259ca21a..44b40ead973b 100644 --- a/vllm/logging_utils/__init__.py +++ b/vllm/logging_utils/__init__.py @@ -2,9 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.logging_utils.formatter import NewLineFormatter +from vllm.logging_utils.lazy import lazy from vllm.logging_utils.log_time import logtime __all__ = [ "NewLineFormatter", + "lazy", "logtime", ] diff --git a/vllm/logging_utils/lazy.py b/vllm/logging_utils/lazy.py new file mode 100644 index 000000000000..3ade79896285 --- /dev/null +++ b/vllm/logging_utils/lazy.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable +from typing import Any + + +class lazy: + """Wrap a zero-argument callable evaluated only during log formatting.""" + + __slots__ = ("_factory",) + + def __init__(self, factory: Callable[[], Any]) -> None: + self._factory = factory + + def __str__(self) -> str: + return str(self._factory()) + + def __repr__(self) -> str: + return str(self)