From bc8a2687ef9eb388cd910856f0d41e6a7bdd354b Mon Sep 17 00:00:00 2001 From: ganyi Date: Mon, 19 May 2025 11:47:29 +0800 Subject: [PATCH 1/2] draft of a3 llmdatadist connector Signed-off-by: ganyi --- .github/workflows/vllm_ascend_test_pd.yaml | 4 + examples/disaggregate_prefill_v1/README.md | 234 ++++++ .../disaggregate_prefill_v1/gen_ranktable.py | 120 +++ .../disaggregate_prefill_v1/gen_ranktable.sh | 80 ++ .../disaggregate_prefill_v1/run_server.sh | 32 + .../toy_proxy_server.py | 261 ++++++ .../e2e/pd_disaggreate/run_edge_case_test.sh | 141 ++++ tests/e2e/pd_disaggreate/test_edge_cases.py | 81 ++ .../test_llmdatadist_connector.py | 42 + .../test_remote_decode_lifecycle.py | 123 +++ .../test_remote_prefill_lifecycle.py | 242 ++++++ tests/ut/kv_connector/utils.py | 196 +++++ vllm_ascend/attention/attention_v1.py | 2 +- vllm_ascend/attention/mla_v1.py | 57 +- vllm_ascend/distributed/__init__.py | 5 + .../llmdatadist_c_mgr_connector.py | 789 ++++++++++++++++++ vllm_ascend/envs.py | 17 + vllm_ascend/models/deepseek_v2.py | 12 +- vllm_ascend/ops/attention.py | 34 +- vllm_ascend/patch/__init__.py | 14 +- .../patch/platform/patch_common/__init__.py | 1 + .../platform/patch_common/patch_scheduler.py | 491 +++++++++++ vllm_ascend/soc_info.py | 14 + vllm_ascend/worker/model_runner_v1.py | 156 +++- vllm_ascend/worker/worker_v1.py | 3 + 25 files changed, 3074 insertions(+), 77 deletions(-) create mode 100644 examples/disaggregate_prefill_v1/README.md create mode 100644 examples/disaggregate_prefill_v1/gen_ranktable.py create mode 100644 examples/disaggregate_prefill_v1/gen_ranktable.sh create mode 100644 examples/disaggregate_prefill_v1/run_server.sh create mode 100644 examples/disaggregate_prefill_v1/toy_proxy_server.py create mode 100644 tests/e2e/pd_disaggreate/run_edge_case_test.sh create mode 100644 tests/e2e/pd_disaggreate/test_edge_cases.py create mode 100644 tests/ut/kv_connector/test_llmdatadist_connector.py create mode 100644 tests/ut/kv_connector/test_remote_decode_lifecycle.py create mode 100644 tests/ut/kv_connector/test_remote_prefill_lifecycle.py create mode 100644 tests/ut/kv_connector/utils.py create mode 100644 vllm_ascend/distributed/llmdatadist_c_mgr_connector.py create mode 100644 vllm_ascend/patch/platform/patch_common/patch_scheduler.py create mode 100644 vllm_ascend/soc_info.py diff --git a/.github/workflows/vllm_ascend_test_pd.yaml b/.github/workflows/vllm_ascend_test_pd.yaml index 57d8cf2a60..015c88c170 100644 --- a/.github/workflows/vllm_ascend_test_pd.yaml +++ b/.github/workflows/vllm_ascend_test_pd.yaml @@ -103,3 +103,7 @@ jobs: - name: Run vllm-project/vllm-ascend PD Disaggregation test run: | pytest -sv tests/e2e/pd_disaggreate/test_pd_e2e.py + + - name: Run vllm-project/vllm-ascend PD Disaggregation edge test + run: | + bash tests/e2e/pd_disaggreate/run_edge_case_test.sh \ No newline at end of file diff --git a/examples/disaggregate_prefill_v1/README.md b/examples/disaggregate_prefill_v1/README.md new file mode 100644 index 0000000000..f096dcd1ce --- /dev/null +++ b/examples/disaggregate_prefill_v1/README.md @@ -0,0 +1,234 @@ +# Disaggregated Prefill-Decode Deployment Guide + +## Overview +This demo document provides instructions for running a disaggregated vLLM-ascend service with separate prefill and decode stages across 4 nodes, uses 16 Ascend NPUs for two prefill nodes (P1/P2) and 16 Ascend NPUS for two decode nodes (D1/D2). + +## Prerequisites +- Ascend NPU environment with vLLM 0.9.1 installed +- Network interfaces configured for distributed communication (eg: eth0) +- Model weights located at `/data01/deepseek_r1_w8a8_zhw` + +## Rank table generation +The rank table is a JSON file that specifies the mapping of Ascend NPU ranks to nodes. The following command generates a rank table for all nodes with 16 cards prefill and 16 cards decode: + +Run the following command on every node to generate the rank table: +```shell +cd vllm-ascend/examples/disaggregate_prefill_v1/ +bash gen_ranktable.sh --ips 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36 \ + --npus-per-node 8 --network-card-name enp189s0f0 --prefill-device-cnt 16 --decode-device-cnt 16 +``` +Rank table will generated at `/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json` + +## Start disaggregated vLLM-ascend service +Execution Sequence +- 4 configured node ip are: 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36 +- Start Prefill on Node 1 (P1) +- Start Prefill on Node 2 (P2) +- Start Decode on Node 1 (D1) +- Start Decode on Node 2 (D2) +- Start proxy server on Node1 + +* Run prefill server P1 on first node +```shell +export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'` +export GLOO_SOCKET_IFNAME="eth0" +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +export VLLM_VERSION=0.9.1 +vllm serve /data01/deepseek_r1_w8a8_zhw \ + --host 0.0.0.0 \ + --port 20002 \ + --data-parallel-size 2 \ + --data-parallel-size-local 1 \ + --api-server-count 2 \ + --data-parallel-address 172.19.32.175 \ + --data-parallel-rpc-port 13356 \ + --tensor-parallel-size 8 \ + --no-enable-prefix-caching \ + --seed 1024 \ + --served-model-name deepseek \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --enforce-eager \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_producer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "expert_tensor_parallel_size": 1}' +``` + +* Run prefill server P2 on second node +```shell +export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'` +export GLOO_SOCKET_IFNAME="eth0" +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +export VLLM_VERSION=0.9.1 +vllm serve /data01/deepseek_r1_w8a8_zhw \ + --host 0.0.0.0 \ + --port 20002 \ + --headless \ + --data-parallel-size 2 \ + --data-parallel-start-rank 1 \ + --data-parallel-size-local 1 \ + --data-parallel-address 172.19.32.175 \ + --data-parallel-rpc-port 13356 \ + --tensor-parallel-size 8 \ + --no-enable-prefix-caching \ + --seed 1024 \ + --served-model-name deepseek \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --enforce-eager \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_producer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "expert_tensor_parallel_size": 1}' +``` + +* Run decode server d1 on third node +```shell +export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'` +export GLOO_SOCKET_IFNAME="eth0" +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +export VLLM_VERSION=0.9.1 +vllm serve /data01/deepseek_r1_w8a8_zhw \ + --host 0.0.0.0 \ + --port 20002 \ + --data-parallel-size 2 \ + --data-parallel-size-local 1 \ + --api-server-count 2 \ + --data-parallel-address 172.19.123.51 \ + --data-parallel-rpc-port 13356 \ + --tensor-parallel-size 8 \ + --no-enable-prefix-caching \ + --seed 1024 \ + --served-model-name deepseek \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --enforce-eager \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_consumer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "expert_tensor_parallel_size": 1}' +``` + +* Run decode server d2 on last node +```shell +export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'` +export GLOO_SOCKET_IFNAME="eth0" +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +export VLLM_VERSION=0.9.1 +vllm serve /data01/deepseek_r1_w8a8_zhw \ + --host 0.0.0.0 \ + --port 20002 \ + --headless \ + --data-parallel-size 2 \ + --data-parallel-start-rank 1 \ + --data-parallel-size-local 1 \ + --data-parallel-address 172.19.123.51 \ + --data-parallel-rpc-port 13356 \ + --tensor-parallel-size 8 \ + --no-enable-prefix-caching \ + --seed 1024 \ + --served-model-name deepseek \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --enforce-eager \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_consumer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "expert_tensor_parallel_size": 1}' +``` + +* Run proxy server on the first node +```shell +cd /vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1 +python toy_proxy_server.py --host 172.19.32.175 --port 1025 --prefiller-hosts 172.19.241.49 --prefiller-port 20002 --decoder-hosts 172.19.123.51 --decoder-ports 20002 +``` + +* Verification +Check service health using the proxy server endpoint: +```shell +curl http://localhost:1025/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "deepseek", + "prompt": "Who are you?", + "max_tokens": 100, + "temperature": 0 + }' +``` + +* Performance +Test performance with vllm benchmark +```shell +cd /vllm-workspace/vllm/benchmarks +python3 benchmark_serving.py \ + --backend vllm \ + --dataset-name random \ + --random-input-len 4096 \ + --random-output-len 1536 \ + --num-prompts 256 \ + --ignore-eos \ + --model deepseek \ + --tokenizer /data01/deepseek_r1_w8a8_zhw \ + --host localhost \ + --port 8000 \ + --endpoint /v1/completions \ + --max-concurrency 4 \ + --request-rate 4 +``` \ No newline at end of file diff --git a/examples/disaggregate_prefill_v1/gen_ranktable.py b/examples/disaggregate_prefill_v1/gen_ranktable.py new file mode 100644 index 0000000000..d170f3ba06 --- /dev/null +++ b/examples/disaggregate_prefill_v1/gen_ranktable.py @@ -0,0 +1,120 @@ +import argparse +import json +import os + +import torch.distributed as dist + +from vllm_ascend.soc_info import NPUSocInfo + +parser = argparse.ArgumentParser( + description="Arguments of rank table generator", ) +parser.add_argument("--local-host", type=str, required=True, help="local ip") +parser.add_argument("--prefill-device-cnt", + type=int, + required=True, + help="number of prefill devices") +parser.add_argument("--decode-device-cnt", + type=int, + required=True, + help="number of decode devices") +args = parser.parse_args() +local_host = args.local_host +prefill_device_cnt = args.prefill_device_cnt +decode_device_cnt = args.decode_device_cnt + +print("enter py") + +hccn_tool_path = os.environ.get("HCCN_TOOL_PATH", + "/usr/local/Ascend/driver/tools/hccn_tool") +master_addr = os.environ.get("MASTER_ADDR") +master_port = os.environ.get("MASTER_PORT") +rank = os.environ.get("RANK") +local_rank = os.environ.get("LOCAL_RANK") +# This variable is set by torchrun, +# and is different from WORLD_SIZE in gen_rank_table.sh. +world_size = os.environ.get("WORLD_SIZE") +soc_info = NPUSocInfo() + + +def get_cmd_stdout(cmd): + import subprocess + return subprocess.run(cmd, capture_output=True, + shell=True).stdout.decode("utf-8").strip() + + +print(f"local_host: {local_host}") +print("gen ranktable.json") + +num_cards = get_cmd_stdout("npu-smi info -l | grep \"Total Count\"").split( + ":")[1].strip() +num_cards = int(num_cards) +chips_per_card = get_cmd_stdout("npu-smi info -l | grep \"Chip Count\"").split( + "\n")[0].split(":")[1].strip() +chips_per_card = int(chips_per_card) + +# generate local device list for local rank 0, and gather it to all ranks +local_device_list: list[dict[str, str]] = list() +if local_rank == "0": + super_pod_id = "0" + for card_id in range(num_cards): + for chip_id in range(chips_per_card): + device_id = card_id * chips_per_card + chip_id + if soc_info.is_a3: + device_ip = get_cmd_stdout( + f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr" + ).split(":")[1].strip() + super_device_id = get_cmd_stdout( + f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep SDID" + ).split(":")[1].strip() + super_pod_id = get_cmd_stdout( + f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep \"Super Pod ID\"" + ).split(":")[1].strip() + else: + device_ip = get_cmd_stdout( + f"{hccn_tool_path} -i {device_id} -ip -g | grep ipaddr" + ).split(":")[1].strip() + + device_info = { + "server_id": local_host, + "device_id": str(device_id), + "device_ip": str(device_ip), + } + if soc_info.is_a3: + device_info.update({ + "super_pod_id": str(super_pod_id), + "super_device_id": str(super_device_id) + }) + local_device_list.append(device_info) + +dist.init_process_group(backend=dist.Backend.GLOO) +global_device_list = [None] * dist.get_world_size() +dist.all_gather_object(global_device_list, local_device_list) +global_device_list = [ + device_info for device_list in global_device_list + for device_info in device_list # type: ignore[attr-defined] +] +cnt = 1 +for device_info in global_device_list: # type: ignore[assignment] + device_info["cluster_id"] = str(cnt) + cnt += 1 +assert (prefill_device_cnt + decode_device_cnt) <= len(global_device_list), \ +"prefill_device_cnt + decode_device_cnt must be less than or equal to number of all devices in cluster" +ranktable = { + "version": + "1.2", + "server_count": + str(world_size), + "prefill_device_list": + global_device_list[:prefill_device_cnt], + "decode_device_list": + global_device_list[prefill_device_cnt:prefill_device_cnt + + decode_device_cnt], + "status": + "completed" +} + +if local_rank == '0': + with open("ranktable.json", "w") as f: + json.dump(ranktable, f, indent=4) + + print("gen ranktable.json done") diff --git a/examples/disaggregate_prefill_v1/gen_ranktable.sh b/examples/disaggregate_prefill_v1/gen_ranktable.sh new file mode 100644 index 0000000000..516bea956c --- /dev/null +++ b/examples/disaggregate_prefill_v1/gen_ranktable.sh @@ -0,0 +1,80 @@ +#!/bin/bash + +source /usr/local/Ascend/ascend-toolkit/set_env.sh +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH} + +NPUS_PER_NODE=8 +while [[ $# -gt 0 ]]; do + case "$1" in + --ips) + shift + # 收集所有后续参数直到遇到下一个选项或结束 + while [[ $# -gt 0 && ! "$1" == --* ]]; do + IPs+=("$1") + shift + done + ;; + --npus-per-node) + shift + NPUS_PER_NODE="$1" + shift + ;; + --network-card-name) + shift + NETWORK_CARD_NAME="$1" + shift + ;; + --prefill-device-cnt) + shift + PREFILL_DEVICE_CNT="$1" + shift + ;; + --decode-device-cnt) + shift + DECODE_DEVICE_CNT="$1" + shift + ;; + esac +done +LOCAL_HOSTS=($(hostname -I)) +LOCAL_HOST="127.0.0.1" +MASTER_ADDR=${IPs[0]} +MASTER_PORT=6657 +NNODES=${#IPs[@]} +NODE_RANK="8" +for i in "${!IPs[@]}"; do + ip="${IPs[$i]}" + for local_host in "${LOCAL_HOSTS[@]}"; do + if [[ "$local_host" == "$ip" ]]; then + LOCAL_HOST=$local_host + NODE_RANK=$i + break 2 + fi + done +done + +if [[ $NODE_RANK == "" ]];then + echo "[Error] para \"NODE_RANK\" must be defined" + exit 1 +fi + +WORLD_SIZE=$(($NPUS_PER_NODE * $NNODES)) +RANKSTART=`expr $NPUS_PER_NODE \* $NODE_RANK` + +echo "========>param:" +echo "LOCAL_HOST": $LOCAL_HOST +echo "WORLD_SIZE: " $WORLD_SIZE +echo "RANKSTART": $RANKSTART +echo "NNODES": $NNODES +echo "NODE_RANK": $NODE_RANK +echo "===============" + +if [[ -n "${GEN_RANKTABLE}" || ! -e ${PWD}/ranktable.json ]]; then + GLOO_SOCKET_IFNAME=$NETWORK_CARD_NAME torchrun \ + --nproc_per_node 1 \ + --nnodes ${NNODES} \ + --node_rank ${NODE_RANK} \ + --master_addr ${MASTER_ADDR} \ + --master_port ${MASTER_PORT} \ + gen_ranktable.py --local-host $LOCAL_HOST --prefill-device-cnt $PREFILL_DEVICE_CNT --decode-device-cnt $DECODE_DEVICE_CNT +fi \ No newline at end of file diff --git a/examples/disaggregate_prefill_v1/run_server.sh b/examples/disaggregate_prefill_v1/run_server.sh new file mode 100644 index 0000000000..37cf6d3aee --- /dev/null +++ b/examples/disaggregate_prefill_v1/run_server.sh @@ -0,0 +1,32 @@ +export HCCL_IF_IP=141.61.39.117 +export GLOO_SOCKET_IFNAME="enp48s3u1u1" +export TP_SOCKET_IFNAME="enp48s3u1u1" +export HCCL_SOCKET_IFNAME="enp48s3u1u1" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=path-to-rank-table + +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 + +export VLLM_USE_V1=1 + +vllm serve model_path \ + --host 0.0.0.0 \ + --port 20002 \ + --tensor-parallel-size 1\ + --seed 1024 \ + --served-model-name dsv3 \ + --max-model-len 2000 \ + ---max-num-batched-tokens 2000 \ + --trust-remote-code \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_consumer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": 0, + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_connector_v1_a3" + }' \ + --additional-config \ + '{"enable_graph_mode": "True"}'\ diff --git a/examples/disaggregate_prefill_v1/toy_proxy_server.py b/examples/disaggregate_prefill_v1/toy_proxy_server.py new file mode 100644 index 0000000000..4478073f74 --- /dev/null +++ b/examples/disaggregate_prefill_v1/toy_proxy_server.py @@ -0,0 +1,261 @@ +# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py + +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import itertools +import os +import uuid +from contextlib import asynccontextmanager + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize client pools for prefiller and decoder services + app.state.prefill_clients = [] + app.state.decode_clients = [] + + # Create prefill clients + for i, (host, port) in enumerate(global_args.prefiller_instances): + prefiller_base_url = f'http://{host}:{port}/v1' + app.state.prefill_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) + + # Create decode clients + for i, (host, port) in enumerate(global_args.decoder_instances): + decoder_base_url = f'http://{host}:{port}/v1' + app.state.decode_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=decoder_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) + + # Initialize round-robin iterators + app.state.prefill_iterator = itertools.cycle( + range(len(app.state.prefill_clients))) + app.state.decode_iterator = itertools.cycle( + range(len(app.state.decode_clients))) + + print(f"Initialized {len(app.state.prefill_clients)} prefill clients " + f"and {len(app.state.decode_clients)} decode clients.") + + yield + + # Shutdown: Close all clients + for client_info in app.state.prefill_clients: + await client_info['client'].aclose() + + for client_info in app.state.decode_clients: + await client_info['client'].aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + + # For prefiller instances + parser.add_argument("--prefiller-hosts", + "--prefiller-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--prefiller-ports", + "--prefiller-port", + type=int, + nargs="+", + default=[8100]) + + # For decoder instances + parser.add_argument("--decoder-hosts", + "--decoder-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--decoder-ports", + "--decoder-port", + type=int, + nargs="+", + default=[8200]) + + args = parser.parse_args() + + # Validate and pair hosts with ports + if len(args.prefiller_hosts) != len(args.prefiller_ports): + raise ValueError( + "Number of prefiller hosts must match number of prefiller ports") + + if len(args.decoder_hosts) != len(args.decoder_ports): + raise ValueError( + "Number of decoder hosts must match number of decoder ports") + + # Create tuples of (host, port) for each service type + args.prefiller_instances = list( + zip(args.prefiller_hosts, args.prefiller_ports)) + args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) + + return args + + +def get_next_client(app, service_type: str): + """ + Get the next client in round-robin fashion. + + Args: + app: The FastAPI app instance + service_type: Either 'prefill' or 'decode' + + Returns: + The next client to use + """ + if service_type == 'prefill': + client_idx = next(app.state.prefill_iterator) + return app.state.prefill_clients[client_idx] + elif service_type == 'decode': + client_idx = next(app.state.decode_iterator) + return app.state.decode_clients[client_idx] + else: + raise ValueError(f"Unknown service type: {service_type}") + + +async def send_request_to_service(client_info: dict, endpoint: str, + req_data: dict, request_id: str): + """ + Send a request to a service using a client from the pool. + """ + req_data = req_data.copy() + req_data['kv_transfer_params'] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + + response = await client_info['client'].post(endpoint, + json=req_data, + headers=headers) + response.raise_for_status() + + return response + + +async def stream_service_response(client_info: dict, endpoint: str, + req_data: dict, request_id: str): + """ + Asynchronously stream response from a service using a client from the pool. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + + async with client_info['client'].stream("POST", + endpoint, + json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + try: + req_data = await request.json() + request_id = str(uuid.uuid4()) + + # Get the next prefill client in round-robin fashion + prefill_client_info = get_next_client(request.app, 'prefill') + + # Send request to prefill service + response = await send_request_to_service(prefill_client_info, + "/completions", req_data, + request_id) + + # Extract the needed fields + response_json = response.json() + kv_transfer_params = response_json.get('kv_transfer_params', {}) + if kv_transfer_params: + req_data["kv_transfer_params"] = kv_transfer_params + + # Get the next decode client in round-robin fashion + decode_client_info = get_next_client(request.app, 'decode') + + logger.debug("Using %s %s", prefill_client_info, decode_client_info) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(decode_client_info, + "/completions", + req_data, + request_id=request_id): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.get("/healthcheck") +async def healthcheck(): + """Simple endpoint to check if the server is running.""" + return { + "status": "ok", + "prefill_instances": len(app.state.prefill_clients), + "decode_instances": len(app.state.decode_clients) + } + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/tests/e2e/pd_disaggreate/run_edge_case_test.sh b/tests/e2e/pd_disaggreate/run_edge_case_test.sh new file mode 100644 index 0000000000..2f5b1bc8e9 --- /dev/null +++ b/tests/e2e/pd_disaggreate/run_edge_case_test.sh @@ -0,0 +1,141 @@ +#!/bin/bash +export LCCL_DETERMINISTIC=1 +export HCCL_DETERMINISTIC=true +export CLOSE_MATMUL_K_SHIFT=1 +export VLLM_USE_V1=1 + +set -xe + +# Models to run +MODELS=( + "Qwen/Qwen2.5-0.5B-Instruct" +) + +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Gen ranktable +RANKTABLE_PATH=${GIT_ROOT}/examples/disaggregate_prefill_v1/ranktable.json +if [ -f "$RANKTABLE_PATH" ]; then + rm "$RANKTABLE_PATH" +fi +cd ${GIT_ROOT}/examples/disaggregate_prefill_v1 +LOCAL_HOST=`hostname -I|awk -F " " '{print$1}'` +bash gen_ranktable.sh --ips $LOCAL_HOST --network-card-name enp189s0f0 --prefill-device-cnt 1 --decode-device-cnt 1 +cd - +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH="$RANKTABLE_PATH" + +# Waits for vLLM to start. +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/health > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Function to clean up previous instances +cleanup_instances() { + echo "Cleaning up any running vLLM instances..." + pkill -f "vllm serve" || true + sleep 2 +} + +# Handle to get model-specific arguments for deepseek +get_model_args() { + local model_name=$1 + local extra_args="" + + if [[ "$model_name" == *"deepseek"* ]]; then + extra_args="--trust-remote-code" + fi + + echo "$extra_args" +} + + +# Function to run tests for a specific model +run_tests_for_model() { + local model_name=$1 + echo "================================" + echo "Testing model: $model_name" + echo "================================" + + # Get model-specific arguments + local model_args=$(get_model_args "$model_name") + + # Start prefill instance + PREFILL_PORT=8001 + + BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=0 VLLM_LLMDD_RPC_PORT=5559 vllm serve $model_name \ + --port $PREFILL_PORT \ + --seed 1024 \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config '{\"kv_connector\":\"LLMDataDistCMgrConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"npu\",\"kv_parallel_size\":\"1\",\"kv_port\":\"20001\",\"engine_id\":\"0\",\"kv_connector_module_path\":\"vllm_ascend.distributed.llmdatadist_c_mgr_connector\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Start decode instance + DECODE_PORT=8002 + + # Build the command with or without model-specific args + BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=1 VLLM_LLMDD_RPC_PORT=6000 vllm serve $model_name \ + --port $DECODE_PORT \ + --seed 1024 \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config '{\"kv_connector\":\"LLMDataDistCMgrConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"npu\",\"kv_parallel_size\":\"1\",\"kv_port\":\"20001\",\"engine_id\":\"0\",\"kv_connector_module_path\":\"vllm_ascend.distributed.llmdatadist_c_mgr_connector\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Wait for all instances to start + echo "Waiting for prefill instance on port $PORT to start..." + wait_for_server $PREFILL_PORT + echo "Waiting for decode instance on port $PORT to start..." + wait_for_server $DECODE_PORT + + # Build the command for the proxy server with all the hosts and ports + PROXY_PORT=8192 + PROXY_CMD="python ${GIT_ROOT}/examples/disaggregate_prefill_v1/toy_proxy_server.py --port $PROXY_PORT" + PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}" + PROXY_CMD+=" --decoder-ports ${DECODE_PORT}" + # Start the proxy server + echo "Starting proxy server with command: $PROXY_CMD" + $PROXY_CMD & + + # Wait for the proxy to start + sleep 5 + + # Run lm eval for this model + echo "Running tests for $model_name" + PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/e2e/pd_disaggreate/test_edge_cases.py + + # Clean up before running next model + cleanup_instances + sleep 3 +} + +# Run tests for each model +for model in "${MODELS[@]}"; do + run_tests_for_model "$model" +done + +echo "All tests completed!" \ No newline at end of file diff --git a/tests/e2e/pd_disaggreate/test_edge_cases.py b/tests/e2e/pd_disaggreate/test_edge_cases.py new file mode 100644 index 0000000000..fe53ddc6db --- /dev/null +++ b/tests/e2e/pd_disaggreate/test_edge_cases.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# This code is from: https://github.com/vllm-project/vllm/blob/main/tests/v1/kv_connector/nixl_integration/test_edge_cases.py +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +import os + +import openai + +PREFILL_PORT = os.getenv("PREFILL_PORT", None) +DECODE_PORT = os.getenv("DECODE_PORT", None) +PROXY_PORT = os.getenv("PROXY_PORT", None) + +if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None: + raise ValueError( + "Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.") + +LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501 +PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501 +SHORT_PROMPT = "Red Hat is " + + +def test_edge_cases(): + # Set the OpenAI API key and base URL + decode_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{DECODE_PORT}/v1", + ) + prefill_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{PREFILL_PORT}/v1", + ) + proxy_client = openai.OpenAI( + api_key="MY_KEY", + base_url=f"http://localhost:{PROXY_PORT}/v1", + ) + + # Get the list of models + models = decode_client.models.list() + MODEL = models.data[0].id + + # (1) Check that we can handle a very short prompt, + # less than the length of the block size. + completion = proxy_client.completions.create(model=MODEL, + prompt=SHORT_PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + completion = prefill_client.completions.create(model=MODEL, + prompt=SHORT_PROMPT, + temperature=0) + prefill_response = completion.choices[0].text + print(f"SMALL PROMPT: {proxy_response=}") + print(f"SMALL PROMPT: {prefill_response=}") + assert proxy_response == prefill_response + + # (2) Check that we can handle a full prefix cache + # hit on the D worker but not on the P worker. + # (2a): prime the D worker. + completion = decode_client.completions.create(model=MODEL, + prompt=PROMPT, + temperature=0) + decode_response = completion.choices[0].text + # (2b): send via the P/D setup + completion = proxy_client.completions.create(model=MODEL, + prompt=PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + print(f"FULL CACHE HIT: {proxy_response=}") + assert proxy_response == decode_response + + # (3) Check that we can handle a partial prefix cache + # hit on the D worker. + completion = proxy_client.completions.create(model=MODEL, + prompt=LONG_PROMPT, + temperature=0) + proxy_response = completion.choices[0].text + completion = prefill_client.completions.create(model=MODEL, + prompt=LONG_PROMPT, + temperature=0) + prefill_response = completion.choices[0].text + print(f"PARTIAL CACHE HIT: {proxy_response=}") + assert proxy_response == prefill_response \ No newline at end of file diff --git a/tests/ut/kv_connector/test_llmdatadist_connector.py b/tests/ut/kv_connector/test_llmdatadist_connector.py new file mode 100644 index 0000000000..94650f43e9 --- /dev/null +++ b/tests/ut/kv_connector/test_llmdatadist_connector.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + +from tests.ut.kv_connector.utils import (create_request, create_scheduler, + create_vllm_config) +from vllm_ascend.distributed.llmdatadist_c_mgr_connector import \ + LLMDataDistCMgrConnectorMetadata + + +def test_basic_inferface(): + """Unit test for basic LLMDataDistCMgrConnector interface functionality.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + request_id = request.request_id + + scheduler.add_request(request) + + # Remote Prefill, triggers LLMDataDistCMgrConnectorMetadata. + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, LLMDataDistCMgrConnectorMetadata) + + assert len(kv_connector_metadata.requests) == 1 + assert request_id in kv_connector_metadata.requests + req_meta = kv_connector_metadata.requests[request_id] + + for block_id, block in zip( + req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[request_id]): + assert block_id == block.block_id diff --git a/tests/ut/kv_connector/test_remote_decode_lifecycle.py b/tests/ut/kv_connector/test_remote_decode_lifecycle.py new file mode 100644 index 0000000000..d321490f65 --- /dev/null +++ b/tests/ut/kv_connector/test_remote_decode_lifecycle.py @@ -0,0 +1,123 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/blob/main/tests/conftest.py +# + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import FinishReason, RequestStatus + +from tests.ut.kv_connector.utils import (assert_scheduler_empty, + create_model_runner_output, + create_request, create_scheduler, + create_vllm_config) + + +def test_basic_lifecycle(): + """Test lifecycle of a Remote Decode request.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request(request_id=1, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): Prefill. + # (1a): schedule() + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + + # (1b): execute_model() + model_runner_output = create_model_runner_output(reqs=[request]) + + # (1c): update_from_output() + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + + # Ensure the request is finished after 1 tokens. + assert request.is_finished() + assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED + output = engine_core_outputs[0].outputs[0] + assert output.finish_reason == FinishReason.LENGTH + assert output.kv_transfer_params is not None + + # Request freed in Scheduler and blocks should be freed + assert request_id in scheduler.finished_req_ids + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 0 + + scheduler.schedule() + assert_scheduler_empty(scheduler) + + +def test_prefix_cache_lifecycle(): + """Test that remote decode params still works with a prefix cache hit.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Prime the KVCache. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 3 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_remote_a = create_request(request_id=1, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request_remote_a) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote_a], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + + ##################### + # Actual Test: confirm we send all blocks. + + # Send the KV Transfer. + NUM_EXTERNAL_FULL_BLOCKS -= 1 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request_remote_b = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request_remote_b) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote_b]) + eco = scheduler.update_from_output(scheduler_output, model_runner_output) + kv_transfer_params = eco[0].outputs[0].kv_transfer_params + # Ensure we send all block ids, even if there is a cache hit. + assert (len( + kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + + 1)) + + scheduler.schedule() + assert_scheduler_empty(scheduler) diff --git a/tests/ut/kv_connector/test_remote_prefill_lifecycle.py b/tests/ut/kv_connector/test_remote_prefill_lifecycle.py new file mode 100644 index 0000000000..c8629d5993 --- /dev/null +++ b/tests/ut/kv_connector/test_remote_prefill_lifecycle.py @@ -0,0 +1,242 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/blob/main/tests/conftest.py +# +import copy + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.request import FinishReason, RequestStatus + +from tests.ut.kv_connector.utils import (assert_scheduler_empty, + create_model_runner_output, + create_request, create_scheduler, + create_vllm_config) + + +def test_basic_lifecycle(): + """Test lifecycle of a remote prefill.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + START_FREE_BLOCK_QUEUE_SIZE = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): + # (1a): schedule() + scheduler_output = scheduler.schedule() + + # Nothing running and empty scheduler output. + assert len(scheduler.running) == 0 + assert len(scheduler_output.scheduled_new_reqs) == 0 + assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert len(scheduler_output.num_scheduled_tokens) == 0 + assert scheduler_output.total_num_scheduled_tokens == 0 + + # Req waiting for KVs with no computed/scheduled toks ... + assert len(scheduler.waiting) == 1 + assert request in scheduler.waiting + assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + assert (request.num_computed_tokens == 0) + + # ... but should have (uncached) blocks allocated to it. + block_pool = scheduler.kv_cache_manager.block_pool + assert (block_pool.free_block_queue.num_free_blocks < + START_FREE_BLOCK_QUEUE_SIZE) + assert len(block_pool.cached_block_hash_to_block) == 0 + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_id] + for block in blocks: + assert block._block_hash is None + + # (1b): forward() + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + + # (1c): update_from_output() + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + assert not engine_core_outputs or not engine_core_outputs[0].outputs + + # STEP (2): + # (2a): schedule(): nothing happens! + scheduler_output = scheduler.schedule() + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 0 + + # (2b): forward(): request finishes recv. + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving = [request_id] + + # (2c): update_from_output(): + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + assert len(scheduler.waiting) == 1 + assert (request_id in scheduler.finished_recving_kv_req_ids) + + # STEP (3): + # (3a): schedule(): this should actually schedule. + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + + # Confirm the block are actually allocated. + num_hashed_blocks = 0 + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_id] + for block in blocks: + assert block.ref_cnt == 1 + num_hashed_blocks += (1 if block._block_hash is not None else 0) + assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS + + # Confirm the rest of the prompt is scheduled in this step. + scheduled_req = scheduler_output.scheduled_new_reqs[0] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id] + num_computed_tokens = scheduled_req.num_computed_tokens + total_prompt_tokens = len(scheduled_req.prompt_token_ids) + assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) + + # (3b): execute_model() + model_runner_output = create_model_runner_output([request]) + # (3c): update_from_output() + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Step (4): Hit EOS. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output([request], use_eos=True) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + scheduler.schedule() + + outputs = engine_core_outputs[0].outputs + assert len(outputs) == 1 + output = outputs[0] + assert output.finish_reason == FinishReason.STOP + assert_scheduler_empty(scheduler) + + +def test_no_spurious_prefix_caching(): + """ + With P/D, blocks can be allocated but uncomputed for + multiple engine steps. This test confirms that we do + not accidentally have cache hits against uncomputed + blocks. + """ + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 and a half full external blocks. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + # Both of these requests have prompts like [1,1,1,1,1, ...] + request_remote = create_request( + request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + use_all_1s_for_prompt_tokens=True, + ) + + # Schedule the remote prefill request. This should not + # cause any blocks to be cached. + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + assert len(scheduler.waiting) == 1 + + remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_remote.request_id] + + # Remote blocks should not be cached. + for block in remote_blocks: + assert block.ref_cnt == 1 + assert block._block_hash is None + + +def test_full_block_prompt(): + """Test that we handle a prompt that is the full block size.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS) + + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_prefill=True) + + scheduler.add_request(request) + request_id = request.request_id + + # STEP (1): Initialize a recv. + scheduler_output = scheduler.schedule() + # All blocks should be allocated. + num_blocks = len(scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[request_id]) + assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS + model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT + scheduler.update_from_output(scheduler_output, model_runner_output) + + # # STEP (2): Recv. + scheduler_output = scheduler.schedule() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_recving = [request_id] + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.waiting) == 1 + assert (request_id in scheduler.finished_recving_kv_req_ids) + + # # STEP (3): Run as usual. + scheduler_output = scheduler.schedule() + + # We need to recompute the final token of the prompt to generate + # the first new token, so we should not have a new block. + num_blocks = len(scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[request_id]) + assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS + assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == + NUM_TOKENS - 1) + assert (scheduler_output.num_scheduled_tokens[request_id] == 1) + + model_runner_output = create_model_runner_output([request]) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # # Step (4): Hit EOS. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output([request], use_eos=True) + engine_core_outputs = scheduler.update_from_output(scheduler_output, + model_runner_output) + scheduler.schedule() + + outputs = engine_core_outputs[0].outputs + assert len(outputs) == 1 + output = outputs[0] + assert output.finish_reason == FinishReason.STOP + assert_scheduler_empty(scheduler) diff --git a/tests/ut/kv_connector/utils.py b/tests/ut/kv_connector/utils.py new file mode 100644 index 0000000000..2ad3b17aeb --- /dev/null +++ b/tests/ut/kv_connector/utils.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 +# This code is from: https://github.com/vllm-project/vllm/tests/v1/kv_connector/unit/utils.py +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + +import os +from typing import Any, Optional + +import torch +from vllm import SamplingParams +from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, + ModelConfig, SchedulerConfig, VllmConfig) +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + +EOS_TOKEN_ID = 50256 +os.environ["VLLM_USE_V1"] = "1" + + +def assert_scheduler_empty(scheduler: Scheduler): + """Confirm the scheduler is "empty" - i.e. no leaks.""" + # Scheduler Metadata. + assert len(scheduler.requests) == 0 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 0 + assert len(scheduler.finished_req_ids) == 0 + assert len(scheduler.finished_recving_kv_req_ids) == 0 + assert len(scheduler._cached_reqs_data) == 0 + + # EncoderCacheManager. + assert len(scheduler.encoder_cache_manager.freed) == 0 + assert len(scheduler.encoder_cache_manager.cached) == 0 + + # KVCache Manager. + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + num_cached_block) == 0 + num_free_blocks = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + assert num_free_blocks == ( + scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + + # NOTE(rob): just the ref count on blocks will be 0. The hash + # value, etc will remain since we lazily evict for prefix cache. + for block in scheduler.kv_cache_manager.block_pool.blocks: + assert block.ref_cnt == 0 + + +def create_vllm_config( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 1024, + block_size: int = 128, +) -> VllmConfig: + """Initialize VllmConfig For Testing.""" + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=True, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="LLMDataDistCMgrConnector", + kv_role="kv_both", + kv_connector_module_path= + "vllm_ascend.distributed.llmdatadist_connector_v1_a3") + return VllmConfig(scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu")) + + +def create_scheduler( + vllm_config: VllmConfig, + num_blocks: int = 10000, +) -> Scheduler: + """Initialize Scheduler For Testing.""" + block_size = vllm_config.cache_config.block_size + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float16, + False)) + ], + ) + vllm_config.cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_request( + request_id: int, + num_tokens: int = 10, + max_tokens: int = 128, + do_remote_decode: bool = False, + do_remote_prefill: bool = False, + use_all_1s_for_prompt_tokens: bool = False, + num_remote_blocks: int = 3, +) -> Request: + """Make dummy request for testing.""" + + kv_transfer_params: Optional[dict[str, Any]] = None + + if do_remote_decode: + assert not do_remote_prefill + kv_transfer_params = dict(do_remote_prefill=False, + do_remote_decode=True) + elif do_remote_prefill: + kv_transfer_params = dict(do_remote_prefill=True, + do_remote_decode=False, + remote_engine_id="my-engine-id", + remote_block_ids=list( + range(num_remote_blocks)), + remote_host="my-host", + remote_port=1234, + remote_tp_size=1) + + max_tokens = 1 if do_remote_decode else max_tokens + sampling_params = SamplingParams(max_tokens=max_tokens) + + if use_all_1s_for_prompt_tokens: + prompt_token_ids = [1] * num_tokens + else: + prompt_token_ids = [i * request_id for i in range(num_tokens)] + + req = Request( + request_id=f"id-{request_id}", + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + multi_modal_inputs=None, + multi_modal_placeholders=None, + multi_modal_hashes=None, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + ) + req.kv_transfer_params = kv_transfer_params + return req + + +def create_model_runner_output( + reqs: list[Request], + finished_sending: Optional[list[str]] = None, + finished_recving: Optional[list[str]] = None, + use_eos: bool = False, +) -> ModelRunnerOutput: + """Make dummy model runner output for testing.""" + + # Make request data. + req_ids = [req.request_id for req in reqs] + req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)} + + # Make sampled tokens. + sampled_token = EOS_TOKEN_ID if use_eos else 0 + sampled_token_ids = [[sampled_token] for _ in req_ids] + + # Make output data structure. + return ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[None], + finished_sending=finished_sending, + finished_recving=finished_recving, + ) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index e6a2376786..3a54a5c15d 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -275,7 +275,7 @@ def forward( # TODO: Remove this contiguous in the future. value = value.contiguous() - if kv_cache.numel() > 0: + if len(kv_cache) > 0: if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 21a57a09b0..67f3e8c733 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -657,12 +657,13 @@ def get_and_maybe_dequant_weights(layer: LinearBase): def _compute_prefill_context( self, query: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], rope_dim: int, attn_metadata: AscendMLAMetadata, prefix_output: torch.Tensor, prefix_lse: torch.Tensor, ): + assert len(kv_c_and_k_pe_cache) > 1 prefill_metadata = attn_metadata.prefill if prefill_metadata is None or prefill_metadata.chunked_context is None: return prefix_output, prefix_lse @@ -672,21 +673,22 @@ def _compute_prefill_context( q_nope = query[..., :self.qk_nope_head_dim] seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) - latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim - cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim] - cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:] + cache_kv_c = kv_c_and_k_pe_cache[0] + cache_k_pe = kv_c_and_k_pe_cache[1] + num_heads = cache_k_pe.size(2) + latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] seq_len = torch.stack([seq_len1, seq_len2]) kv_c_normed = torch.empty(toks, - kv_c_and_k_pe_cache.size(2), + num_heads, latent_kv_dim, dtype=query.dtype, device=query.device) k_pe = torch.empty(toks, - kv_c_and_k_pe_cache.size(2), + num_heads, rope_dim, dtype=query.dtype, device=query.device) @@ -736,10 +738,11 @@ def _forward_prefill( query: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: assert attn_metadata.prefill is not None + assert len(kv_c_and_k_pe_cache) > 1 num_tokens = query.size(0) attn_output = torch.empty(num_tokens, @@ -831,10 +834,6 @@ def _forward_prefill( num_kv_heads=self.num_heads, out=attn_output) attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) - else: - raise RuntimeError( - "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !" - ) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) if attn_metadata.attn_state in [ @@ -1027,7 +1026,7 @@ def forward( hidden_states_or_q_c: torch.Tensor, # query in unified attn hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn - kv_cache: torch.Tensor, + kv_cache: Tuple[torch.Tensor], attn_metadata: M, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -1146,8 +1145,12 @@ def forward( prefill_q_pe.contiguous(), prefill_k_pe, max_seq_len=attn_metadata.prefill.max_seq_lens) + + assert len( + kv_cache + ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" if self.torchair_graph_enabled: - if len(kv_cache) > 0 and kv_cache[0].numel( + if kv_cache[0].numel( ) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: slots = attn_metadata.slot_mapping # NOTE: Separate the kv cache in advance to avoid OOM or other issues @@ -1157,16 +1160,15 @@ def forward( key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=slots) - elif kv_cache.numel() > 0: - key = torch.cat([ - kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), - k_pe - ], - dim=2) - torch_npu._npu_reshape_and_cache_siso( - key=key, - key_cache=kv_cache, - slot_indices=attn_metadata.slot_mapping.flatten()) + else: + kv_c_normed = kv_c_normed.view( + [num_actual_toks, self.num_kv_heads, -1]) + torch_npu._npu_reshape_and_cache( + key=kv_c_normed, + value=k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=attn_metadata.slot_mapping) if has_prefill: # FIX: aicore move should be also placed on the comm stream in dbo, # otherwise it may affect the accuracy @@ -1189,11 +1191,10 @@ def forward( decode_k_nope, decode_k_pe, kv_cache, attn_metadata) else: - output_decode = self._forward_decode(decode_ql_nope, - decode_q_pe, - decode_k_nope, - decode_k_pe, kv_cache, - attn_metadata) + combined_cache = torch.cat([kv_cache[0], kv_cache[1]], dim=-1) + output_decode = self._forward_decode( + decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, + combined_cache, attn_metadata) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is not None: with torch.npu.stream(current_ms_metadata.comm_stream): diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py index 88c2f2199b..d7be705c2b 100644 --- a/vllm_ascend/distributed/__init__.py +++ b/vllm_ascend/distributed/__init__.py @@ -25,3 +25,8 @@ KVConnectorFactory.register_connector( "AscendSimpleConnector", "vllm_ascend.distributed.kv_transfer.simple_connector", "SimpleConnector") + +KVConnectorFactory.register_connector( + "LLMDataDistCMgrConnector", + "vllm_ascend.distributed.llmdatadist_c_mgr_connector", + "LLMDataDistCMgrConnector") diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py new file mode 100644 index 0000000000..c97f973bb3 --- /dev/null +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -0,0 +1,789 @@ +import contextlib +import json +import math +import threading +import time +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any, Optional, Tuple + +import llm_datadist # type: ignore[import-not-found] +import msgspec +import torch +import zmq +from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist, + LLMException, LLMRole) +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import get_tp_group, get_world_group +from vllm.forward_context import ForwardContext +from vllm.utils import get_ip, logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request, RequestStatus + +from vllm_ascend import envs +from vllm_ascend.soc_info import NPUSocInfo + +TORCH_DTYPE_TO_NPU_DTYPE = { + torch.half: llm_datadist.DataType.DT_FLOAT16, + torch.float16: llm_datadist.DataType.DT_FLOAT16, + torch.bfloat16: llm_datadist.DataType.DT_BF16, + torch.float: llm_datadist.DataType.DT_FLOAT, + torch.float32: llm_datadist.DataType.DT_FLOAT, + torch.int8: llm_datadist.DataType.DT_INT8, + torch.int64: llm_datadist.DataType.DT_INT64, + torch.int32: llm_datadist.DataType.DT_INT32 +} + + +class LLMDataDistCMgrAgentMetadata(msgspec.Struct): + super_pod_id: str + server_id: str + device_id: str + device_ip: str + super_device_id: str + cluster_id: int + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: str + engine_id: str + remote_tp_size: str + + +class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + + def add_new_req(self, request_id: str, local_block_ids: list[int], + kv_transfer_params: dict[str, Any]): + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + remote_tp_size=kv_transfer_params["remote_tp_size"], + ) + + +class LLMDataDistCMgrConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + self.engine_id = vllm_config.kv_transfer_config.engine_id + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: Optional[ + LLMDataDistCMgrConnectorScheduler] = LLMDataDistCMgrConnectorScheduler( + vllm_config, self.engine_id) + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = LLMDataDistCMgrConnectorWorker(vllm_config) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches( + self, + kv_caches: dict[str, # type: ignore[override] + Tuple[torch.Tensor]]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished(finished_req_ids) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + LLMDataDistCMgrConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """LLMDataDistCMgrConnector does not do layerwise saving, the load is in blocking manager.""" + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata, **kwargs) -> None: + """LLMDataDistCMgrConnector does not save explicitly.""" + pass + + def wait_for_save(self): + """LLMDataDistCMgrConnector does not save explicitly.""" + pass + + +class LLMDataDistCMgrConnectorScheduler(): + + def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id = engine_id + self.local_ip = get_ip() + # Can not retrieve the parallel config since it is not initialized. + self.local_dp_rank = None + self.tp_size = None + dp_rank_local = self.vllm_config.parallel_config.data_parallel_rank_local + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + + self.port = dp_rank_local * tp_size + envs.VLLM_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs.VLLM_LLMDD_RPC_PORT + + self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + params = request.kv_transfer_params + logger.debug( + f"LLMDataDistCMgrConnector get_num_new_matched_tokens: num_computed_tokens={num_computed_tokens}, kv_transfer_params={params}" + ) + + if params is not None and params.get("do_remote_prefill"): + # Remote prefill: get all prompt blocks from remote. + assert num_computed_tokens % self.block_size == 0 + # Note: We use the full token count as transmit data here. + count = max(len(request.prompt_token_ids) - num_computed_tokens, 0) + return count, count > 0 + + # No remote prefill for this request. + return 0, False + + def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, + num_externel_tokens: int): + params = request.kv_transfer_params + logger.debug( + f"LLMDataDistCMgrConnector update states num_externel_tokens: {num_externel_tokens} kv_transfer_params: {params}" + ) + if params is not None and params.get("do_remote_prefill"): + if params.get("remote_block_ids"): + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port", "remote_tp_size")): + self._reqs_need_recv[request.request_id] = ( + request, blocks.get_unhashed_block_ids()) + else: + logger.warning("" \ + f"Invalid KVTransferParams {params}, This request will be discard") + else: + assert num_externel_tokens == 0 + params["do_remote_prefill"] = False + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = LLMDataDistCMgrConnectorMetadata() + + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req(request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params) + self._reqs_need_recv.clear() + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + + params = request.kv_transfer_params + logger.debug( + "LLMDataDistCMgrConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", request.status, params) + + if (params is None or not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + return False, None + + # note: NIXL transfer the full block only, but I don't see any reason to do that, so here + # we just transfer any data that computed from prefill node + # note: there might be some issue on this, check it if there is any unexpected result + computed_block_ids = block_ids + # If prompt < block_size, no xfer so free blocks immediately. + + return False, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=computed_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.local_ip, + remote_port=self.port, + remote_tp_size=str( + self.vllm_config.parallel_config.tensor_parallel_size), + ) + + +class LLMDataDistCMgrConnectorWorker(): + """ + Implementation of Worker side methods + """ + + def __init__(self, vllm_config: VllmConfig): + assert vllm_config.kv_transfer_config is not None + logger.info("Initialize the LLMDataDistCMgrConnectorWorker") + # we assume the local node only contains dp and tp, and tp will not communicate inter-node. + # for any scenario beyond this scope, the functionality of this connector is not guaranteed. + self.local_rank_on_node = get_world_group().rank % ( + vllm_config.parallel_config.data_parallel_size_local * + vllm_config.parallel_config.tensor_parallel_size) + self.local_rank = get_world_group().local_rank + self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local + self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.tp_rank = get_tp_group().rank_in_group + self.rank = get_world_group().rank + self.local_ip = get_ip() + self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config + self.local_agent_metadata: Optional[ + LLMDataDistCMgrAgentMetadata] = None + self.vllm_config = vllm_config + + self.llm_datadist_role = None + self.llm_datadist_remote_role = None + if self.kv_transfer_config.kv_role == "kv_producer": + self.llm_datadist_role = LLMRole.PROMPT + self.llm_datadist_remote_role = LLMRole.DECODER + elif self.kv_transfer_config.kv_role == "kv_consumer": + self.llm_datadist_role = LLMRole.DECODER + self.llm_datadist_remote_role = LLMRole.PROMPT + else: + raise RuntimeError( + f"LLMDataDistWorker: Receive unexpected kv role in LLMDataDistWorker, this worker now only support kv_producer and kv_consumer, but receiving {vllm_config.kv_transfer_config.kv_role}" + ) + + # linked_cluster record the cluster that already build the connection its format should be {"cluster_id": "comm_name"} + self.linked_cluster: dict[Any, Any] = {} + self.prefill_device_list: list[tuple[int, int]] = [] + self.decode_device_list: list[tuple[int, int]] = [] + global_rank_table = self.read_offline_rank_table() + self.local_agent_metadata = self.read_agent_metadata( + global_rank_table, self.local_ip, self.local_rank_on_node, + self.llm_datadist_role) + self.llm_datadist = LLMDataDist(self.llm_datadist_role, + self.local_agent_metadata.cluster_id) + self.init_llm_datadist() + self.finished_reqs: set[str] = set() + self.soc_info = NPUSocInfo() + + def listen_for_agent_metadata_req(self, event: threading.Event): + assert self.local_agent_metadata is not None + port = envs.VLLM_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs.VLLM_LLMDD_RPC_PORT + self.tp_size + self.tp_rank + url = f"tcp://0.0.0.0:{port}" + msg_encoder = msgspec.msgpack.Encoder() + msg_decoder = msgspec.msgpack.Decoder() + msg_to_send = msg_encoder.encode(self.local_agent_metadata) + logger.debug(f"Start to listen to address: {url}") + logger.debug( + f"The local agent metadata have {len(msg_to_send)} bytes here") + logger.info( + f"LLMDataDistCMgrConnectorWorker: Cluster {self.local_agent_metadata.cluster_id} start to listen request from peers" + ) + with zmq_ctx(zmq.ROUTER, url) as sock: # type: ignore[attr-defined] + event.set() + while True: + identity, _, msg = sock.recv_multipart() + decode_msg = msg_decoder.decode(msg) + if "cluster_id" in decode_msg: + decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg) + logger.info( + f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}" + ) + sock.send_multipart((identity, b"", msg_to_send)) + self.add_remote_agent(decode_msg) + else: + logger.warning( + f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}" + ) + + def init_llm_datadist(self): + assert self.local_agent_metadata is not None + llm_config = LLMConfig() + llm_config.device_id = self.local_rank + llm_config.sync_kv_timeout = 20000 + llm_config.enable_switch_role = True + llm_config.enable_cache_manager = True + llm_config.enable_remote_cache_accessible = True + llm_config_options = llm_config.generate_options() + self.llm_datadist.init(llm_config_options) + self.cache_manager = self.llm_datadist.cache_manager + logger.info( + f"Done initialize llm_datadist in rank {self.rank}, local rank {self.local_rank}, cluster id {self.local_agent_metadata.cluster_id}" + ) + + def read_offline_rank_table(self): + assert ( + envs.DISAGGREGATED_PREFILL_RANK_TABLE_PATH + ), "Please set path of rank_table to env variable DISAGGREGATED_PREFILL_RANK_TABLE_PATH" + rank_table_path = envs.DISAGGREGATED_PREFILL_RANK_TABLE_PATH + with open(rank_table_path, "r", encoding="utf-8") as f: + global_rank_table = json.load(f) + decode_device_list = global_rank_table["decode_device_list"] + for decode_device in decode_device_list: + server_id = decode_device["server_id"] + device_id = decode_device["device_id"] + self.decode_device_list.append((server_id, device_id)) + prefill_device_list = global_rank_table["prefill_device_list"] + for prefill_device in prefill_device_list: + server_id = prefill_device["server_id"] + device_id = prefill_device["device_id"] + self.prefill_device_list.append((server_id, device_id)) + + # global_rank_table = json.dumps(global_rank_table) + return global_rank_table + + def read_agent_metadata(self, global_rank_table, server_id, device_rank, + agent_role): + devices_type_list = [] + agent_metadata = None + if self.llm_datadist_role == LLMRole.PROMPT: + devices_type_list.append("prefill_device_list") + elif self.llm_datadist_role == LLMRole.DECODER: + devices_type_list.append("decode_device_list") + else: + devices_type_list.append("prefill_device_list") + devices_type_list.append("decode_device_list") + for device_type in devices_type_list: + device_list = global_rank_table[device_type] + device_list = [ + d for d in device_list if d.get("server_id") == server_id + ] + if len(device_list) <= device_rank: + continue + device_info = device_list[device_rank] + super_pod_id_ = device_info.get("super_pod_id", None) + server_id_ = device_info["server_id"] + device_id_ = device_info["device_id"] + device_ip_ = device_info["device_ip"] + super_device_id_ = device_info.get("super_device_id", None) + cluster_id_ = int(device_info["cluster_id"]) + agent_metadata = LLMDataDistCMgrAgentMetadata( + super_pod_id=super_pod_id_, + server_id=server_id_, + device_id=device_id_, + device_ip=device_ip_, + super_device_id=super_device_id_, + cluster_id=cluster_id_, + ) + assert agent_metadata is not None, f"Can't read the target server_id {server_id} and device_rank {device_rank} from rank table" + return agent_metadata + + def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]): + _, first_kv_cache_tuple = next(iter(kv_caches.items())) + first_kv_cache = first_kv_cache_tuple[0] + assert len(first_kv_cache_tuple) > 1 + assert self.local_agent_metadata is not None + kv_cache_dtype = first_kv_cache.dtype + self.use_mla: bool = first_kv_cache_tuple[0].size( + -1) != first_kv_cache_tuple[1].size(-1) + # MLA case. [2 (k_normed, k_pe), num_blocks, ...] + # MHA case. [2 (k and v), num_blocks, ...] + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + + self.block_len = math.prod(block_shape) + self.cache_addr: list[int] = [] + alignment = 2 * 1024 * 1024 + if self.use_mla: + cache_k_normed_addr_list = [] + cache_k_pe_addr_list = [] + k_normed = None + k_pe = None + for cache_or_caches in kv_caches.values(): + assert len(cache_or_caches) > 1 + k_normed, k_pe = cache_or_caches[0], cache_or_caches[1] + cache_k_normed_addr_list.append(k_normed.data_ptr()) + cache_k_pe_addr_list.append(k_pe.data_ptr()) + self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list) + + cache_desc_k_normed = CacheDesc( + len(self.cache_addr[0]), [*k_normed.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_desc_k_pe = CacheDesc( + len(self.cache_addr[1]), [*k_pe.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_key_k_normed = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=0) + cache_key_k_pe = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=1) + self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe) + self.cache_key = (cache_key_k_normed, cache_key_k_pe) + try: + cache_k_normed = self.cache_manager.register_blocks_cache( + self.cache_desc[0], self.cache_addr[0], self.cache_key[0]) + cache_k_pe = self.cache_manager.register_blocks_cache( + self.cache_desc[1], self.cache_addr[1], self.cache_key[1]) + self.cache = (cache_k_normed, cache_k_pe) + logger.info("LLMDataDistWorker: End of register Paged Cache.") + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" + ) + else: + for cache_or_caches in kv_caches.values(): + for cache in cache_or_caches: + base_addr = cache.data_ptr() + assert base_addr % alignment == 0, "The address of the registered kv cache should be aligned to 2M" + self.cache_addr.append(base_addr) + # register paged kv cache into the llm_cache manager + self.cache_desc = CacheDesc( + len(self.cache_addr), [*cache.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + self.cache_key = BlocksCacheKey( + cluster_id=int(self.local_agent_metadata.cluster_id)) + logger.info( + f"num of cache: {len(self.cache_addr)}, size of cache: {[*cache.shape]}, real size of cache: {first_kv_cache.shape}" + ) + try: + self.cache = self.cache_manager.register_blocks_cache( + self.cache_desc, self.cache_addr, self.cache_key) + logger.info( + "LLMDataDistCMgrConnectorWorker: End of register Paged Cache." + ) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" + ) + self.ready_event = threading.Event() + self.metadata_agent_listener_t = threading.Thread( + target=self.listen_for_agent_metadata_req, + args=(self.ready_event, ), + daemon=True, + name="metadata_agent_listener") + self.metadata_agent_listener_t.start() + self.ready_event.wait() + + def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata): + for req_id, meta in metadata.requests.items(): + logger.debug(f"Start to transmit {req_id}") + self._read_blocks(meta.local_block_ids, + meta.remote_block_ids, meta.remote_host, + int(meta.remote_port), meta.engine_id, req_id, + meta.remote_tp_size) + self.finished_reqs.add(req_id) + + def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int: + assert self.local_agent_metadata is not None + remote_cluster_id = metadata.cluster_id + if remote_cluster_id in self.linked_cluster: + logger.debug( + f"LLMDataDistCMgrConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection" + ) + return remote_cluster_id + remote_super_pod_id = metadata.super_pod_id + remote_server_id = metadata.server_id + is_same_server = remote_server_id == self.local_agent_metadata.server_id + is_same_pod = remote_super_pod_id == self.local_agent_metadata.super_pod_id + if self.llm_datadist_role == LLMRole.PROMPT: + prefill_metadata = self.local_agent_metadata + decode_metadata = metadata + else: + prefill_metadata = metadata + decode_metadata = self.local_agent_metadata + comm_name = f"pd_comm_{prefill_metadata.device_ip}_{decode_metadata.device_ip}" + cluster_rank_info = { + prefill_metadata.cluster_id: 0, + decode_metadata.cluster_id: 1 + } + rank_table = {} + rank_table["version"] = "1.2" + rank_table["server_count"] = "1" if is_same_server else "2" + rank_table["status"] = "completed" + + # generate server_list for rank table + rank_table["server_list"] = [] # type: ignore[assignment] + decode_server_device_info = None + prefill_server_device_info = { + "device": [{ + k: v + for k, v in [( + "device_id", prefill_metadata.device_id + ), ("device_ip", prefill_metadata.device_ip + ), ("super_device_id", + prefill_metadata.super_device_id), ("rank_id", "0")] + if v is not None + }], + "server_id": + prefill_metadata.server_id + } + if is_same_server: + prefill_server_device_info["device"].append( # type: ignore[attr-defined] + { + k: v + for k, v in [( + "device_id", decode_metadata.device_id + ), ("device_ip", decode_metadata.device_ip + ), ("super_device_id", + decode_metadata.super_device_id), ("rank_id", "1")] + if v is not None + }) + else: + decode_server_device_info = { + "device": [{ + k: v + for k, v in [( + "device_id", decode_metadata.device_id + ), ("device_ip", decode_metadata.device_ip + ), ("super_device_id", + decode_metadata.super_device_id), ("rank_id", "1")] + if v is not None + }], + "server_id": + decode_metadata.server_id + } + rank_table["server_list"].append( # type: ignore[attr-defined] + prefill_server_device_info) + if decode_server_device_info is not None: + rank_table["server_list"].append( # type: ignore[attr-defined] + decode_server_device_info) + + if self.soc_info.is_a3: + # generate super_pod_list for rank table + super_pod_list = [] + prefill_super_pod_info = { + "super_pod_id": prefill_metadata.super_pod_id, + "server_list": [{ + "server_id": prefill_metadata.server_id + }], + } + if is_same_pod and not is_same_server: + prefill_super_pod_info[ + "server_list"].append( # type: ignore[attr-defined] + {"server_id": decode_metadata.server_id}) + super_pod_list.append(prefill_super_pod_info) + if not is_same_pod: + decode_super_pod_id = { + "super_pod_id": decode_metadata.super_pod_id, + "server_list": [{ + "server_id": decode_metadata.server_id + }], + } + super_pod_list.append(decode_super_pod_id) + rank_table[ + "super_pod_list"] = super_pod_list # type: ignore[assignment] + logger.info( + f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: {comm_name}" + ) + logger.info(f"rank table \n{rank_table}") + logger.info(f"comm name: {comm_name}") + logger.info(f"cluster rank info: {cluster_rank_info}") + comm_id = self.llm_datadist.link(comm_name, cluster_rank_info, + json.dumps(rank_table)) + while True: + ret = self.llm_datadist.query_register_mem_status(comm_id=comm_id) + if ret == llm_datadist.RegisterMemStatus.OK: + logger.info( + f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: {comm_id}" + ) + break + elif ret == llm_datadist.RegisterMemStatus.FAILED: + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: {comm_id}" + ) + time.sleep(1) + logger.info("Checking query_register_mem_status again") + self.linked_cluster.update({remote_cluster_id: comm_id}) + logger.info(f"cached linked cluster: {self.linked_cluster}") + logger.info( + f"Successfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !" + ) + return remote_cluster_id + + def remove_remote_agent(self, cluster_id: int): + if cluster_id not in self.linked_cluster: + logger.warning( + f"LLMDataDistCMgrConnectorWorker: Warning! Can't remove remote client with cluster id {cluster_id} for its not exist in linked_cluster list" + ) + comm_id = self.linked_cluster[cluster_id] + try: + self.llm_datadist.unlink(comm_id) + self.linked_cluster.pop(cluster_id) + except LLMException: + logger.error( + f"Try to remove remote client with cluster id {cluster_id} failed!, program won't terminate, but please carefully check your environment" + ) + logger.info( + f"Successfully remove remote client with cluster id {cluster_id} !" + ) + + def connect_to_remote_agent(self, host: str, port: int) -> int: + url = f"tcp://{host}:{port}" + logger.debug(f"Querying metadata from url: {url}") + msg_encoder = msgspec.msgpack.Encoder() + msg_send = msg_encoder.encode(self.local_agent_metadata) + with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] + logger.info("Try request remote metadata from socket......") + sock.send(msg_send) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder() + metadata = decoder.decode(metadata_bytes) + metadata = LLMDataDistCMgrAgentMetadata(**metadata) + logger.info(f"recving metadata: {metadata}") + cluster_id = self.add_remote_agent(metadata) + return cluster_id + + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + remote_ip: str, + remote_port: int, + remote_engine_id: str, + request_id: str, + remote_tp_size: str, + ): + # if remote_ip not in self.linked_cluster: + tp_offset = self.tp_rank % int(remote_tp_size) + remote_cluster_id = self.connect_to_remote_agent( + remote_ip, remote_port + tp_offset) + num_local_blocks = len(local_block_ids) + if num_local_blocks == 0: + return + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + + logger.info(f"remote cluster id is: {remote_cluster_id}") + if self.use_mla: + remote_cache_key_k_normed = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=0) + remote_cache_key_k_pe = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=1) + logger.info("Try pull blocks from remote server") + try: + self.cache_manager.pull_blocks( + remote_cache_key_k_normed, + self.cache[0], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + self.cache_manager.pull_blocks( + remote_cache_key_k_pe, + self.cache[1], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type] + ) + except LLMException: + raise RuntimeError( + "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" + ) + else: + remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id) + logger.info("Try pull blocks from remote server") + try: + self.cache_manager.pull_blocks( + remote_cache_key, + self.cache, # type: ignore[has-type] + remote_block_ids, + local_block_ids) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type] + ) + except LLMException: + raise RuntimeError( + "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" + ) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """Get the finished recving and sending requuests.""" + import copy + req_ids_to_ret = copy.deepcopy(self.finished_reqs) + self.finished_reqs.clear() + if self.llm_datadist_role == LLMRole.PROMPT: + return req_ids_to_ret, None + else: + return None, req_ids_to_ret + + +# adopt this from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, + addr: str) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + """Context manager for a ZMQ socket""" + + ctx: Optional[zmq.Context] = None # type: ignore[name-defined] + try: + ctx = zmq.Context() # type: ignore[attr-defined] + + if socket_type == zmq.ROUTER: # type: ignore[attr-defined] + socket = ctx.socket(zmq.ROUTER) # type: ignore[attr-defined] + socket.bind(addr) + elif socket_type == zmq.REQ: # type: ignore[attr-defined] + socket = ctx.socket(zmq.REQ) # type: ignore[attr-defined] + socket.connect(addr) + else: + raise ValueError(f"Unexpected socket type: {socket_type}") + + yield socket + finally: + if ctx is not None: + ctx.destroy(linger=0) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 02ecd6625b..fd5d6c681b 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -116,6 +116,23 @@ # value to False to disable the optimized model. "USE_OPTIMIZED_MODEL": lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))), + # `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is + # used for llmdatadist to build the communication topology for kv cache transfer, it is + # a required variable if `LLMDataDistCMgrConnector` is used as kv connector for disaggregated + # pd. The rank table can be generated by adopting the script `gen_ranktable.sh` + # in vllm_ascend's example folder. + "DISAGGREGATED_PREFILL_RANK_TABLE_PATH": + lambda: os.getenv("DISAGGREGATED_PREFILL_RANK_TABLE_PATH", None), + # `LLMDataDistCMgrConnector` required variable. `VLLM_ASCEND_LLMDD_RPC_IP` is used as the + # rpc communication listening ip, which will be used to receive the agent metadata from the + # remote worker. + "VLLM_ASCEND_LLMDD_RPC_IP": + lambda: os.getenv("VLLM_ASCEND_LLMDD_RPC_IP", "0.0.0.0"), + # `LLMDataDistCMgrConnector` required variable. `VLLM_LLMDD_RPC_PORT` is used as the + # rpc communication listening port, which will be used to receive the agent metadata from the + # remote worker. + "VLLM_LLMDD_RPC_PORT": + lambda: int(os.getenv("VLLM_LLMDD_RPC_PORT", 5557)) } # end-env-vars-definition diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index e96b2e9847..f8bf83b1e1 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -33,7 +33,8 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, get_tp_group) @@ -286,6 +287,10 @@ def __init__( self.tp_group = get_tp_group().device_group self.tp_rank = get_tp_group().rank_in_group self.ep_group = get_ep_group() + self.kv_consumer = None + transfer_config = get_current_vllm_config().kv_transfer_config + if transfer_config is not None: + self.kv_consumer = transfer_config.kv_role == "kv_consumer" self.params_dtype = torch.get_default_dtype() @@ -307,6 +312,11 @@ def forward( enable_force_load_balance = False if hasattr(attn_metadata, 'with_prefill_across_dp'): is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + # If this node is kv_consumer, we force the moe always runs in decode path to make sure + # the behaviour aligned between dummy_run and normal model_execute. + if self.kv_consumer is not None: + is_prefill = False + enable_force_load_balance = False # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) diff --git a/vllm_ascend/ops/attention.py b/vllm_ascend/ops/attention.py index 8037c9545b..05600aee7a 100644 --- a/vllm_ascend/ops/attention.py +++ b/vllm_ascend/ops/attention.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, Tuple import torch from vllm.model_executor.layers.linear import ColumnParallelLinear @@ -37,7 +37,7 @@ def vanilla_chunked_prefill( scale: float, alibi_slopes: Optional[torch.Tensor], causal: bool = True, -) -> None: +) -> torch.Tensor: num_query_heads = query.shape[1] head_dim = value_cache.shape[3] num_kv_heads = value_cache.shape[2] @@ -138,7 +138,8 @@ def vanilla_chunked_prefill( def vanilla_chunked_prefill_mla( output: torch.Tensor, # (num_tokens, num_heads, v_head_dim) query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim) - kv_cache: torch.Tensor, # (num_blocks, block_size, latent_kv) + kv_cache: Tuple[ + torch.Tensor], # [nope, rope] (num_blocks, block_size, latent_kv) block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq) query_lens: torch.Tensor, # (batch_size) context_lens: torch.Tensor, # (batch_size) @@ -152,22 +153,25 @@ def vanilla_chunked_prefill_mla( alibi_slopes: Optional[torch.Tensor], causal: bool = True) -> None: batch_size = block_tables.size(0) + assert len(kv_cache) > 1 assert query_lens.size(0) == batch_size num_heads = query.size(1) - block_size = kv_cache.size(1) - latent_kv_dim = kv_cache.size(3) - rope_dim + nope_cache = kv_cache[0] + rope_cache = kv_cache[1] + block_size = nope_cache.size(1) + latent_kv_dim = nope_cache.size(-1) max_num_blocks_per_seq = block_tables.size(1) batch_size = query_lens.size(0) - kv_cache = kv_cache.squeeze() - # select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] - cache_kv_c_pe = kv_cache[block_tables].view( - batch_size, max_num_blocks_per_seq * block_size, - latent_kv_dim + rope_dim)[:, :max_context_len, :] - # get kv_c and k_pe + nope_cache = nope_cache.squeeze() + # select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] and get kv_c and k_pe # cached_kv_c: [batch_size, max_context_len, latent_kv] # cached_k_pe: [batch_size, max_context_len, rope_dim] - cache_kv_c = cache_kv_c_pe[:, :, :latent_kv_dim] - cache_k_pe = cache_kv_c_pe[:, :, latent_kv_dim:] + cache_kv_c = nope_cache[block_tables].view( + batch_size, max_num_blocks_per_seq * block_size, + latent_kv_dim)[:, :max_context_len, :] + cache_k_pe = rope_cache[block_tables].view( + batch_size, max_num_blocks_per_seq * block_size, + rope_dim)[:, :max_context_len, :] # get k_rope and v # k_nope: [batch_size, max_context_len, num_heads, nope_dim] # value: [batch_size, max_context_len, num_heads, v_head_dim] @@ -258,8 +262,8 @@ def vanilla_chunked_prefill_mla( attn_output = (attn_output[q_mask].view([-1, num_heads, v_head_dim]).to(output.dtype)) - output = output.view([-1, num_heads, v_head_dim]) - output.copy_(attn_output[:query.size(0) - num_add_query]) + attn_output = attn_output.view_as(output) + output.copy_(attn_output) return attn_output diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 3c24bfc70f..7d7da6328f 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -74,7 +74,19 @@ # Related PR (if no, explain why): # Need a PR to vllm to support more backend. # Future Plan: -# Remove those patch when vllm support more backend. +# Remove those patch when vllm merged them +# +# ** File: platform/patch_common/patch_scheduler.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.core.sched.scheduler.Scheduler.destroy_model_parallel()` +# Why: +# Vllm transfer the kv blocks data only when this block have already been full filled. However, this behaviour may cause decode node +# exist prefill behaviour. In order to make decode node work as expected, we always transfer all data whether or not the block is filled. +# How: +# The num_computed_token shall always equals to the token number of request during scheduling. +# Related PR (if no, explain why): https://github.com/vllm-project/vllm/pull/17751 (nixl implementation) +# Future Plan: +# No plan, we will maintain this patch util vllm change it behaviour # # * Worker Patch: # =============== diff --git a/vllm_ascend/patch/platform/patch_common/__init__.py b/vllm_ascend/patch/platform/patch_common/__init__.py index f88f2a9d83..0f6d9bc435 100644 --- a/vllm_ascend/patch/platform/patch_common/__init__.py +++ b/vllm_ascend/patch/platform/patch_common/__init__.py @@ -16,3 +16,4 @@ # import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa +import vllm_ascend.patch.platform.patch_common.patch_scheduler # noqa diff --git a/vllm_ascend/patch/platform/patch_common/patch_scheduler.py b/vllm_ascend/patch/platform/patch_common/patch_scheduler.py new file mode 100644 index 0000000000..b3f1e7fa77 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_common/patch_scheduler.py @@ -0,0 +1,491 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +import time +from collections import deque + +from vllm.distributed.kv_events import KVEventBatch +from vllm.logger import init_logger +from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.engine import EngineCoreEventType +from vllm.v1.request import Request, RequestStatus + +logger = init_logger(__name__) + + +def ascend_update_waiting_for_remote_kv(self, request) -> bool: + """ + P/D: check if the request_id is finished_recving. + + The finished_recving_kv_req_ids list is populated + on the previous steps()'s update_from_output based + on the worker side connector. + + When the kv transfer is ready, we cache the blocks + and the request state will be moved back to WAITING from + WAITING_FOR_REMOTE_KV. + """ + if request.request_id not in self.finished_recving_kv_req_ids: + return False + assert len(self.kv_cache_config.kv_cache_groups + ) == 1, "KV connector only supports one KV cache group now" + # Now that the blocks are ready, actually cache them. + # In order to make decode node always do the decode step, we transfer every block as long as it contains the + # data computed by prefill node. + num_computed_tokens = request.num_tokens + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + self.kv_cache_manager.cache_blocks( + request, + num_computed_tokens, + ) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens + + # Return that we are ready. + self.finished_recving_kv_req_ids.remove(request.request_id) + return True + + +def ascend_schedule(self) -> SchedulerOutput: + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. + # Each request just has the num_computed_tokens and + # num_tokens_with_spec. num_tokens_with_spec = + # len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids). + # At each step, the scheduler tries to assign tokens to the requests + # so that each request's num_computed_tokens can catch up its + # num_tokens_with_spec. This is general enough to cover + # chunked prefills, prefix caching, speculative decoding, + # and the "jump decoding" optimization in the future. + + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + + # NOTE: structured_output_request_ids maps + # a request's (request that uses structured output) + # request_id to the running request index. + # This will helps us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. + structured_output_request_ids: dict[str, int] = {} + + req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_budget = self.max_num_encoder_input_tokens + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # For logging. + scheduled_timestamp = time.monotonic() + + # First, schedule the RUNNING requests. + req_index = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + + num_new_tokens = (request.num_tokens_with_spec - + request.num_computed_tokens) + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + num_new_tokens = min(num_new_tokens, + self.max_model_len - request.num_computed_tokens) + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, request.num_computed_tokens, num_new_tokens, + encoder_budget) + + if num_new_tokens == 0: + # The request cannot be scheduled because one of the following + # reasons: + # 1. No new tokens to schedule. This may happen when PP>1 and + # we have already scheduled all prompt tokens but they are + # not finished yet. + # 2. The encoder budget is exhausted. + # 3. The encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue + + num_draft_tokens = max( + num_new_tokens + request.num_computed_tokens - request.num_tokens, + 0) + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_draft_tokens=num_draft_tokens, + num_lookahead_tokens=self.num_lookahead_tokens) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + preempted_req = self.running.pop() + self.kv_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + if self.log_stats: + preempted_req.record_event(EngineCoreEventType.PREEMPTED, + scheduled_timestamp) + + self.waiting.appendleft(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + scheduled_running_reqs.append(request) + if request.use_structured_output: + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids[request.request_id] = req_index + req_to_new_block_ids[request.request_id] = (new_blocks.get_block_ids()) + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids) + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Record the LoRAs in scheduled_running_reqs + scheduled_loras: set[int] = set() + if self.lora_config: + scheduled_loras = set( + req.lora_request.lora_int_id for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0) + assert len(scheduled_loras) <= self.lora_config.max_loras + + # Use a temporary deque to collect requests that need to be skipped + # and put back at the head of the waiting queue later + skipped_waiting_requests: deque[Request] = deque() + + # Next, schedule the WAITING requests. + if not preempted_reqs: + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_running_reqs: + break + + request = self.waiting[0] + + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id) + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + + # Skip request if the structured output request is still waiting + # for FSM compilation. + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if structured_output_req and structured_output_req.grammar: + request.status = RequestStatus.WAITING + else: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + + # Check that adding the request still respects the max_loras + # constraint. + if self.lora_config and request.lora_request and ( + len(scheduled_loras) == self.lora_config.max_loras and + request.lora_request.lora_int_id not in scheduled_loras): + # Scheduling would exceed max_loras, skip. + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + + num_external_computed_tokens = 0 + load_kv_async = False + + # Get already-cached tokens. + if request.num_computed_tokens == 0: + # Get locally-cached tokens. + new_computed_blocks, num_new_local_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks( + request) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens = (num_new_local_computed_tokens + + num_external_computed_tokens) + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. + else: + new_computed_blocks = ( + self.kv_cache_manager.create_empty_block_list()) + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens + + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + + # KVTransfer: loading remote KV, do not allocate for new work. + if load_kv_async: + assert num_external_computed_tokens > 0 + num_new_tokens = 0 + # Number of tokens to be scheduled. + else: + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed + # requests, which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + if self.connector is not None and not self.cache_config.enable_prefix_caching \ + and num_new_tokens > token_budget: + break + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens + num_external_computed_tokens, + num_new_local_computed_tokens, + new_computed_blocks, + num_lookahead_tokens=self.num_lookahead_tokens, + delay_cache_blocks=load_kv_async, + ) + if new_blocks is None: + # The request cannot be scheduled. + break + + # KVTransfer: the connector uses this info to determine + # if a load is needed. Note that + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + new_computed_blocks + new_blocks, + num_external_computed_tokens, + ) + + self.waiting.popleft() + if load_kv_async: + # If loading async, allocate memory and put request + # into the WAITING_FOR_REMOTE_KV state. + skipped_waiting_requests.appendleft(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue + + if request.use_structured_output: + structured_output_request_ids[request.request_id] = req_index + req_index += 1 + self.running.append(request) + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, + scheduled_timestamp) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError(f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + req_to_new_block_ids[request.request_id] = ( + self.kv_cache_manager.get_block_ids(request.request_id)) + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.extendleft(skipped_waiting_requests) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). + assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + + len(scheduled_running_reqs) <= len(self.running)) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) + + grammar_bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request(req, req_to_new_block_ids[req.request_id]) + for req in scheduled_new_reqs + ] + resumed_reqs_data = [ + self._make_cached_request_data( + req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), + req_to_new_block_ids[req.request_id], + resumed_from_preemption=True, + ) for req in scheduled_resumed_reqs + ] + running_reqs_data = [ + self._make_cached_request_data( + req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), + req_to_new_block_ids[req.request_id], + resumed_from_preemption=False, + ) for req in scheduled_running_reqs + ] + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, + free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + structured_output_request_ids=structured_output_request_ids, + grammar_bitmask=grammar_bitmask, + ) + + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + + events = self.kv_cache_manager.take_events() + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the + # original number of scheduled tokens to determine input IDs. + # 2. Advance the number of computed tokens here allowing us to + # schedule the prefill request again immediately in the next + # scheduling step. + # 3. If some tokens (e.g. spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + for req_id, num_scheduled_token in num_scheduled_tokens.items(): + self.requests[req_id].num_computed_tokens += num_scheduled_token + + self.finished_req_ids = set() + return scheduler_output + + +Scheduler._update_waiting_for_remote_kv = ascend_update_waiting_for_remote_kv +Scheduler.schedule = ascend_schedule diff --git a/vllm_ascend/soc_info.py b/vllm_ascend/soc_info.py new file mode 100644 index 0000000000..ac1317e8e1 --- /dev/null +++ b/vllm_ascend/soc_info.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +import torch_npu + + +@dataclass +class NPUSocInfo: + is_a3: bool = False + + def __post_init__(self): + torch_npu.npu._lazy_init() + self.soc_version = torch_npu._C._npu_get_soc_version() + if self.soc_version in (250, 251, 252, 253, 254, 255): + self.is_a3 = True diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 43b472ec9a..aa894ab6fd 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -17,7 +17,9 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # +import copy import gc +import math import os import time import types @@ -37,8 +39,11 @@ from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import get_dp_group, get_pp_group -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE @@ -762,7 +767,8 @@ def _process_reqs( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata, - torch.Tensor, int, torch.Tensor]: + torch.Tensor, int, torch.Tensor, Optional[set[str]], + Optional[set[str]]]: # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -1023,6 +1029,7 @@ def _process_reqs( self.vllm_config, num_tokens=num_input_tokens): with ProfileExecuteDuration().capture_async("forward"): + self.maybe_setup_kv_connector(scheduler_output) model_kwargs = {} if self.torchair_graph_enabled: model_kwargs["kv_caches"] = self.kv_caches @@ -1047,6 +1054,9 @@ def _process_reqs( **model_kwargs, ) + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer( + scheduler_output) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -1071,7 +1081,8 @@ def _process_reqs( sample_indices = spec_decode_metadata.logits_indices return (attn_metadata, hidden_states, spec_decode_metadata, positions, - total_num_scheduled_tokens, sample_indices) + total_num_scheduled_tokens, sample_indices, finished_sending, + finished_recving) def _calc_spec_decode_metadata( self, @@ -1242,12 +1253,17 @@ def execute_model( "prepare input and forward"): self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: - # Return empty ModelRunnerOuptut if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT + if not has_kv_transfer_group(): + logger.debug( + "skip this step for we receive the data from remote disaggregate prefill node" + ) + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output) (attn_metadata, hidden_states, spec_decode_metadata, positions, - num_scheduled_tokens, - sample_indices) = (self._process_reqs(scheduler_output, - intermediate_tensors)) + num_scheduled_tokens, sample_indices, finished_sending, + finished_recving) = (self._process_reqs(scheduler_output, + intermediate_tensors)) with ProfileExecuteDuration().capture_async("post process"): logits = self.model.compute_logits(hidden_states[sample_indices], @@ -1338,6 +1354,8 @@ def execute_model( hidden_states, attn_metadata, ) + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() model_runner_output = ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1346,6 +1364,8 @@ def execute_model( spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict={}, + finished_sending=finished_sending, + finished_recving=finished_recving, ) durations = ProfileExecuteDuration().pop_captured_sync() @@ -1360,6 +1380,49 @@ def execute_model( return model_runner_output + def kv_connector_no_forward( + self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + with set_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + finsihed_sending, finished_recving = ( + self.get_finished_kv_transfer(scheduler_output)) + # For the case of no forward caused by receiving remote kv, + # one round of dummy inference is necessary + # to prevent hang over the collective calls. + if not finsihed_sending and not finished_recving: + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finsihed_sending + output.finished_recving = finished_recving + return output + + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod + def maybe_wait_for_kv_save() -> None: + if has_kv_transfer_group(): + get_kv_transfer_group().wait_for_save() + + @staticmethod + def get_finished_kv_transfer( + scheduler_output: "SchedulerOutput", + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids) + return None, None + def _profile_multimodal(self) -> None: # TODO: handle encoder-decoder models once we support them. # NOTE: Currently model is profiled with a single non-text @@ -1649,9 +1712,14 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - import torch_npu kv_caches: Dict[str, torch.Tensor] = {} + def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: + data_ptr = tensor.data_ptr() + aligned_addr = (data_ptr + alignment - 1) // alignment * alignment + offset = (aligned_addr - data_ptr) // tensor.element_size() + return tensor[int(offset):] + self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.model_config.max_model_len, @@ -1684,6 +1752,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # different GPUs, and `kv_cache_config.num_blocks` is set to # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks + alignment = 2 * 1024 * 1024 # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may # encounter OOM issue if isinstance(kv_cache_spec, FullAttentionSpec): @@ -1691,29 +1760,51 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - if self.torchair_graph_enabled: - layer_kv_cache_nope = torch.zeros( - kv_cache_shape[:-1] + - (self.model_config.hf_text_config.kv_lora_rank, ), - dtype=self.dtype, - pin_memory=True, - device=self.device) - layer_kv_cache_pe = torch.zeros( - kv_cache_shape[:-1] + - (self.model_config.hf_text_config.qk_rope_head_dim, - ), - dtype=self.dtype, - pin_memory=True, - device=self.device) - kv_caches[layer_name] = (layer_kv_cache_nope, - layer_kv_cache_pe) - torch_npu.npu_format_cast(kv_caches[layer_name][0], 2) - torch_npu.npu_format_cast(kv_caches[layer_name][1], 2) + if self.model_config.is_deepseek_mla: + # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory + # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but + # we found there are also some exceptions during test, so we manual align those memory here, this part + # of code may consume 2M * 2 * elem_size memory every layer. + num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape + rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + nope_dim = head_size - rope_dim + nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim + nope_allocate_shape_alignment = nope_allocate_shape + alignment + rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim + rope_allocate_shape_alignment = rope_allocate_shape + alignment + nope_cache_shape = (num_blocks, block_size, + num_kv_heads, nope_dim) + rope_cache_shape = (num_blocks, block_size, + num_kv_heads, rope_dim) + nope_cache = torch.zeros(nope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + rope_cache = torch.zeros(rope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + nope_cache = align_memory( + nope_cache, alignment)[:nope_allocate_shape].view( + nope_cache_shape) + rope_cache = align_memory( + rope_cache, alignment)[:rope_allocate_shape].view( + rope_cache_shape) + kv_caches[layer_name] = (nope_cache, rope_cache) else: - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) - torch_npu.npu_format_cast(kv_caches[layer_name], 2) + num_caches = kv_cache_shape[0] + kv_cache_list = [] + for i in range(num_caches): + cache_shape = kv_cache_shape[1:] + cache_size = math.prod(cache_shape) + cache_size_aligned = cache_size + alignment + kv_cache = torch.zeros(cache_size_aligned, + dtype=dtype, + device=self.device) + kv_cache = align_memory( + kv_cache, + alignment)[:cache_size].view(cache_shape) + kv_cache_list.append(kv_cache) + kv_caches[layer_name] = kv_cache_list + # torch_npu.npu_format_cast(kv_caches[layer_name], 2) else: # TODO: add new branches when introducing more types of # KV cache specs. @@ -1724,6 +1815,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.vllm_config.compilation_config.static_forward_context, self.kv_caches) + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 6fe84a4580..5cd3679319 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -75,6 +75,9 @@ def __init__( is_driver_worker=is_driver_worker) # Try to import mindie_turbo to accelerate vLLM inference. + local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local + world_size = self.vllm_config.parallel_config.world_size + self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank try_register_lib( "mindie_turbo", "MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo." From 66f72812d403714f3bee6c17eb55b49e5fe29689 Mon Sep 17 00:00:00 2001 From: underfituu Date: Thu, 19 Jun 2025 17:10:54 +0800 Subject: [PATCH 2/2] fix ut for v091 Signed-off-by: underfituu --- tests/ut/kv_connector/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/ut/kv_connector/utils.py b/tests/ut/kv_connector/utils.py index 2ad3b17aeb..a8f65ff429 100644 --- a/tests/ut/kv_connector/utils.py +++ b/tests/ut/kv_connector/utils.py @@ -159,7 +159,6 @@ def create_request( multi_modal_inputs=None, multi_modal_placeholders=None, multi_modal_hashes=None, - pooling_params=None, eos_token_id=EOS_TOKEN_ID, ) req.kv_transfer_params = kv_transfer_params @@ -190,7 +189,6 @@ def create_model_runner_output( spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - pooler_output=[None], finished_sending=finished_sending, finished_recving=finished_recving, )