Skip to content

Conversation

orozery
Copy link
Contributor

@orozery orozery commented Sep 4, 2025

This is the final PR enabling CPU offloading in v1.

Concludes RFC #19854.
Depends on #20075, #21448, #22595.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant new feature: CPU offloading for v1. The implementation is extensive, adding a new offloading framework with managers, specs, handlers, and a dedicated KV connector. The code is well-structured, with a clear separation of concerns between scheduler and worker logic. My review focuses on critical aspects of reliability and resource management, and I've identified a few high-impact issues that should be addressed to ensure the robustness of this new feature.

finished_recving = set()
for job_id, success in self.worker.get_finished():
# we currently do not support job failures
assert success
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using assert success is risky as it will crash the worker process if an offloading transfer fails for any reason (e.g., I/O error, out of space). This can bring down the entire system. Failures should be handled more gracefully, for instance by logging a critical error and cleaning up the state for the failed job, without crashing the worker. While the comment indicates failures are not supported, using an assert is not a robust way to enforce this in production code.

Comment on lines 284 to 282
if store_output is None:
logger.warning("Cannot store %s blocks", num_new_blocks)
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using break here will prematurely exit the loop that iterates over scheduled requests. If prepare_store fails for one request (by returning None), subsequent requests in the same scheduling step will not be considered for offloading. This could lead to offloading starvation for other requests. You should use continue to proceed to the next request in the loop.

Suggested change
if store_output is None:
logger.warning("Cannot store %s blocks", num_new_blocks)
break
if store_output is None:
logger.warning("Cannot store %s blocks", num_new_blocks)
continue

@ApostaC
Copy link
Collaborator

ApostaC commented Sep 12, 2025

Is this PR the combination of #19848, #20075, #21448, and #22595?
I've reviewed the above 4 PRs, and just wondering are there any new things in this PR?

@orozery
Copy link
Contributor Author

orozery commented Sep 12, 2025

Is this PR the combination of #19848, #20075, #21448, and #22595? I've reviewed the above 4 PRs, and just wondering are there any new things in this PR?

So each PR actually introduces a single new commit.
For this PR, this is just the registration of the CPU implementation for the offloading connector.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM .. needs the block hash type change of course (and I assume that affects the other PRs too...)

Comment on lines 112 to 67
attn_backend = get_attn_backend(
self.vllm_config.model_config.get_head_size(),
self.vllm_config.model_config.dtype,
self.vllm_config.cache_config.cache_dtype,
self.gpu_block_size,
self.vllm_config.model_config.is_attention_free,
use_mla=self.vllm_config.model_config.use_mla)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should add a get_attn_backend_from_config(VllmConfig) and use that both here and in the NixlConnector.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orozery sorry I think some of the comments here actually apply to earlier commit

Comment on lines 53 to 57
# allocate fresh blocks
blocks: list[BlockStatus] = []
for _ in range(num_fresh_blocks):
blocks.append(CPUBlockStatus(self.num_allocated_blocks))
self.num_allocated_blocks += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So there will only be "fresh" blocks temporarily until the cache is full?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right

Comment on lines 168 to 169
for blk_hash in request.block_hashes[self.block_size_factor -
1::self.block_size_factor]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use islice?

return 0, False

start_block_idx = num_computed_tokens // self.offloaded_block_size
hits = self.manager.lookup(block_hashes[start_block_idx:])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use islice?

Comment on lines 224 to 231
block_hashes = [
blk_hash.hash_value
for blk_hash in request.block_hashes[self.block_size_factor -
1::self.block_size_factor]
]
assert len(block_hashes) >= num_blocks

block_hashes = block_hashes[start_block_idx:num_blocks]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
block_hashes = [
blk_hash.hash_value
for blk_hash in request.block_hashes[self.block_size_factor -
1::self.block_size_factor]
]
assert len(block_hashes) >= num_blocks
block_hashes = block_hashes[start_block_idx:num_blocks]
step = self.block_size_factor
block_hashes = [
blk_hash.hash_value for blk_hash in itertools.islice(
request.block_hashes,
(start_block_idx + 1) * step - 1,
(num_blocks + 1) * step - 1,
step)
]

src_specs = self.manager.prepare_load(block_hashes)
dst_specs = [
GPULoadStoreSpec(gpu_block_id)
for gpu_block_id in block_ids[num_computed_gpu_blocks:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use islice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm changing GPULoadStoreSpec to construct a tensor of block IDs. It needs a list as input, so cannot use islice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be misunderstanding but this comprehension is creating a list of GPULoadStoreSpec objects, which can't be used to create a tensor directly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry I think I understand you are saying this is n/a after latest refactor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to create a single GPULoadStoreSpec object, wrapping a tensor.

Copy link

pytorch-bot bot commented Sep 15, 2025

No ciflow labels are configured for this repo.
For information on how to enable CIFlow bot see this wiki

@orozery
Copy link
Contributor Author

orozery commented Sep 15, 2025

@njhill I switched to using islice.
The downside is you get a one-time iterator.
Because I use it more than once, I need to create several instances.
So this makes the code a bit more complex.

@njhill
Copy link
Member

njhill commented Sep 15, 2025

@njhill I switched to using islice. The downside is you get a one-time iterator. Because I use it more than once, I need to create several instances. So this makes the code a bit more complex.

Thanks @orozery I guess I don't follow the downside / extra complexity. E.g.

hits = self.manager.lookup(block_hashes[start_block_idx:])

just becomes

hits = self.manager.lookup(islice(block_hashes, None, start_block_idx))

If/where the sliced list is iterated over multiple times then I agree there may be more to consider.

@orozery
Copy link
Contributor Author

orozery commented Sep 16, 2025

@njhill I switched to using islice. The downside is you get a one-time iterator. Because I use it more than once, I need to create several instances. So this makes the code a bit more complex.

Thanks @orozery I guess I don't follow the downside / extra complexity. E.g.

hits = self.manager.lookup(block_hashes[start_block_idx:])

just becomes

hits = self.manager.lookup(islice(block_hashes, None, start_block_idx))

If/where the sliced list is iterated over multiple times then I agree there may be more to consider.

For example, this:

https://github.com/vllm-project/vllm/blob/94a0405bfaa32932cdb0e8362250f68587ebfa95/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py#L237-L251

And:

https://github.com/vllm-project/vllm/blob/94a0405bfaa32932cdb0e8362250f68587ebfa95/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py#L283-L303

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh sorry these comments were sitting in pending, just realized I didn't submit it yesterday

src_specs = self.manager.prepare_load(block_hashes)
dst_specs = [
GPULoadStoreSpec(gpu_block_id)
for gpu_block_id in block_ids[num_computed_gpu_blocks:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry I think I understand you are saying this is n/a after latest refactor

@mergify mergify bot added the kv-connector label Sep 18, 2025
@njhill njhill changed the title v1: CPU offloading [KV offload][5/N] Add CPUOffloadingSpec Sep 18, 2025
Copy link

mergify bot commented Sep 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @orozery.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 19, 2025
@mergify mergify bot added documentation Improvements or additions to documentation and removed needs-rebase labels Sep 21, 2025
This commit registers a new OffloadingSpec to add CPU offloading support
to the OffloadingConnector.

Signed-off-by: Or Ozeri <oro@il.ibm.com>
@orozery
Copy link
Contributor Author

orozery commented Sep 21, 2025

@njhill I've added a small e2e test + small example in the docs



@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
def test_cpu_offloading(cpu_block_size: int) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to add some assertions to the test such that it will fail if the offload is not working? Should probably also verify correctness in conjunction with this.

Copy link
Contributor Author

@orozery orozery Sep 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a correctness unit test for the transfer function (test_cpu_gpu.py).
Also there is a correctness unit test that the offloading connector generates the correct transfer addresses of the GPU and the offloaded medium.
I don't know how to we can test correctness e2e.

Currently the test here just checks that prompt generation does not crash when using cpu offloading.
It does not verify that any offloading actually occurs.

One way we can verify this is by adding a kv_events_config (like in test_kv_cache_events) that will check for KVEvents with the CPU medium.
I actually started coding that but saw that it is a bit cumbersome, so I decided to defer this to see if others think it's worthwhile.

Another option is to verify latency decreases when we're supposed to hit the cpu cache (after resetting the GPU prefix cache).
We can decrease the variance by, say, repeat this 100 times and verify that at least 70 times the latency decreased.
This will actually be easy to implement (comparing to the KVEvents test).
My concern is that even when repeating the test multiple (e.g. 100) times, it can still be flakey.
Your thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @orozery yes I was thinking perhaps at least some kind of latency comparison but I agree timing tests are fragile / generally not a good idea. If the magnitude of the difference is large enough perhaps it wouldn't need so many attempts, maybe just a handful?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging this to ensure that it makes the release but we might want to think a bit more about the e2e CI tests.

Thanks again for all of your hard work @orozery!

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 22, 2025
@njhill njhill merged commit 8db2939 into vllm-project:main Sep 22, 2025
50 checks passed
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: Or Ozeri <oro@il.ibm.com>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: charlifu <charlifu@amd.com>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@alew3
Copy link

alew3 commented Oct 7, 2025

@orozery how can I debug if the offload connector is being used?

--kv-transfer-config \ '{"kv_connector":"OffloadingConnector", "kv_role":"kv_both", "kv_connector_extra_config":{ "num_cpu_blocks": 5000 } }'  

@orozery
Copy link
Contributor Author

orozery commented Oct 8, 2025

@orozery how can I debug if the offload connector is being used?

--kv-transfer-config \ '{"kv_connector":"OffloadingConnector", "kv_role":"kv_both", "kv_connector_extra_config":{ "num_cpu_blocks": 5000 } }'  

You can check the vllm logs which will include logs from the connector.
Control level by e.g. VLLM_LOGGING_LEVEL=DEBUG

gjc0824 pushed a commit to gjc0824/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: gaojc <1055866782@qq.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
Signed-off-by: Or Ozeri <oro@il.ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants