Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/vllm_ascend_test_pd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,7 @@ jobs:
- name: Run vllm-project/vllm-ascend PD Disaggregation test
run: |
pytest -sv tests/e2e/pd_disaggreate/test_pd_e2e.py

- name: Run vllm-project/vllm-ascend PD Disaggregation edge test
run: |
bash tests/e2e/pd_disaggreate/run_edge_case_test.sh
234 changes: 234 additions & 0 deletions examples/disaggregate_prefill_v1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# Disaggregated Prefill-Decode Deployment Guide

## Overview
This demo document provides instructions for running a disaggregated vLLM-ascend service with separate prefill and decode stages across 4 nodes, uses 16 Ascend NPUs for two prefill nodes (P1/P2) and 16 Ascend NPUS for two decode nodes (D1/D2).

## Prerequisites
- Ascend NPU environment with vLLM 0.9.1 installed
- Network interfaces configured for distributed communication (eg: eth0)
- Model weights located at `/data01/deepseek_r1_w8a8_zhw`

## Rank table generation
The rank table is a JSON file that specifies the mapping of Ascend NPU ranks to nodes. The following command generates a rank table for all nodes with 16 cards prefill and 16 cards decode:

Run the following command on every node to generate the rank table:
```shell
cd vllm-ascend/examples/disaggregate_prefill_v1/
bash gen_ranktable.sh --ips 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36 \
--npus-per-node 8 --network-card-name enp189s0f0 --prefill-device-cnt 16 --decode-device-cnt 16
```
Rank table will generated at `/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json`

## Start disaggregated vLLM-ascend service
Execution Sequence
- 4 configured node ip are: 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36
- Start Prefill on Node 1 (P1)
- Start Prefill on Node 2 (P2)
- Start Decode on Node 1 (D1)
- Start Decode on Node 2 (D2)
- Start proxy server on Node1

* Run prefill server P1 on first node
```shell
export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'`
export GLOO_SOCKET_IFNAME="eth0"
export TP_SOCKET_IFNAME="eth0"
export HCCL_SOCKET_IFNAME="eth0"
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
export VLLM_VERSION=0.9.1
vllm serve /data01/deepseek_r1_w8a8_zhw \
--host 0.0.0.0 \
--port 20002 \
--data-parallel-size 2 \
--data-parallel-size-local 1 \
--api-server-count 2 \
--data-parallel-address 172.19.32.175 \
--data-parallel-rpc-port 13356 \
--tensor-parallel-size 8 \
--no-enable-prefix-caching \
--seed 1024 \
--served-model-name deepseek \
--max-model-len 6144 \
--max-num-batched-tokens 6144 \
--trust-remote-code \
--enforce-eager \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "LLMDataDistCMgrConnector",
"kv_buffer_device": "npu",
"kv_role": "kv_producer",
"kv_parallel_size": 1,
"kv_port": "20001",
"engine_id": "0",
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
}' \
--additional-config \
'{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "expert_tensor_parallel_size": 1}'
```

* Run prefill server P2 on second node
```shell
export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'`
export GLOO_SOCKET_IFNAME="eth0"
export TP_SOCKET_IFNAME="eth0"
export HCCL_SOCKET_IFNAME="eth0"
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
export VLLM_VERSION=0.9.1
vllm serve /data01/deepseek_r1_w8a8_zhw \
--host 0.0.0.0 \
--port 20002 \
--headless \
--data-parallel-size 2 \
--data-parallel-start-rank 1 \
--data-parallel-size-local 1 \
--data-parallel-address 172.19.32.175 \
--data-parallel-rpc-port 13356 \
--tensor-parallel-size 8 \
--no-enable-prefix-caching \
--seed 1024 \
--served-model-name deepseek \
--max-model-len 6144 \
--max-num-batched-tokens 6144 \
--trust-remote-code \
--enforce-eager \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "LLMDataDistCMgrConnector",
"kv_buffer_device": "npu",
"kv_role": "kv_producer",
"kv_parallel_size": 1,
"kv_port": "20001",
"engine_id": "0",
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
}' \
--additional-config \
'{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "expert_tensor_parallel_size": 1}'
```

* Run decode server d1 on third node
```shell
export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'`
export GLOO_SOCKET_IFNAME="eth0"
export TP_SOCKET_IFNAME="eth0"
export HCCL_SOCKET_IFNAME="eth0"
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
export VLLM_VERSION=0.9.1
vllm serve /data01/deepseek_r1_w8a8_zhw \
--host 0.0.0.0 \
--port 20002 \
--data-parallel-size 2 \
--data-parallel-size-local 1 \
--api-server-count 2 \
--data-parallel-address 172.19.123.51 \
--data-parallel-rpc-port 13356 \
--tensor-parallel-size 8 \
--no-enable-prefix-caching \
--seed 1024 \
--served-model-name deepseek \
--max-model-len 6144 \
--max-num-batched-tokens 6144 \
--trust-remote-code \
--enforce-eager \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "LLMDataDistCMgrConnector",
"kv_buffer_device": "npu",
"kv_role": "kv_consumer",
"kv_parallel_size": 1,
"kv_port": "20001",
"engine_id": "0",
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
}' \
--additional-config \
'{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "expert_tensor_parallel_size": 1}'
```

* Run decode server d2 on last node
```shell
export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'`
export GLOO_SOCKET_IFNAME="eth0"
export TP_SOCKET_IFNAME="eth0"
export HCCL_SOCKET_IFNAME="eth0"
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
export VLLM_VERSION=0.9.1
vllm serve /data01/deepseek_r1_w8a8_zhw \
--host 0.0.0.0 \
--port 20002 \
--headless \
--data-parallel-size 2 \
--data-parallel-start-rank 1 \
--data-parallel-size-local 1 \
--data-parallel-address 172.19.123.51 \
--data-parallel-rpc-port 13356 \
--tensor-parallel-size 8 \
--no-enable-prefix-caching \
--seed 1024 \
--served-model-name deepseek \
--max-model-len 6144 \
--max-num-batched-tokens 6144 \
--trust-remote-code \
--enforce-eager \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "LLMDataDistCMgrConnector",
"kv_buffer_device": "npu",
"kv_role": "kv_consumer",
"kv_parallel_size": 1,
"kv_port": "20001",
"engine_id": "0",
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
}' \
--additional-config \
'{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "expert_tensor_parallel_size": 1}'
```

* Run proxy server on the first node
```shell
cd /vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1
python toy_proxy_server.py --host 172.19.32.175 --port 1025 --prefiller-hosts 172.19.241.49 --prefiller-port 20002 --decoder-hosts 172.19.123.51 --decoder-ports 20002
```

* Verification
Check service health using the proxy server endpoint:
```shell
curl http://localhost:1025/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "deepseek",
"prompt": "Who are you?",
"max_tokens": 100,
"temperature": 0
}'
```

* Performance
Test performance with vllm benchmark
```shell
cd /vllm-workspace/vllm/benchmarks
python3 benchmark_serving.py \
--backend vllm \
--dataset-name random \
--random-input-len 4096 \
--random-output-len 1536 \
--num-prompts 256 \
--ignore-eos \
--model deepseek \
--tokenizer /data01/deepseek_r1_w8a8_zhw \
--host localhost \
--port 8000 \
--endpoint /v1/completions \
--max-concurrency 4 \
--request-rate 4
```
120 changes: 120 additions & 0 deletions examples/disaggregate_prefill_v1/gen_ranktable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import argparse
import json
import os

import torch.distributed as dist

from vllm_ascend.soc_info import NPUSocInfo

parser = argparse.ArgumentParser(
description="Arguments of rank table generator", )
parser.add_argument("--local-host", type=str, required=True, help="local ip")
parser.add_argument("--prefill-device-cnt",
type=int,
required=True,
help="number of prefill devices")
parser.add_argument("--decode-device-cnt",
type=int,
required=True,
help="number of decode devices")
args = parser.parse_args()
local_host = args.local_host
prefill_device_cnt = args.prefill_device_cnt
decode_device_cnt = args.decode_device_cnt

print("enter py")

hccn_tool_path = os.environ.get("HCCN_TOOL_PATH",
"/usr/local/Ascend/driver/tools/hccn_tool")
master_addr = os.environ.get("MASTER_ADDR")
master_port = os.environ.get("MASTER_PORT")
rank = os.environ.get("RANK")
local_rank = os.environ.get("LOCAL_RANK")
# This variable is set by torchrun,
# and is different from WORLD_SIZE in gen_rank_table.sh.
world_size = os.environ.get("WORLD_SIZE")
soc_info = NPUSocInfo()


def get_cmd_stdout(cmd):
import subprocess
return subprocess.run(cmd, capture_output=True,
shell=True).stdout.decode("utf-8").strip()


print(f"local_host: {local_host}")
print("gen ranktable.json")

num_cards = get_cmd_stdout("npu-smi info -l | grep \"Total Count\"").split(
":")[1].strip()
num_cards = int(num_cards)
chips_per_card = get_cmd_stdout("npu-smi info -l | grep \"Chip Count\"").split(
"\n")[0].split(":")[1].strip()
chips_per_card = int(chips_per_card)

# generate local device list for local rank 0, and gather it to all ranks
local_device_list: list[dict[str, str]] = list()
if local_rank == "0":
super_pod_id = "0"
for card_id in range(num_cards):
for chip_id in range(chips_per_card):
device_id = card_id * chips_per_card + chip_id
if soc_info.is_a3:
device_ip = get_cmd_stdout(
f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr"
).split(":")[1].strip()
super_device_id = get_cmd_stdout(
f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep SDID"
).split(":")[1].strip()
super_pod_id = get_cmd_stdout(
f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep \"Super Pod ID\""
).split(":")[1].strip()
else:
device_ip = get_cmd_stdout(
f"{hccn_tool_path} -i {device_id} -ip -g | grep ipaddr"
).split(":")[1].strip()

device_info = {
"server_id": local_host,
"device_id": str(device_id),
"device_ip": str(device_ip),
}
if soc_info.is_a3:
device_info.update({
"super_pod_id": str(super_pod_id),
"super_device_id": str(super_device_id)
})
local_device_list.append(device_info)

dist.init_process_group(backend=dist.Backend.GLOO)
global_device_list = [None] * dist.get_world_size()
dist.all_gather_object(global_device_list, local_device_list)
global_device_list = [
device_info for device_list in global_device_list
for device_info in device_list # type: ignore[attr-defined]
]
cnt = 1
for device_info in global_device_list: # type: ignore[assignment]
device_info["cluster_id"] = str(cnt)
cnt += 1
assert (prefill_device_cnt + decode_device_cnt) <= len(global_device_list), \
"prefill_device_cnt + decode_device_cnt must be less than or equal to number of all devices in cluster"
ranktable = {
"version":
"1.2",
"server_count":
str(world_size),
"prefill_device_list":
global_device_list[:prefill_device_cnt],
"decode_device_list":
global_device_list[prefill_device_cnt:prefill_device_cnt +
decode_device_cnt],
"status":
"completed"
}

if local_rank == '0':
with open("ranktable.json", "w") as f:
json.dump(ranktable, f, indent=4)

print("gen ranktable.json done")
Loading