Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions tests/config/test_config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass
from enum import Enum
from typing import Optional

import pytest

from vllm.config.utils import get_hash_factors, hash_factors, normalize_value

# Helpers

def endswith_fqname(obj, suffix: str) -> bool:
# normalize_value(type) returns fully-qualified name
# Compare suffix to avoid brittle import paths.
out = normalize_value(obj)
return isinstance(out, str) and out.endswith(suffix)


def expected_path(p_str: str = ".") -> str:
import pathlib

p = pathlib.Path(p_str)
return p.expanduser().resolve().as_posix()


# Minimal dataclass to test get_hash_factors.
# Avoid importing heavy vLLM configs.
@dataclass
class SimpleConfig:
a: object
b: Optional[object] = None


class DummyLogprobsMode(Enum):
RAW_LOGITS = "raw_logits"


def test_hash_factors_deterministic():
"""Test that hash_factors produces consistent SHA-256 hashes"""
factors = {"a": 1, "b": "test"}
hash1 = hash_factors(factors)
hash2 = hash_factors(factors)

assert hash1 == hash2
assert len(hash1) == 64
assert all(c in "0123456789abcdef" for c in hash1)


@pytest.mark.parametrize(
"inp, expected",
[
(None, None),
(True, True),
(1, 1),
(1.0, 1.0),
("x", "x"),
(b"ab", "6162"),
(bytearray(b"ab"), "6162"),
([1, 2], (1, 2)),
({"b": 2, "a": 1}, (("a", 1), ("b", 2))),
],
)
def test_normalize_value_matrix(inp, expected):
"""Parametric input→expected normalization table."""
assert normalize_value(inp) == expected


def test_normalize_value_enum():
# Enums normalize to (module.QualName, value).
# DummyLogprobsMode uses a string payload.
out = normalize_value(DummyLogprobsMode.RAW_LOGITS)
assert isinstance(out, tuple)
assert out[0].endswith("DummyLogprobsMode")
# Expect string payload 'raw_logits'.
assert out[1] == "raw_logits"


def test_normalize_value_set_order_insensitive():
# Sets are unordered; normalize_value sorts elements for determinism.
assert normalize_value({3, 1, 2}) == normalize_value({1, 2, 3})


def test_normalize_value_path_normalization():
from pathlib import Path # local import to avoid global dependency

# Paths expand/resolve to absolute strings.
# Stabilizes hashing across working dirs.
assert normalize_value(Path(".")) == expected_path(".")


def test_normalize_value_uuid_and_to_json():
# Objects may normalize via uuid() or to_json_string().
class HasUUID:
def uuid(self):
return "test-uuid"

class ToJson:
def to_json_string(self):
return '{"x":1}'

assert normalize_value(HasUUID()) == "test-uuid"
assert normalize_value(ToJson()) == '{"x":1}'


@pytest.mark.parametrize(
"bad",
[
(lambda x: x),
(type("CallableInstance", (), {"__call__": lambda self: 0}))(),
(lambda: (lambda: 0))(), # nested function instance
],
)
def test_error_cases(bad):
"""Inputs expected to raise TypeError."""
# Reject functions/lambdas/callable instances
# to avoid under-hashing.
with pytest.raises(TypeError):
normalize_value(bad)


def test_enum_vs_int_disambiguation():
# int stays primitive
nf_int = normalize_value(1)
assert nf_int == 1

# enum becomes ("module.QualName", value)
nf_enum = normalize_value(DummyLogprobsMode.RAW_LOGITS)
assert isinstance(nf_enum, tuple) and len(nf_enum) == 2
enum_type, enum_val = nf_enum
assert enum_type.endswith(".DummyLogprobsMode")
assert enum_val == "raw_logits"

# Build factor dicts from configs with int vs enum
f_int = get_hash_factors(SimpleConfig(1), set())
f_enum = get_hash_factors(SimpleConfig(DummyLogprobsMode.RAW_LOGITS), set())
# The int case remains a primitive value
assert f_int["a"] == 1
# The enum case becomes a tagged tuple ("module.QualName", "raw_logits")
assert isinstance(f_enum["a"], tuple) and f_enum["a"][1] == "raw_logits"
# Factor dicts must differ so we don't collide primitives with Enums.
assert f_int != f_enum
# Hash digests must differ correspondingly
assert hash_factors(f_int) != hash_factors(f_enum)

# Hash functions produce stable hex strings
h_int = hash_factors(f_int)
h_enum = hash_factors(f_enum)
assert isinstance(h_int, str) and len(h_int) == 64
assert isinstance(h_enum, str) and len(h_enum) == 64


def test_classes_are_types():
"""Types normalize to FQNs; include real vLLM types."""
# Only classes allowed; functions/lambdas are rejected.
# Canonical form is the fully-qualified name.
assert isinstance(normalize_value(str), str)

class LocalDummy:
pass

assert endswith_fqname(LocalDummy, ".LocalDummy")
104 changes: 84 additions & 20 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
import ast
import dataclasses
import hashlib
import json
import logging
import os
import pprint
import time
from collections.abc import Callable, Sequence
from contextlib import contextmanager
from functools import partial
from typing import Any

import torch
Expand All @@ -22,8 +25,10 @@
resolve_defined_ops,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.utils import hash_factors
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.logging_utils import lazy
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer

Expand Down Expand Up @@ -574,32 +579,46 @@ def __call__(
from .caching import _compute_code_hash, compilation_config_hash_factors

vllm_config = self.vllm_config
# Minimal hashing here with existing utilities, reused below.

env_factors = envs.compile_factors()
env_hash = hash_factors(env_factors)
# Compute config/compiler/code hashes once and reuse
config_hash = vllm_config.compute_hash()
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
forward_code_files = list(sorted(self.compilation_config.traced_files))

logger.debug(
"Traced files (to be considered for compilation cache):\n%s",
lazy(lambda: "\n".join(forward_code_files)),
)
hash_content = []
for filepath in forward_code_files:
hash_content.append(filepath)
if filepath == "<string>":
# This means the function was dynamically generated, with
# e.g. exec(). We can't actually check these.
continue
try:
with open(filepath) as f:
hash_content.append(f.read())
except Exception:
logger.warning("Failed to read file %s", filepath)
continue
code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest()
# Clear after consumption
self.compilation_config.traced_files.clear()
if not self.compilation_config.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.

factors = compilation_config_hash_factors(vllm_config)
# 2. factors come from the code files that are traced by Dynamo (
# it mainly summarizes how the model is used in forward pass)
code_hash = _compute_code_hash(self.compilation_config.traced_files)
self.compilation_config.traced_files.clear()
factors.append(code_hash)

# 3. compiler hash
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
factors.append(compiler_hash)

# combine all factors to generate the cache dir
hash_key = hashlib.md5(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]

factors = [env_hash, config_hash, code_hash, compiler_hash]
# Use SHA-256 for cache key hashing to be consistent across
# compute_hash functions. Truncate for a short cache dir name.
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10]
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT,
"torch_compile_cache",
hash_key,
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key
)
self.compilation_config.cache_dir = cache_dir

Expand Down Expand Up @@ -627,6 +646,51 @@ def __call__(
local_cache_dir, disable_cache, self.prefix
)

# Reuses existing cache key

logger.info(
"torch.compile cache factors: env=%s cfg=%s comp=%s dir=%s",
env_hash,
config_hash,
compiler_hash,
local_cache_dir,
)
logger.debug("code hash=%s", code_hash)

# Persist and log only hash-relevant factors together.
try:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Compile env factors (raw):\n%s\nVllm config hash: %s",
lazy(partial(pprint.pformat, env_factors, width=120)),
config_hash,
)
meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
if not os.path.exists(meta_path):
with open(meta_path, "w") as f:
json.dump(
{
"env": env_factors, # raw factors used for env_hash
"config_hash": config_hash,
"code_hash": code_hash,
"compiler_hash": compiler_hash,
},
f,
indent=2,
sort_keys=True,
)
except Exception:
# Best-effort only; metadata write failures are non-fatal.
logger.warning(
(
"Could not write compile cache metadata at %s; continuing without "
"metadata. Compiled cache remains valid; diagnostics may be "
"limited."
),
local_cache_dir,
exc_info=True,
)

# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
compilation_counter.num_graphs_seen += 1
Expand Down
31 changes: 23 additions & 8 deletions vllm/config/cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import hashlib
from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal

Expand Down Expand Up @@ -152,13 +151,29 @@ def compute_hash(self) -> str:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: list[Any] = []
factors.append(self.cache_dtype)
factors.append(self.mamba_cache_dtype)
factors.append(self.mamba_ssm_cache_dtype)
# `cpu_offload_gb` does not use `torch.compile` yet.
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
ignored_factors = {
# Runtime/derived knobs that don't affect compiled graph shape
"gpu_memory_utilization",
"swap_space",
"is_attention_free",
"num_gpu_blocks_override",
"enable_prefix_caching",
"prefix_caching_hash_algo",
"cpu_offload_gb",
"calculate_kv_scales",
"cpu_kvcache_space_bytes",
"mamba_page_size_padded",
# Post-init/derived counters
"num_gpu_blocks",
"num_cpu_blocks",
# WIP feature toggle not impacting compiled graph shape
"kv_sharing_fast_prefill",
}

from vllm.config.utils import get_hash_factors, hash_factors

factors = get_hash_factors(self, ignored_factors)
return hash_factors(factors)

def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
Expand Down
Loading