-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Support disaggregated prefill with Mooncake Transfer Engine #10884
[Core] Support disaggregated prefill with Mooncake Transfer Engine #10884
Conversation
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Now working on OSDI submission, will review after Dec 10. |
This is a great demonstration to adopt mooncake to current disaggregation implementation. Could you share some benchmark data and best practice here? transfer engine's primary feature like more protocols support, topology-aware path selection would be beneficial in larger scale clusters. I am just curious how mooncake perform in 1P1D simple case or isomorphic environments. |
Here are some preview mooncake benchmark results on A10 with up to 2 RDMA NICs. I am currently having some trouble benchmarking Varying tp (input length = 1024, qps = 2, output length =6)
Varying qps (length = 1024, tp = 4, output length =6)
Varying input length (tp = 4, qps = 2, output length =6)
For best practice, I believe there is no best practice before XpYd is ready. But if you want to test the mooncake transfer engine, you can follow the guidance doc to reproduce the results. In addition, we are also coordinating resources to integrate some machines with more RDMA NICs and more advanced GPUs. The official benchmark results will be released in due time. |
…a and Turing GPUs. Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Use latest mooncake code, when I test with tp=1, num_rdma_nic=2, qps=2, input_len=200, output_len=100 in a single machine, which prefill instance num is 1 and decode instance num also is 1. { There will occur an error in transfer_engine: E1213 02:57:10.528410 5811 worker_pool.cpp:274] Worker: Process failed for slice (opcode: 0, source_addr: 0x7efdf3ffd010, length: 404, dest_addr: 140532604981264, local_nic: mlx5_1, peer_nic: 127.0.0.1:8149@mlx5_0, dest_rkey: 2105088, retry_cnt: 0): transport retry counter exceeded And with one rdma device(mlx5_0 or mlx5_1) is ok |
@junna2016 I think the errors are reported from the underlying transfer engine, please open an issue in the Mooncake repo and we will get someone to help identify the root cause. |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean implementation! I left some comments PTAL.
vllm/distributed/kv_transfer/kv_connector/mooncake_connector.py
Outdated
Show resolved
Hide resolved
vllm/distributed/kv_transfer/kv_connector/mooncake_connector.py
Outdated
Show resolved
Hide resolved
if self.transport_thread is None: | ||
self.transport_thread = ThreadPoolExecutor(max_workers=1) | ||
tensor = self.transport_thread.submit(self._recv_impl).result() | ||
if tensor.numel() == 1 and tensor.item() == NONE_INT: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is a bit tricky --- tensor.item()
can return wrong value when I use vLLM's PyNccl to transmit tensor but print(tensor)
can return the right value. You don't need to change any code here if you don't see any bug when stress testing this code (I have a stress test code in tests/kv_transfer/test_send_recv.sh you can try to use that).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, MooncakePipe
has not encountered similar problems during stress testing and benchmarking. However, will keep this part in mind and make corresponding modifications and tests as PyNcclPipe
changes.
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you for your contribution!
Hello, could you please provide the PCIe bus bandwidth of the benchmark machine? This will help us make a horizontal comparison. Currently, the machine I am using is configured with 8*A100-SXM4-80GB GPUs, and the bus bandwidth for each slot is Speed 16GT/s, Width x16. |
@liweiqing1997 Hello. Here we use 8x GA102GL [A10] , the PCIe Phylink Speed is 16GT/s (x16). |
Hello, may I ask if you're using GPUDirect RDMA? I want to know what the transfer path for KvCache is . is it VRAM->DRAM->DRAM->VRAM? or VRAM->VRAM? |
@liweiqing1997 |
Hello, I try MooncakeConnector in vllm, choose {"protocol": "tcp"} and {"kv_buffer_device":"cpu"}. When prefill node need to send tensor, I move the tensor from the gpu to the cpu. But when the input_token of request is greater than 800, there is an error on the decode node: It seems that the error is in the place where the k cache is received. I compare the received src pointer with the data length. and they are the same as the prefill node transmits. I do not know whether it is because the amount of k cache data transmitted by tcp is too large, and my network speed is relatively slow, resulting in the loss of data transmission. |
@liuyumoye You should not set |
@ShangmingCai Sorry, I don't understand why By the way, I am also curious, if the initial tensor is on the gpu, where is the VRAM->DRAM and DRAM->VRAM in the transfer chain implemented? |
That is correct. But
You can check the implementation of |
I think this comment is because vllm uses pynccl for transmitting kv cache, but this is not the case for mooncake.
There is no transfer from VRAM to DRAM in |
Nope, either |
BTW, tensor objects have device info. If you are not sure where the transfers happen, you can print the device info of tensors, it might help. |
We really appreciate @KuntaiDu for his remarkable work in supporting the disaggregated prefill feature in vLLM. Since PR #10502 has been merged. After rebase, we switch the mooncake integration from PR #10728 to here.
This PR is related to #10727, as well as a continuation of PR #10502, which uses Mooncake's Transfer Engine for KVCache transfer instead of NCCL.
Mooncake is a KVCache-centric disaggregated architecture for LLM serving. Transfer Engine is the core component of Mooncake, see documentations for its design & API list.
Compared with NCCL, Mooncake Transfer Engine has the following features:
Like the current implementation of PR #10502, there are two roles: KV provider (e.g. prefill vLLM instance) and KV consumer (e.g. decode vLLM instance)
insert
: insert a KV cache into a buffer, so that it can be transferred upon requestdrop_select
: select a KV cache based on tokens, transfer the selected KV, and drop this KV out from the bufferBoth roles are run on different machines.
Integration guide: https://github.com/kvcache-ai/mooncake/blob/main/doc/en/vllm-integration-v0.2.md
Benchmark result: https://github.com/kvcache-ai/mooncake/blob/main/doc/en/vllm-benchmark-results-v0.2.md
More benchmark results will be added in the future.
Test files will be added to align with the future test CI pipeline for PR #10502.
CC List.
@KuntaiDu @youkaichao @alogfans @stmatengss @james0zan