Skip to content

Conversation

@LCAIZJ
Copy link
Contributor

@LCAIZJ LCAIZJ commented Jul 1, 2025

What this PR does / why we need it?

This PR adopt Mooncake TransferEngine for kv cache register and pull_blocks style disaggregate prefill implementation.

Does this PR introduce any user-facing change?

No

Dependencies

  1. Cann Dependencies
    Using Mooncake TransferEngine with Ascend Transport requires CANN version 8.2.RC1 or higher.(see detail Mooncake#502

  2. vllm-ascend
    This PR depends on changes introduced by Disaggregate prefill for kv cache register style #950 (modifications to model_runner_v1) and [V0.9.1][Bugfix] Remove schedulre patch for disaggregated PD #1361 (updates to schedule), both of which have been merged into the v0.9.1-dev branch and are expected to land in main shortly.

How was this patch tested?

@jianzs jianzs self-assigned this Jul 1, 2025
@LCAIZJ LCAIZJ force-pushed the dev branch 3 times, most recently from f8d44b2 to 7977dac Compare July 3, 2025 08:00
@LCAIZJ
Copy link
Contributor Author

LCAIZJ commented Jul 3, 2025

Change log

07/03: change transfer_sync_read to batch_transfer_sync_read
07/05: solve some format and mypy problem
07/06: rename MooncakeConnectorV1

@jianzs
Copy link
Collaborator

jianzs commented Jul 3, 2025

@LCAIZJ plz make ci happy

@zzy-ContiLearn
Copy link
Contributor

@LCAIZJ plz make ci happy

god bless we

@LCAIZJ LCAIZJ force-pushed the dev branch 6 times, most recently from d6a48f0 to f8fbcf6 Compare July 4, 2025 15:42
@codecov
Copy link

codecov bot commented Jul 5, 2025

Codecov Report

❌ Patch coverage is 90.89470% with 115 lines in your changes missing coverage. Please review.
✅ Project coverage is 77.04%. Comparing base (1a70564) to head (7337fc9).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
vllm_ascend/distributed/mooncake_connector.py 83.86% 86 Missing ⚠️
tests/ut/kv_connector/test_mooncake_connector.py 96.02% 29 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1568      +/-   ##
==========================================
+ Coverage   76.35%   77.04%   +0.69%     
==========================================
  Files         117      120       +3     
  Lines       13371    14788    +1417     
==========================================
+ Hits        10209    11393    +1184     
- Misses       3162     3395     +233     
Flag Coverage Δ
unittests 77.04% <90.89%> (+0.69%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@LCAIZJ
Copy link
Contributor Author

LCAIZJ commented Jul 5, 2025

@jianzs CI is happy

@github-actions
Copy link

github-actions bot commented Jul 7, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@LCAIZJ
Copy link
Contributor Author

LCAIZJ commented Jul 9, 2025

solve conflicts

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

LCAIZJ added 5 commits August 13, 2025 21:24
Signed-off-by: leichao.lc <leichao139636@163.com>
Signed-off-by: leichao.lc <leichao139636@163.com>
Signed-off-by: leichao.lc <leichao139636@163.com>
Signed-off-by: leichao.lc <leichao139636@163.com>
Signed-off-by: leichao.lc <leichao139636@163.com>
@LCAIZJ
Copy link
Contributor Author

LCAIZJ commented Aug 13, 2025

@wangxiyuan Apologies, the code review comments were lost due to a force-push. Implemented the majority of your suggestions, and Mooncake version specifics need elaboration.
The version issue with the Mooncake package is mainly related to hardware-specific compilation configurations. Currently, Mooncake requires adjusting compile-time parameters based on different hardware platforms. For instance, when using Ascend 910B/910C NPUs, you need to enable the Ascend transport module by setting the USE_ASCEND option: option(USE_ASCEND "Option to enable NPU support" ON) . This configuration should be set in the CMake build system (refer to common.cmake#L62).
The default compilation settings in the official Mooncake repository have USE_ASCEND disabled. Users targeting Ascend hardware must manually enable this option and build the package from source following the instructions in the Ascend Transport Configuration Guide.

"prefill with num_computed_tokens == 0."
# Assume that the request's KV cache is already fully prefilled and
# can be fetched entirely from the prefill node.
count = max(len(request.prompt_token_ids) - 1, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line looks strange, the external computed token should not include the local computed token? Or you just not support prefix cache on the decode instance?

Copy link
Collaborator

@jianzs jianzs Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. For this implementation, we assume two things: decode node without prefix cache enabled, and prefill kv cache will not use block size truncation.

Copy link
Collaborator

@ganyi1996ppo ganyi1996ppo Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? prefix cache is almost some kind of free lunch, it should bring no pain in any scenario

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. While we typically consider prefix cache and KV pool as a coupled solution, enabling GPU prefix cache alone could be a great option for decode nodes. We'll explore this approach.

if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
local_block_ids = (blocks.get_unhashed_block_ids()
if num_external_tokens > 0 else [])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity, is there any chance for disaggregated pd that a remote computed request dose not have num_external_tokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct. In our current implementation, decode node without prefix cache enabled, so the scenario where num_external_tokens = 0 does not occur. If remote_blocks and num_external_tokens = 0, we have a full prefix cache hit on the D worker. We need to call send_notif in _read_blocks to free the memory on the P.

remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
last_token_id=request.output_token_ids[-1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So mooncake just support returning the first token from prefiller's output? But how dose the decode instance use this attributed? Seems no usage for this attr

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for raising this concern. This is on me. We've made internal modifications to the API server where the decoder appends the token ID to the original prompt token IDs upon receiving a request. We need to consider how to properly adapt this in the open-source version. Maybe, we could have the decoder directly return the first token?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I think its worth to try

block_len = (self.block_len[k % 2]
if self.use_mla else self.block_len[0])
for i, remote_block_id in enumerate(grouped_remote_block_ids):
local_block_ids = grouped_local_block_ids[i]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you grouped the block ids to make sure you pull blocks in a locality friendly way? How much improvements it gets? Normally, the block ids should be already sequential in most scenario, just wonder if the cpu overhead of contiguous rearange can be actually beaten by its transfer benefits.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although we haven't measured the specific benefits of block grouping, our experiments show that blocks are often non-sequential. Due to block sizes tending to be modest, it seems that the computational overhead for reordering operations should be negligible?

Signed-off-by: leichao.lc <leichao139636@163.com>
@Shichang-Zhang
Copy link

@LCAIZJ Hello, thanks for your contribution! Recently I'am trying to run vllm v1 + mooncake with mooncake connector on NPU. I built an image of your PR with vllm-ascend v0.10.1 and Ascend_transport. I tried to run it on K8s and assgined 4 of 8 910B NPU to the pod:

        resources:
          limits:
            huawei.com/Ascend910: "4"
          requests:
            huawei.com/Ascend910: "4"

Then I entered the container and exported ENV variables ASCEND_RT_VISIBLE_DEVICES=2,3,4,5 (K8s assigns NPU resource randomly).
When I start command to run vllm:

JSON_CONTENT="{\"local_hostname\": \"$POD_IP\", \"metadata_server\": \"redis://10.250.39.62:6379\",\"protocol\": \"tcp\",\"device_name\": \"\",\"master_server_address\": \"10.250.39.46:50001\",\"protocol\":\"ascend\"}"
          echo "$JSON_CONTENT" > /app/mooncake.json
          MOONCAKE_CONFIG_PATH=/app/mooncake.json VLLM_USE_V1=1 python3 -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct --port 8100 --tensor-parallel-size 2 --max-model-len 10000 --data-parallel-size 2 --data-parallel-address localhost --data-parallel-rpc-port 9100 --kv-transfer-config '{"kv_connector":"MooncakeConnectorV1","kv_role":"kv_producer","kv_buffer_device":"npu","kv_connector_module_path":"vllm_ascend.distributed.mooncake_connector","kv_parallel_size": 1,"kv_port": "20001","engine_id": "0","kv_rank": 0,"kv_connector_extra_config":{"prefill":{"tp_size":2,"dp_size":2},"decode":{"tp_size":2,"dp_size":2}}}'

I encountered the problem from torch_npu(v2.7.1):
vllm-mooncake-torchnpu-bug1
I think in the containter torch_npu could only see the assigned NPU device with logicId ranged [0-4), but the possible physical id in ENV variables ASCEND_RT_VISIBLE_DEVICES ranged [0,8).
But if ENV variables ASCEND_RT_VISIBLE_DEVICES is not assigend, in the mooncake connecotr initialization stage, the NPU device id will be automatically 0, which is the wrong physical device id, and cause the process try to connect an invisible NPU device. So I wonder should it to use ranktable to get the physical-logical id mapping instead of directly reading the ENV variables ASCEND_RT_VISIBLE_DEVICES.

@zzy-ContiLearn
Copy link
Contributor

@LCAIZJ Hello, thanks for your contribution! Recently I'am trying to run vllm v1 + mooncake with mooncake connector on NPU. I built an image of your PR with vllm-ascend v0.10.1 and Ascend_transport. I tried to run it on K8s and assgined 4 of 8 910B NPU to the pod:

        resources:
          limits:
            huawei.com/Ascend910: "4"
          requests:
            huawei.com/Ascend910: "4"

Then I entered the container and exported ENV variables ASCEND_RT_VISIBLE_DEVICES=2,3,4,5 (K8s assigns NPU resource randomly). When I start command to run vllm:

JSON_CONTENT="{\"local_hostname\": \"$POD_IP\", \"metadata_server\": \"redis://10.250.39.62:6379\",\"protocol\": \"tcp\",\"device_name\": \"\",\"master_server_address\": \"10.250.39.46:50001\",\"protocol\":\"ascend\"}"
          echo "$JSON_CONTENT" > /app/mooncake.json
          MOONCAKE_CONFIG_PATH=/app/mooncake.json VLLM_USE_V1=1 python3 -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct --port 8100 --tensor-parallel-size 2 --max-model-len 10000 --data-parallel-size 2 --data-parallel-address localhost --data-parallel-rpc-port 9100 --kv-transfer-config '{"kv_connector":"MooncakeConnectorV1","kv_role":"kv_producer","kv_buffer_device":"npu","kv_connector_module_path":"vllm_ascend.distributed.mooncake_connector","kv_parallel_size": 1,"kv_port": "20001","engine_id": "0","kv_rank": 0,"kv_connector_extra_config":{"prefill":{"tp_size":2,"dp_size":2},"decode":{"tp_size":2,"dp_size":2}}}'

I encountered the problem from torch_npu(v2.7.1): vllm-mooncake-torchnpu-bug1 I think in the containter torch_npu could only see the assigned NPU device with logicId ranged [0-4), but the possible physical id in ENV variables ASCEND_RT_VISIBLE_DEVICES ranged [0,8). But if ENV variables ASCEND_RT_VISIBLE_DEVICES is not assigend, in the mooncake connecotr initialization stage, the NPU device id will be automatically 0, which is the wrong physical device id, and cause the process try to connect an invisible NPU device. So I wonder should it to use ranktable to get the physical-logical id mapping instead of directly reading the ENV variables ASCEND_RT_VISIBLE_DEVICES.

I also encountered the same issue. The device_id assigned by Kubernetes might start counting from 0 by default. In our connector, we currently read the device_id from the OS environment variables, it may raise error when quering device_ip in hccn_conf file. A possible solution is to add a new environment variable in vLLM (e.g., VLLM_DEVICE_ID) and then retrieve the device_id from vllm.envs.

Signed-off-by: leichao.lc <leichao139636@163.com>
@LCAIZJ
Copy link
Contributor Author

LCAIZJ commented Aug 15, 2025

@LCAIZJ Hello, thanks for your contribution! Recently I'am trying to run vllm v1 + mooncake with mooncake connector on NPU. I built an image of your PR with vllm-ascend v0.10.1 and Ascend_transport. I tried to run it on K8s and assgined 4 of 8 910B NPU to the pod:

        resources:
          limits:
            huawei.com/Ascend910: "4"
          requests:
            huawei.com/Ascend910: "4"

Then I entered the container and exported ENV variables ASCEND_RT_VISIBLE_DEVICES=2,3,4,5 (K8s assigns NPU resource randomly). When I start command to run vllm:

JSON_CONTENT="{\"local_hostname\": \"$POD_IP\", \"metadata_server\": \"redis://10.250.39.62:6379\",\"protocol\": \"tcp\",\"device_name\": \"\",\"master_server_address\": \"10.250.39.46:50001\",\"protocol\":\"ascend\"}"
          echo "$JSON_CONTENT" > /app/mooncake.json
          MOONCAKE_CONFIG_PATH=/app/mooncake.json VLLM_USE_V1=1 python3 -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct --port 8100 --tensor-parallel-size 2 --max-model-len 10000 --data-parallel-size 2 --data-parallel-address localhost --data-parallel-rpc-port 9100 --kv-transfer-config '{"kv_connector":"MooncakeConnectorV1","kv_role":"kv_producer","kv_buffer_device":"npu","kv_connector_module_path":"vllm_ascend.distributed.mooncake_connector","kv_parallel_size": 1,"kv_port": "20001","engine_id": "0","kv_rank": 0,"kv_connector_extra_config":{"prefill":{"tp_size":2,"dp_size":2},"decode":{"tp_size":2,"dp_size":2}}}'

I encountered the problem from torch_npu(v2.7.1): vllm-mooncake-torchnpu-bug1 I think in the containter torch_npu could only see the assigned NPU device with logicId ranged [0-4), but the possible physical id in ENV variables ASCEND_RT_VISIBLE_DEVICES ranged [0,8). But if ENV variables ASCEND_RT_VISIBLE_DEVICES is not assigend, in the mooncake connecotr initialization stage, the NPU device id will be automatically 0, which is the wrong physical device id, and cause the process try to connect an invisible NPU device. So I wonder should it to use ranktable to get the physical-logical id mapping instead of directly reading the ENV variables ASCEND_RT_VISIBLE_DEVICES.

Thank you for your question. In the Mooncake Connector, we identify NPUs through the ASCEND_RT_VISIBLE_DEVICES parameter. For example, when using NPUs 2, 3, 4, and 5, your Kubernetes YAML should set ASCEND_RT_VISIBLE_DEVICES="2,3,4,5". If this parameter is not configured, we assume the NPU indexing starts from 0 by default.
Our goal is to eliminate reliance on the ranktable. However, if the ranktable is still used, how would the physical NPU IDs be mapped in such a scenario?

@wangxiyuan
Copy link
Collaborator

let's consider to build the mooncake package into docker image as well. @wxsIcey

@wangxiyuan wangxiyuan merged commit 03ca2b2 into vllm-project:main Aug 18, 2025
18 of 20 checks passed
@wangxiyuan
Copy link
Collaborator

wangxiyuan commented Aug 18, 2025

The CI failed due to other known issue. Let's merge this first. Feel free to add following-up PR if there is any other issue.

@Shichang-Zhang
Copy link

@LCAIZJ @zzy-ContiLearn Thank you for your advice!
Firstly, I think we could only assign the number of NPU devices such as 4 to the pod, but only aware the actual device assigned such as device 1,5,6,7 in the container? So I need to export the ENV variable ASCEND_RT_VISIBLE_DEVICES in the container instead of configuring it in the deployment yaml.
Secondly, if the ENV variable ASCEND_RT_VISIBLE_DEVICES is used, the torch_npu will get the error shown previously at first. Torch_npu wants the device id in the range of [0,4).
Currently I solve this problem by using a new ENV variable, ex. VLLM_VISIBLE_DEVICES, to avoid the error caused by torch_npu. Then I modify the processing logic of this ENV variable, since in multiple dp situation, every dp retrieves first several elements (number of tp rank) in the VLLM_VISIBLE_DEVICES. For example, if the ENV variable is 1,5,6,7 and config is dp2 tp2, dp0 tp0 gets device 1 and dp1 tp0 also gets device 1.
The patch file:

diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py
index e223877..2d84098 100644
--- a/vllm_ascend/distributed/mooncake_connector.py
+++ b/vllm_ascend/distributed/mooncake_connector.py
@@ -758,13 +758,15 @@ class MooncakeConnectorWorker:
         # get tp device id
         # TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940
         # introducing some changes
-        device_ids_str = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None)
+        device_ids_str = os.getenv("VLLM_VISIBLE_DEVICES", None)
         if device_ids_str is None:
             device_ids = list(
                 range(self.dp_rank * self.tp_size,
                       (self.dp_rank + 1) * self.tp_size))
         else:
-            device_ids = list(map(int, device_ids_str.split(',')))
+            #device_ids = list(map(int, device_ids_str.split(',')))
+            device_ids = list(map(int, device_ids_str.split(',')))[self.dp_rank * self.tp_size : (self.dp_rank + 1) * self.tp_size]
+        logger.info("device_ids: %s", device_ids)
         assert len(device_ids) > self.tp_rank  # type: ignore
         self.device_id = device_ids[self.tp_rank]  # type: ignore

@LCAIZJ
Copy link
Contributor Author

LCAIZJ commented Aug 20, 2025

@LCAIZJ @zzy-ContiLearn Thank you for your advice! Firstly, I think we could only assign the number of NPU devices such as 4 to the pod, but only aware the actual device assigned such as device 1,5,6,7 in the container? So I need to export the ENV variable ASCEND_RT_VISIBLE_DEVICES in the container instead of configuring it in the deployment yaml. Secondly, if the ENV variable ASCEND_RT_VISIBLE_DEVICES is used, the torch_npu will get the error shown previously at first. Torch_npu wants the device id in the range of [0,4). Currently I solve this problem by using a new ENV variable, ex. VLLM_VISIBLE_DEVICES, to avoid the error caused by torch_npu. Then I modify the processing logic of this ENV variable, since in multiple dp situation, every dp retrieves first several elements (number of tp rank) in the VLLM_VISIBLE_DEVICES. For example, if the ENV variable is 1,5,6,7 and config is dp2 tp2, dp0 tp0 gets device 1 and dp1 tp0 also gets device 1. The patch file:

diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py
index e223877..2d84098 100644
--- a/vllm_ascend/distributed/mooncake_connector.py
+++ b/vllm_ascend/distributed/mooncake_connector.py
@@ -758,13 +758,15 @@ class MooncakeConnectorWorker:
         # get tp device id
         # TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940
         # introducing some changes
-        device_ids_str = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None)
+        device_ids_str = os.getenv("VLLM_VISIBLE_DEVICES", None)
         if device_ids_str is None:
             device_ids = list(
                 range(self.dp_rank * self.tp_size,
                       (self.dp_rank + 1) * self.tp_size))
         else:
-            device_ids = list(map(int, device_ids_str.split(',')))
+            #device_ids = list(map(int, device_ids_str.split(',')))
+            device_ids = list(map(int, device_ids_str.split(',')))[self.dp_rank * self.tp_size : (self.dp_rank + 1) * self.tp_size]
+        logger.info("device_ids: %s", device_ids)
         assert len(device_ids) > self.tp_rank  # type: ignore
         self.device_id = device_ids[self.tp_rank]  # type: ignore

We ran some tests in our local environment using vLLM-Ascend v0.9.1 (we did not use v0.10.1 due to potential dependency compatibility concerns) and were unable to reproduce the issue you mentioned.
In our setup, when the device list is [0, 1, 2, 3], we observed that:
For DP0, ASCEND_RT_VISIBLE_DEVICES is configured as [0, 1].
For DP1, ASCEND_RT_VISIBLE_DEVICES is configured as [2, 3].
After reviewing the relevant code in vLLM, we believe the current implementation should correctly partition devices across different data parallelism (DP) groups.
image

@Shichang-Zhang
Copy link

@LCAIZJ I used v0.9.1 vllm-ascend also. In my condition, I use K8s on a physical machine, and run several vllm (P or D) pods on it. As shown in the first question I commented under this PR, firstly the ENV variable ASCEND_RT_VISIBLE_DEVICES is not compatible with torch_npu. Because for each pod, the torch_npu in the container could only observe the mount NPU device, not all NPU device on the machine. When I try to avoid using ASCEND_RT_VISIBLE_DEVICES by introducing a new ENV variable, I need to do extra operations to get the device for the current DP, as shown in the patch.

chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
### What this PR does / why we need it?
This PR adopt Mooncake TransferEngine for kv cache register and
pull_blocks style disaggregate prefill implementation.

### Does this PR introduce any user-facing change?
No

### Dependencies
1. Cann Dependencies
Using Mooncake TransferEngine with Ascend Transport requires CANN
version 8.2.RC1 or higher.(see detail
Mooncake[vllm-project#502](kvcache-ai/Mooncake#502))

2. vllm-ascend
This PR depends on changes introduced by vllm-project#950 (modifications to
`model_runner_v1`) and vllm-project#1361 (updates to `schedule`), both of which have
been merged into the `v0.9.1-dev` branch and are expected to land in
`main` shortly.

### How was this patch tested?


- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@1c859a1

---------

Signed-off-by: leichao.lc <leichao139636@163.com>
Co-authored-by: jianzs <zheng.shoujian@outlook.com>
Co-authored-by: zzy-ContiLearn <1831242919@qq.com>
Co-authored-by: fems14 <1804143737@qq.com>
Co-authored-by: Dreamerleader <2270923832@qq.com>
Co-authored-by: chris668899 <15105191595@126.com>
Co-authored-by: Pz1116 <zpbzpb123123@gmail.com>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
### What this PR does / why we need it?
This PR adopt Mooncake TransferEngine for kv cache register and
pull_blocks style disaggregate prefill implementation.

### Does this PR introduce any user-facing change?
No

### Dependencies
1. Cann Dependencies
Using Mooncake TransferEngine with Ascend Transport requires CANN
version 8.2.RC1 or higher.(see detail
Mooncake[vllm-project#502](kvcache-ai/Mooncake#502))

2. vllm-ascend
This PR depends on changes introduced by vllm-project#950 (modifications to
`model_runner_v1`) and vllm-project#1361 (updates to `schedule`), both of which have
been merged into the `v0.9.1-dev` branch and are expected to land in
`main` shortly.

### How was this patch tested?


- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@1c859a1

---------

Signed-off-by: leichao.lc <leichao139636@163.com>
Co-authored-by: jianzs <zheng.shoujian@outlook.com>
Co-authored-by: zzy-ContiLearn <1831242919@qq.com>
Co-authored-by: fems14 <1804143737@qq.com>
Co-authored-by: Dreamerleader <2270923832@qq.com>
Co-authored-by: chris668899 <15105191595@126.com>
Co-authored-by: Pz1116 <zpbzpb123123@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation module:tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants