Skip to content

Conversation

@zhiyuan1i
Copy link
Contributor

@zhiyuan1i zhiyuan1i commented Sep 9, 2025

Purpose

This PR introduces a hybrid cache architecture that separates logical kernel block size from
physical page size, enabling more flexible memory management. Key changes include:

  • Added kernel_block_size field to CacheConfig for logical block sizing
  • Enhanced platform-specific configurations for CUDA and ROCm to support hybrid blocks
  • Implemented block table conversion logic between physical and logical representations
  • Added support for different physical/logical block size ratios in V1 worker components

This hybrid model decoupling enables independent development of high-performance operators
without being constrained by linear attention mechanisms like Mamba, addressing performance
bottlenecks discussed in issues #24280 and
#23161.

Test Plan

Added comprehensive tests in tests/v1/worker/test_gpu_model_runner.py to verify:

  • Block table conversion between physical and logical representations
  • Proper handling of different block size ratios
  • Integration with existing GPU model runner functionality
  • Platform-specific configurations for CUDA and ROCm

Test Result

pytest tests/v1/worker/test_gpu_model_runner.py - 20 passes

tests/v1/worker/test_gpu_model_runner.py ....................                                                                                                                        [100%]

===================================================================================== warnings summary =====================================================================================
../../../../opt/conda/envs/vllm-upstream/lib/python3.12/site-packages/torch/cuda/__init__.py:63
  /opt/conda/envs/vllm-upstream/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
    import pynvml  # type: ignore[import]

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================== 20 passed, 3 warnings in 89.20s (0:01:29) =========================================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a hybrid cache architecture to decouple logical and physical block sizes, which is a significant enhancement for memory management. The changes span configuration, platform-specific code, and the core block table management. The implementation in block_table.py appears solid. However, I've identified some critical issues in the tests intended to validate this new functionality. The tests are flawed and do not correctly verify the hybrid block logic, which could mask bugs. Additionally, there's a piece of logic in the GPUModelRunner that could be made more robust. My review focuses on fixing these test and implementation issues to ensure the new feature is reliable and well-tested.

@heheda12345
Copy link
Collaborator

Also CC @tdoublep

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @zhiyuan1i offline. Two major concerns:

  1. I prefer to calculate kernel block size for each attention backend in gpu_model_runner
  2. would be great if BlockTable.block_table and BlockTable.physical_block_table can be merged into one tensor.

@zhiyuan1i
Copy link
Contributor Author

@heheda12345 Thanks for the prompt feedback! I’ve addressed suggestion2 and merged BlockTable.block_table and BlockTable.physical_block_table into a single tensor as recommended. :)

@zhiyuan1i zhiyuan1i force-pushed the hybrid-cache-groups branch 2 times, most recently from 6d1735e to 0b544bf Compare September 9, 2025 14:43
@tjtanaa
Copy link
Contributor

tjtanaa commented Sep 11, 2025

CC @gshtras @hongxiayang as this also affect ROCm

@mergify mergify bot added performance Performance-related issues qwen Related to Qwen models gpt-oss Related to GPT-OSS models speculative-decoding kv-connector labels Oct 9, 2025
@zhiyuan1i zhiyuan1i force-pushed the hybrid-cache-groups branch from 6951014 to 10fabbb Compare October 9, 2025 03:48
@zhiyuan1i zhiyuan1i changed the title [Hybrid]: Decouple Logical Block Size from Physical Page Size [Hybrid]: Decouple Kernel Block Size from KV Page Size Oct 9, 2025
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
@mergify mergify bot removed the needs-rebase label Oct 9, 2025
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for this enhancement. Follow-ups:

  1. more clean-ups @heheda12345
  2. verify the get_supported_kernel_block_size of each attention backend.

else:
self.reorder_batch_threshold = reorder_batch_threshold_i

def _find_compatible_block_sizes(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not a blocker) this function may be simplified.

num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
kv_manager_block_size = kv_cache_spec.block_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not a blocker) should we use the common block size of all attention groups in the same kv cache group here?

@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Oct 9, 2025
@heheda12345 heheda12345 enabled auto-merge (squash) October 9, 2025 06:24
@vllm-bot vllm-bot merged commit d24cf32 into vllm-project:main Oct 9, 2025
58 of 60 checks passed
845473182 pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Oct 10, 2025
…to loader

* 'loader' of https://github.com/dsxsteven/vllm_splitPR: (778 commits)
  [torchao] Add support for ModuleFqnToConfig using regex (vllm-project#26001)
  Add: Support for multiple hidden layers in Eagle3 (vllm-project#26164)
  Enable `RMSNorm` substitution for Transformers backend (vllm-project#26353)
  [Model] Gemma3: Fix GGUF loading and quantization (vllm-project#26189)
  Bump Flashinfer to v0.4.0 (vllm-project#26326)
  Update Dockerfile and install runai-model-streamer[gcs] package (vllm-project#26464)
  [Core] Relax the LoRA  max rank (vllm-project#26461)
  [CI/Build] Fix model nightly tests (vllm-project#26466)
  [Hybrid]: Decouple Kernel Block Size from KV Page Size (vllm-project#24486)
  [Core][KVConnector] Propagate all tokens on resumed preemptions (vllm-project#24926)
  [MM][Doc] Add documentation for configurable mm profiling (vllm-project#26200)
  [Hardware][AMD] Enable FlexAttention backend on ROCm (vllm-project#26439)
  [Bugfix] Incorrect another MM data format in vllm bench throughput (vllm-project#26462)
  [Bugfix] Catch and log invalid token ids in detokenizer #2 (vllm-project#26445)
  [Minor] Change warning->warning_once in preprocess (vllm-project#26455)
  [Bugfix] Set the minimum python version for gpt-oss (vllm-project#26392)
  [Misc] Redact ray runtime env before logging (vllm-project#26302)
  Separate MLAAttention class from Attention (vllm-project#25103)
  [Attention] Register FLASHMLA_SPARSE (vllm-project#26441)
  [Kernels] Modular kernel refactor (vllm-project#24812)
  ...
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…24486)

Signed-off-by: lizhiyuan <uniartisan2017@gmail.com>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>

@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(16)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically FA3 would support MultipleOf(1) while FA2 would support MultipleOf(16); I dont think its worth handling this though

Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
…24486)

Signed-off-by: lizhiyuan <uniartisan2017@gmail.com>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
…24486)

Signed-off-by: lizhiyuan <uniartisan2017@gmail.com>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
…24486)

Signed-off-by: lizhiyuan <uniartisan2017@gmail.com>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…24486)

Signed-off-by: lizhiyuan <uniartisan2017@gmail.com>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation gpt-oss Related to GPT-OSS models kv-connector performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding tpu Related to Google TPUs v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.