-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
fix: NIXL connector transfers partial block to pass full multi-modal context #21074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: NIXL connector transfers partial block to pass full multi-modal context #21074
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes an issue where partial KV blocks were not transferred in the NIXL connector, potentially leading to incomplete multi-modal context. The changes correctly adjust the logic to include partial blocks in the transfer, both on the producer and consumer side. The implementation looks correct and addresses the described problem. I've pointed out one critical issue regarding a potentially incorrect assertion that could lead to a crash in certain scenarios involving prefix caching.
|
@robertgshaw2-redhat I saw you wrote the initial implementation and I would love to know your opinion on changing the behavior from skipping partial last block to transferring it. |
Do you mean encoder cache here? Or are you referring to non-llava impl with cross blocks (mllama?) |
@NickLucche I haven't considered the above for this change, can you elaborate or point me to the places where I can learn more about it? My understanding of multi-modal inference is that after the prefill stage, the context of prompt (image template applied) and image is being captured in the KV blocks allocated for the request. Therefore, as long as the downstream (decode) worker can access all the blocks from upstream (prefill) worker, then it has sufficient information to continue the generation stage. This change is to ensure that all blocks will be transferred, even if the last block is partially used, and it does let decode worker produce proper response. |
4d4e68b to
eaadf83
Compare
NickLucche
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes after the encoder is run, the language_model part works exactly the same way.
However, for decode worker, the multi-modal input will not be passed,
Had a hard time figuring out what you meant here.
I see now you mean you want to transfer the blocks regardless of whether a full page is reached because image tokens are blocking.
I am not sure this is the best default for regular non-MM models where this optimization was saving a few transfers.
We probably want to run a few benchmarks or just do this for mm.
Thanks for spotting this and for contributing!
Let me know if what is the best way to proceed at this point, I do see that this optimization is going to save the last transfer in exchange to let decode re-compute the last block of token. Yes we should run a few benchmarks to understand the trade off numerically. But for disaggregation, the ISL is typically long to be beneficial, which means there will already be a few full blocks to be transferred, so adding the last partial block in the transfer may not be as bad. |
|
After thinking a bit more, I think the point is for decode worker to successfully reconstruct the same KV cache as what is in prefill worker. The optimization works perfectly fine for non-MM model because the decode worker can re-compute the missing KV cache from the prompt, whereas in MM model, the encoding stage can't be redo, unless the decode worker also receives the request with MM input? But I am afraid that means repeating the MM input processing in decode worker. Given that, I will still advocate for always transferring all KV cache from prefill, which better decouples the language model part from any processing before that (i.e. the encoding). Then for disaggregation, decode worker only needs the final prompt processed by prefill (from prefill_response.prompt) and the full KV cache, no matter whether the model is MM model or not. |
|
I agree, what you're saying makes a lot of sense for MM. |
|
I had to do a similar thing last week to transfer all blocks, including the partial ones. Your implementation is I think incomplete. You need to add this piece: vllm/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Lines 255 to 257 in eaadf83
Into this: len_prompt = len(request.prompt_token_ids)
num_computed_tokens = len_prompt - num_external_tokens
assert num_computed_tokens % self.block_size == 0
start_block = num_computed_tokens // self.block_size
local_block_ids = flatten_2d_lists(blocks.get_block_ids())[start_block:]You'll also need to add this import at the top ofc: from vllm.utils import flatten_2d_lists
|
@hasB4K Do you mind to help me understand more about this? Am I understanding it correctly that the issue is that when the decode worker has previously processed a request with shorter prompt (say 48 tokens and block size 32), the 2nd block will have hash value and will not be returned by this function. And then when processing another request with the same shorter prompt as prefix (say 72 tokens), the 2nd block will not be transferred?
I am also curious on your use case, is that also related to multi-modal? |
Yes exactly 😉
No, I was just trying to optimize the last block transfer to avoid re-doing the prefill. But since you raise this issue with multi-modal I'm pretty sure your PR is mandatory. |
|
@NickLucche I have gathered some simple run with
Focusing on the TTFT as that is what sending partial block or not affects the most. You can see that the number before and after the change are not significant. After change (transfer all) Before change (skip partial) |
|
Hi @hasB4K , after going deeper into the implementation, I think that the change you purposed may not be needed. In Let me know if there is misunderstanding in my reasoning, and it would be great if there is a simple test case that I can experiment to visualize what have I missed if the change is indeed needed. |
|
This pull request has merge conflicts that must be resolved before it can be |
…i-modal context to downstream worker Signed-off-by: GuanLuo <gluo@nvidia.com>
Signed-off-by: GuanLuo <gluo@nvidia.com>
Signed-off-by: GuanLuo <gluo@nvidia.com>
eaadf83 to
c9cbab6
Compare
|
@NickLucche do you mind to give another round of review when you have a chance? It would be great if this can be merged this week as this affects correctness and we want to include this fix in our next release. |
I double checked by running some tests by adding this: And you are right, it seems that my implem was not necessary, since it's always equals 😅 . |
NickLucche
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @GuanLuo thanks for running the benchmarks and apologies for the late response! Are you running on IB-connected nodes btw?
Let's get a second opinion on this @robertgshaw2-redhat , otherwise this LGTM thanks for the work!
Yes |
njhill
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @GuanLuo! It looks mostly good to me.
I think a change here may also be required for the intermediate-host-buffer case (used for TPU). cc @juncgu
Also is there any chance you could add a test / extend the existing nixl connector tests to cover this? I.e. that tokens from partial blocks are now transferred rather than recomputed on decode side?
|
@njhill I have updated the test cases to reflect the new behavior. I made a somewhat big change to Previously the test is testing the old behavior where the last partial block will not be transferred, so after the transfer is finished, the immediate step (step 4) will request new block (for the rest of the tokens) immediately when looping through waiting requests and get skipped due to insufficient blocks. In other words, the request stays in waiting list. But with the new behavior, if the transfer is initiated, the request has allocated blocks for all tokens, so it will be moved to running list as soon as the transfer is finished, and thus introduce the new steps where the request get moved but can't process as it needs new block for generated token and get moved back to waiting list, so we have (waiting -> running -> waiting). Note that I change to prompt size for remote request as well so the request will follow the above flow, if the prompt size results in partial block, then the remote request can be scheduled immediately after transfer as the partial block can hold generated tokens. |
Signed-off-by: GuanLuo <gluo@nvidia.com>
NickLucche
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments for minor stuff.
Re: test_cannot_schedule_after_recv , apologies if I misread your changes, but I think the test is now functionally different.
Can we either:
- a) increase tokens st we're actually testing the not enough blocks case
- b) do a separate case with the last partial block now being transferred after your changes, but keep a) to test old intended behavior?
Thanks for the great work here!
|
The I think it make sense to add (a) in addition to the current test cases, with the current change, the KV transfer will not be scheduled until the previous request is finished due to insufficient blocks. |
Signed-off-by: GuanLuo <gluo@nvidia.com>
Signed-off-by: GuanLuo <gluo@nvidia.com>
NickLucche
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
I think the name could be a bit less general than test_cannot_recv but happy with the change :)
|
Awesome! @njhill @robertgshaw2-redhat can I get another round of review from you? |
njhill
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @GuanLuo!
…context (vllm-project#21074) Signed-off-by: GuanLuo <gluo@nvidia.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
…context (vllm-project#21074) Signed-off-by: GuanLuo <gluo@nvidia.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
…context (vllm-project#21074) Signed-off-by: GuanLuo <gluo@nvidia.com>
…context (vllm-project#21074) Signed-off-by: GuanLuo <gluo@nvidia.com>
…context (vllm-project#21074) Signed-off-by: GuanLuo <gluo@nvidia.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
…context (vllm-project#21074) Signed-off-by: GuanLuo <gluo@nvidia.com>
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Purpose
In disaggregation case, the prefill worker will receive the multi-modal input (say an image), the embedding will be processed and stored in KV cache. However, for decode worker, the multi-modal input will not be passed, and in such a case, decode worker will rely on the KV cache transfer to obtain the multi-modal context.
Previous implementation only transfer full KV blocks which may result in incomplete context when part of it is in the incomplete block. This change simply always transfers all blocks to work around that.
Test Plan
Test Result
(Optional) Documentation Update