Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ python-json-logger # Used by logging as per examples/others/logging_configuratio
scipy # Required for phi-4-multimodal-instruct
ninja # Required for xgrammar, rocm, tpu, xpu
pybase64 # fast base64 implementation
cbor2 # Required for cross-language serialization of hashable objects
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure who can double check the changes to requirements. Maybe @DarkLight1337 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with this addition. cc @mgoin to be sure

1 change: 1 addition & 0 deletions requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ruff
# Required for argparse hook only
-f https://download.pytorch.org/whl/cpu
cachetools
cbor2
cloudpickle
fastapi
msgspec
Expand Down
30 changes: 20 additions & 10 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_manager import KVCacheManager
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
from vllm.v1.core.kv_cache_utils import (
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys,
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
hash_block_tokens, hash_request_tokens, unify_kv_cache_configs)
hash_block_tokens, hash_request_tokens, init_none_hash,
unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec)
Expand Down Expand Up @@ -78,24 +79,27 @@ def new_sliding_window_spec(block_size=16,
sliding_window=sliding_window)


def test_none_hash(monkeypatch):
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_none_hash(monkeypatch, hash_fn):
import vllm.v1.core.kv_cache_utils

# case 1: PYTHONHASHSEED is not set, use random
with monkeypatch.context() as m:
m.delenv('PYTHONHASHSEED', raising=False)
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
reloaded_kv_cache_utils.init_none_hash(hash_fn)
assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
assert reloaded_kv_cache_utils.NONE_HASH != 0

# case 2: PYTHONHASHSEED is set, use the seed
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
with monkeypatch.context() as m:
m.setenv('PYTHONHASHSEED', 'python hash seed')
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
reloaded_kv_cache_utils.init_none_hash(hash_fn)
assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
assert sha256('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH


def test_kv_cache_block():
Expand Down Expand Up @@ -287,9 +291,10 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert next_mm_idx == 1


@pytest.mark.parametrize("hash_fn", [sha256, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_hash_block_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn)
parent_block_hash = 123
curr_block_token_ids = (1, 2, 3)
extra_keys = ("key1", "key2")
Expand All @@ -303,9 +308,10 @@ def test_hash_block_tokens(hash_fn):
assert block_hash.extra_keys == extra_keys


@pytest.mark.parametrize("hash_fn", [sha256, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_hash_request_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn)
request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
Expand All @@ -332,8 +338,10 @@ def test_hash_request_tokens(hash_fn):
assert block_hashes[1].extra_keys == ("hash2", )


@pytest.mark.parametrize("hash_fn", [sha256, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_hash_tokens_different_mm_input(hash_fn):
init_none_hash(hash_fn)

request1 = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
Expand All @@ -359,8 +367,10 @@ def test_hash_tokens_different_mm_input(hash_fn):
assert block_hashes1[1] != block_hashes2[1]


@pytest.mark.parametrize("hash_fn", [sha256, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_hash_request_tokens_no_mm_inputs(hash_fn):
init_none_hash(hash_fn)

request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
Expand Down Expand Up @@ -916,4 +926,4 @@ def test_get_kv_cache_config():
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
])
])
14 changes: 9 additions & 5 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock, hash_block_tokens)
KVCacheBlock, hash_block_tokens,
init_none_hash)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, SlidingWindowSpec)

Expand Down Expand Up @@ -91,7 +92,7 @@ def make_kv_cache_config_hybrid_model(block_size: int,
)


@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"])
def test_prefill(hash_algo):
manager = KVCacheManager(
make_kv_cache_config(16, 11),
Expand All @@ -101,7 +102,8 @@ def test_prefill(hash_algo):
)

# choose the hash function according to the parameter
hash_fn = sha256 if hash_algo == "sha256" else hash
hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else
sha256 if hash_algo == "sha256" else hash)

# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]
Expand Down Expand Up @@ -696,12 +698,14 @@ def test_basic_prefix_caching_disabled():
assert not blocks


@pytest.mark.parametrize("hash_fn", [sha256, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_cache_blocks(hash_fn):
"""
This is a unit test that tests the correctness of the _cache_full_blocks
function of KVCacheManager.
"""
init_none_hash(hash_fn)

block_size = 4
block_pool = BlockPool(
num_gpu_blocks=5,
Expand Down
9 changes: 7 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,7 +1503,7 @@ def get_and_verify_max_len(self, max_model_len: int):

BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
PrefixCachingHashAlgo = Literal["builtin", "sha256"]
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]


@config
Expand Down Expand Up @@ -1548,7 +1548,12 @@ class CacheConfig:
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
"""Set the hash algorithm for prefix caching:\n
- "builtin" is Python's built-in hash.\n
- "sha256" is collision resistant but with certain overheads."""
- "sha256" is collision resistant but with certain overheads.
This option uses Pickle for object serialization before hashing.\n
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
hash. It serializes objects using canonical CBOR and hashes them with
SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
digest."""
cpu_offload_gb: float = 0
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to
Expand Down
24 changes: 24 additions & 0 deletions vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from uuid import uuid4

import cachetools
import cbor2
import cloudpickle
import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -3177,6 +3178,29 @@ def sha256(input) -> int:
byteorder="big")


def sha256_cbor_64bit(input) -> int:
"""
Hash objects using CBOR serialization and SHA-256, then truncate to 64bits.

This option is useful for non-Python-dependent serialization and hashing.

Args:
input: Object to be serialized and hashed. Supported types include
basic Python types and complex structures like lists, tuples, and
dictionaries.
Custom classes must implement CBOR serialization methods.

Returns:
An integer in the range [0, 2^64-1] representing the lower 64 bits
of the SHA-256 hash of the CBOR serialized input.
"""
input_bytes = cbor2.dumps(input, canonical=True)
full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big")

return full_hash & ((1 << 64) - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this, would sha256_cbor_64bit be a better name?

Copy link
Contributor Author

@vMaroon vMaroon Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was intended to be temporary until KVEvents support for >64bit hashes is fixed in a dedicated PR, so right now the trimming is only mentioned in formal argument description, if you think that it isn't clear enough then yes.

Temporary in the sense that longer hashes can be supported after. I would still want 64bit trimmed hashes until solid arguments are made against it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If they are independent, can you use 256 bit format in this pr and fix kv event later? I think this pr can be tested without kv event fix.

Copy link
Contributor Author

@vMaroon vMaroon Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'll change the name to your suggestion and introduce a 256bit option with the fix.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the fix, is there any reason to use the 64 bit version? If not, I prefer only keeping the 256 bit version and wait for KV event fix to unblock it.

Copy link
Contributor Author

@vMaroon vMaroon Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a confusion here, apologies.

I propose keeping it 64bit until further experimentation. The intention is to add a 256bit CBOR->SHA256 option (as well as xxHash) if makes sense, but that is in followups. I thought of the KVEvents bug as an easy additional justifier but that seems to have backfired.

Personally I prefer this option as trimming down to 64bits drops hashes-transportation bandwidth to 1/4th while being extremely fit for pretty much all applications - as mentioned in a previous comment.

Copy link
Contributor Author

@vMaroon vMaroon Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a 64-bit truncated SHA-256 is still extremely useful for performant and low-collision hashing - it still provides strong distribution and negligible collision risk at scale:
Think of a 1M tokens cache (practically GPU+CPU) -> ~62K blocks (16 token blocks) -> ~1 in 10 billion collision probability using the birthday paradox approximation (62000^2 / 2*2^64).

Copy link
Contributor Author

@vMaroon vMaroon Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heheda12345 I hope that the above arguments are good for proceeding forward - what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That make sense. Can you add a warning for "PYTHONHASHSEED not set" + "cbor"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.



def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version.

Expand Down
9 changes: 6 additions & 3 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.utils import sha256
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
hash_request_tokens)
hash_request_tokens, init_none_hash)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
Expand Down Expand Up @@ -79,7 +79,10 @@ def __init__(
self.max_model_len = max_model_len

self.enable_caching = enable_caching
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
self.caching_hash_fn = (
sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else
sha256 if caching_hash_algo == "sha256" else hash)
init_none_hash(self.caching_hash_fn)
self.use_eagle = use_eagle
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
Expand Down
28 changes: 20 additions & 8 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, cdiv, sha256
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
Expand Down Expand Up @@ -46,18 +46,30 @@ def get_hash_value(self) -> int:
return self.block_hash.hash_value


# The hash seed for the first block of the prefix block sequence.
#
# Even if the hash function is the builtin hash(), we use sha256 to generate
# the initial hash to simplify the code. This is not performance critical
# as it is done one per process.
# The hash seed for the first block of any prefix block sequence.
#
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
# variable if set such that processes can share the seed if needed.
# This aligns with the behavior of Python's hash() function, which also uses
# a random seed if PYTHONHASHSEED is not set.
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
"PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED"))
#
# The function `init_none_hash` initializes this variable globally.
NONE_HASH: int


def init_none_hash(hash_fn: Callable):
global NONE_HASH

hash_seed = os.getenv("PYTHONHASHSEED")
if hash_seed is None and hash_fn is sha256_cbor_64bit:
logger.warning(
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
"block-hashes when using sha256_cbor_64bit as the hash function."
"Consider setting PYTHONHASHSEED to a fixed value for "
"reproducibility.")

NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this urandom reproducible? If not, I think we need a check for PYTHONHASHSEED is set when using sha256_cbor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should differentiate between the algorithm as an additional hashing function and the need for interoperability. Right now sha256_cbor is the only reproducible option but in the future others such as xxhash_cbor would follow, therefore it is not a 1:1 relationship.

Thus I think it is sensible to keep the support for a random root hash if the seed is not explicitly set regardless of the hashing algorithm - as it was before this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heheda12345 I'll piggyback on this comment for a more manageable discussion - do you see benchmarking as a merge blocker? Or would it be fine to follow-up on the RFC / next PR on xxhash_cbor?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thus I think it is sensible to keep the support for a random root hash if the seed is not explicitly set regardless of the hashing algorithm - as it was before this PR.
What about adding a warning? I think we need to remind the user to set the seed when they need interoperability.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need an end-to-end test. Just something like #15297 is enough.

Copy link
Contributor Author

@vMaroon vMaroon Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some profiling, like @dr75 said, I think the difference is (still) negligible.

=== System Information ===
Platform: macOS-15.5-arm64-arm-64bit-Mach-O
Processor: arm
Python version: 3.13.5
CPU count: 8
RAM: 32.0 GB
=========================

=== Hash Function Profiling Summary ===
AI workload equivalent per run: 50,000 tokens processed
Profiling config: 1000 runs, 3125 blocks/run, block_size=16
---------------------------------------
hash: mean=0.0012s, std=0.0020s
    Mean time per token: 0.00000002s
sha256: mean=0.0054s, std=0.0003s
    Mean time per token: 0.00000011s
sha256_cbor_64bit: mean=0.0171s, std=0.0044s
    Mean time per token: 0.00000034s
---------------------------------------
Comparison (relative slowdown, higher is slower):
    hash: 1.00x (baseline) mean diff: (0.0000s per 50,000 tokens, 0.00000000s per token)
    sha256: 4.62x  mean diff: (+0.0042s per 50,000 tokens, +0.00000008s per token)
    sha256_cbor_64bit: 14.73x  mean diff: (+0.0159s per 50,000 tokens, +0.00000032s per token)
=======================================

code: https://pastebin.com/7thahB9Y

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a warning for "PYTHONHASHSEED not set" + "cbor"?

if hash_seed is None else hash_fn(hash_seed))


class PrefixCachingMetrics:
Expand Down