-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[Prefix Cache] Add reproducible prefix-cache block hashing using SHA-256 + CBOR (64bit) #20511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
154b049
fff4be5
b33577d
612862a
febb658
e493a8c
2c0d381
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,6 +52,7 @@ | |
| from uuid import uuid4 | ||
|
|
||
| import cachetools | ||
| import cbor2 | ||
| import cloudpickle | ||
| import numpy as np | ||
| import numpy.typing as npt | ||
|
|
@@ -3177,6 +3178,29 @@ def sha256(input) -> int: | |
| byteorder="big") | ||
|
|
||
|
|
||
| def sha256_cbor_64bit(input) -> int: | ||
| """ | ||
| Hash objects using CBOR serialization and SHA-256, then truncate to 64bits. | ||
|
|
||
| This option is useful for non-Python-dependent serialization and hashing. | ||
|
|
||
| Args: | ||
| input: Object to be serialized and hashed. Supported types include | ||
| basic Python types and complex structures like lists, tuples, and | ||
| dictionaries. | ||
| Custom classes must implement CBOR serialization methods. | ||
|
|
||
| Returns: | ||
| An integer in the range [0, 2^64-1] representing the lower 64 bits | ||
| of the SHA-256 hash of the CBOR serialized input. | ||
| """ | ||
| input_bytes = cbor2.dumps(input, canonical=True) | ||
| full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(), | ||
| byteorder="big") | ||
|
|
||
| return full_hash & ((1 << 64) - 1) | ||
|
||
|
|
||
|
|
||
| def is_torch_equal_or_newer(target: str) -> bool: | ||
| """Check if the installed torch version is >= the target version. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,7 @@ | |
|
|
||
| from vllm.config import VllmConfig | ||
| from vllm.logger import init_logger | ||
| from vllm.utils import GiB_bytes, cdiv, sha256 | ||
| from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit | ||
| from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, | ||
| KVCacheGroupSpec, KVCacheSpec, | ||
| KVCacheTensor, SlidingWindowSpec) | ||
|
|
@@ -46,18 +46,30 @@ def get_hash_value(self) -> int: | |
| return self.block_hash.hash_value | ||
|
|
||
|
|
||
| # The hash seed for the first block of the prefix block sequence. | ||
| # | ||
| # Even if the hash function is the builtin hash(), we use sha256 to generate | ||
| # the initial hash to simplify the code. This is not performance critical | ||
| # as it is done one per process. | ||
| # The hash seed for the first block of any prefix block sequence. | ||
| # | ||
| # We use a random value to avoid hash collisions or PYTHONHASHSEED environment | ||
| # variable if set such that processes can share the seed if needed. | ||
| # This aligns with the behavior of Python's hash() function, which also uses | ||
| # a random seed if PYTHONHASHSEED is not set. | ||
| NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv( | ||
| "PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED")) | ||
| # | ||
| # The function `init_none_hash` initializes this variable globally. | ||
| NONE_HASH: int | ||
|
|
||
|
|
||
| def init_none_hash(hash_fn: Callable): | ||
| global NONE_HASH | ||
|
|
||
| hash_seed = os.getenv("PYTHONHASHSEED") | ||
| if hash_seed is None and hash_fn is sha256_cbor_64bit: | ||
| logger.warning( | ||
| "PYTHONHASHSEED is not set. This will lead to non-reproducible " | ||
| "block-hashes when using sha256_cbor_64bit as the hash function." | ||
| "Consider setting PYTHONHASHSEED to a fixed value for " | ||
| "reproducibility.") | ||
|
|
||
| NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big") | ||
|
||
| if hash_seed is None else hash_fn(hash_seed)) | ||
|
|
||
|
|
||
| class PrefixCachingMetrics: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure who can double check the changes to requirements. Maybe @DarkLight1337 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with this addition. cc @mgoin to be sure