Skip to content

Conversation

@heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented Oct 6, 2025

Purpose

In DCP, the block_size is multiplied by DCP size. We should also apply this to the hash function

Test Plan

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import LLM, SamplingParams
from vllm.inputs.data import TokensPrompt

# Sample prompts.
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=10)
prompt_token_ids_1 = [TokensPrompt(prompt_token_ids=[23333] * 150 + [12222] * 400)]
prompt_token_ids_2 = [TokensPrompt(prompt_token_ids=[23333] * 150 + [12223] * 400)]


def main():
    # Create an LLM.
    llm = LLM(
        model="deepseek-ai/DeepSeek-V2-Lite",
        tensor_parallel_size=2,
        decode_context_parallel_size=2,
        load_format="dummy",
        enforce_eager=True,
    )
    # Generate texts from the prompts.
    # The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    output1 = llm.generate(prompt_token_ids_1, sampling_params)
    output2 = llm.generate(prompt_token_ids_2, sampling_params)


if __name__ == "__main__":
    main()

Test Result

The cache hit length should be < 150
before this PR, the cache hit length of the two requests are

(EngineCore_DP0 pid=1746435) num_new_local_computed_tokens 0
(EngineCore_DP0 pid=1746435) num_new_local_computed_tokens 256

after this PR, the cache hit length are:

(EngineCore_DP0 pid=1737704) num_new_local_computed_tokens 0
(EngineCore_DP0 pid=1737704) num_new_local_computed_tokens 128

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a bug in prefix caching with Decode Context Parallelism (DCP) by correctly calculating the block size for hashing. The change ensures that the block_size is multiplied by decode_context_parallel_size, which is consistent with how the scheduler handles block sizes in DCP. My review includes a suggestion to refactor the code to avoid logic duplication, which was the root cause of this bug, thereby improving maintainability and preventing similar issues in the future.

Comment on lines 180 to 183
hash_block_size = (
vllm_config.cache_config.block_size
* vllm_config.parallel_config.decode_context_parallel_size
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This calculation for hash_block_size duplicates logic from the Scheduler's __init__ method, where its block_size attribute is also adjusted for decode_context_parallel_size. This duplication is what led to the original bug this PR is fixing.

To avoid this, you can reuse the block_size from the scheduler instance. This makes the code more robust by using a single source of truth.

While block_size is not on the SchedulerInterface, using it with a type: ignore is a pragmatic way to remove the duplication within this file. A more complete solution would involve exposing the logical block size through the interface or a shared utility, which could be addressed in a follow-up.

            hash_block_size = self.scheduler.block_size  # type: ignore

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines 177 to 190
self.vllm_config.cache_config.enable_prefix_caching
or self.scheduler.get_kv_connector() is not None
):
block_size = vllm_config.cache_config.block_size
hash_block_size = (
vllm_config.cache_config.block_size
* vllm_config.parallel_config.decode_context_parallel_size
)
caching_hash_fn = get_hash_fn_by_name(
vllm_config.cache_config.prefix_caching_hash_algo
)
init_none_hash(caching_hash_fn)

self.request_block_hasher = get_request_block_hasher(
block_size, caching_hash_fn
hash_block_size, caching_hash_fn

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid scaling block hashes when only KV connectors use them

The new hash setup multiplies cache_config.block_size by decode_context_parallel_size before computing request.block_hashes. That makes each hash correspond to block_size × dcp_world_size tokens (engine/core.py lines 175‑190). KV connectors, however, still interpret request.block_hashes in units of the original GPU block size; for example OffloadingConnectorScheduler.get_num_new_matched_tokens asserts len(request.block_hashes) // self.block_size_factor == num_blocks where self.block_size_factor is derived from cache_config.block_size (offloading_connector.py lines 179‑181 and spec.py lines 33‑38). With DCP>1 and a KV connector enabled but prefix caching disabled, len(request.block_hashes) shrinks by the DCP factor and the assertion will fail or the connector will mis-index blocks. Either keep hashes at the GPU block granularity when only a connector is present or update the connector code to handle the larger hash stride.

Useful? React with 👍 / 👎.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines 177 to 190
self.vllm_config.cache_config.enable_prefix_caching
or self.scheduler.get_kv_connector() is not None
):
block_size = vllm_config.cache_config.block_size
hash_block_size = (
vllm_config.cache_config.block_size
* vllm_config.parallel_config.decode_context_parallel_size
)
caching_hash_fn = get_hash_fn_by_name(
vllm_config.cache_config.prefix_caching_hash_algo
)
init_none_hash(caching_hash_fn)

self.request_block_hasher = get_request_block_hasher(
block_size, caching_hash_fn
hash_block_size, caching_hash_fn

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid scaling block hashes when only KV connectors use them

The new hash setup multiplies cache_config.block_size by decode_context_parallel_size before computing request.block_hashes. That makes each hash correspond to block_size × dcp_world_size tokens (engine/core.py lines 175‑190). KV connectors, however, still interpret request.block_hashes in units of the original GPU block size; for example OffloadingConnectorScheduler.get_num_new_matched_tokens asserts len(request.block_hashes) // self.block_size_factor == num_blocks where self.block_size_factor is derived from cache_config.block_size (offloading_connector.py lines 179‑181 and spec.py lines 33‑38). With DCP>1 and a KV connector enabled but prefix caching disabled, len(request.block_hashes) shrinks by the DCP factor and the assertion will fail or the connector will mis-index blocks. Either keep hashes at the GPU block granularity when only a connector is present or update the connector code to handle the larger hash stride.

Useful? React with 👍 / 👎.

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao youkaichao enabled auto-merge (squash) October 9, 2025 02:31
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 9, 2025
@youzhedian
Copy link
Contributor

thank you for the fix.
LGTM.

@mergify mergify bot added the kv-connector label Oct 9, 2025
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
@vllm-bot vllm-bot merged commit 606b00e into vllm-project:main Oct 10, 2025
43 of 46 checks passed
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
…ect#26296)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
…ect#26296)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: bbartels <benjamin@bartels.dev>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Oct 24, 2025
### What this PR does / why we need it?
This is the step 1 of refactoring code to adapt with vllm main, and this
pr aligned with
vllm-project/vllm@17c540a

1. refactor deepseek to the latest code arch as of
vllm-project/vllm@17c540a
 
2. bunches of fixes due to vllm changes
- Fix `AscendScheduler` `__post_init__`, caused by
vllm-project/vllm#25075
- Fix `AscendScheduler` init got an unexpected arg `block_size`, caused
by vllm-project/vllm#26296
- Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by
vllm-project/vllm#23485
- Fix `MLAAttention` import,caused by
vllm-project/vllm#25103
- Fix `SharedFusedMoE` import, caused by
vllm-project/vllm#26145
- Fix `LazyLoader` improt, caused by
vllm-project/vllm#27022
- Fix `vllm.utils.swap_dict_values` improt, caused by
vllm-project/vllm#26990
- Fix `Backend` enum import, caused by
vllm-project/vllm#25893
- Fix `CompilationLevel` renaming to `CompilationMode` issue introduced
by vllm-project/vllm#26355
- Fix fused_moe ops, caused by
vllm-project/vllm#24097
- Fix bert model because of `inputs_embeds`, caused by
vllm-project/vllm#25922
- Fix MRope because of `get_input_positions_tensor` to
`get_mrope_input_positions`, caused by
vllm-project/vllm#24172
- Fix `splitting_ops` changes introduced by
vllm-project/vllm#25845
- Fix multi-modality changes introduced by
vllm-project/vllm#16229
- Fix lora bias dropping issue introduced by
vllm-project/vllm#25807
- Fix structured ouput break introduced by
vllm-project/vllm#26737

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
CI passed with existing test.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Icey <1790571317@qq.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…ect#26296)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…ect#26296)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…ect#26296)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants