Skip to content

Commit cc4f9fb

Browse files
liziyu179ganyi1996ppo
authored andcommitted
support a2
1 parent e7b020a commit cc4f9fb

File tree

4 files changed

+196
-38
lines changed

4 files changed

+196
-38
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import os
2+
import torch.distributed as dist
3+
import json
4+
import argparse
5+
from vllm_ascend.soc_info import NPUSocInfo
6+
7+
parser = argparse.ArgumentParser(
8+
description="Arguments of rank table generator",
9+
)
10+
parser.add_argument(
11+
"--prefill-device-cnt", type=int, required=True, help="number of prefill devices"
12+
)
13+
parser.add_argument(
14+
"--decode-device-cnt", type=int, required=True, help="number of decode devices"
15+
)
16+
args = parser.parse_args()
17+
prefill_device_cnt = args.prefill_device_cnt
18+
decode_device_cnt = args.decode_device_cnt
19+
20+
print("enter py")
21+
22+
master_addr = os.environ.get("MASTER_ADDR")
23+
master_port = os.environ.get("MASTER_PORT")
24+
rank = os.environ.get("RANK")
25+
# This variable is set by torchrun,
26+
# and is different from WORLD_SIZE in gen_rank_table.sh.
27+
world_size = os.environ.get("WORLD_SIZE")
28+
soc_info = NPUSocInfo()
29+
30+
def get_cmd_stdout(cmd):
31+
import subprocess
32+
return subprocess.run(
33+
cmd,
34+
capture_output=True,
35+
shell=True
36+
).stdout.decode("utf-8").strip()
37+
38+
local_host = get_cmd_stdout("hostname -I | awk -F \" \" \'{print$1}\'")
39+
print(f"local_host: {local_host}")
40+
print("gen ranktable.json")
41+
42+
num_cards = get_cmd_stdout("npu-smi info -l | grep \"Total Count\"").split(":")[1].strip()
43+
num_cards = int(num_cards)
44+
chips_per_card = get_cmd_stdout("npu-smi info -l | grep \"Chip Count\"").split("\n")[0].split(":")[1].strip()
45+
chips_per_card = int(chips_per_card)
46+
47+
local_device_list: list[dict[str, str]] = list()
48+
super_pod_id = "0"
49+
for card_id in range(num_cards):
50+
for chip_id in range(chips_per_card):
51+
device_id = card_id * chips_per_card + chip_id
52+
if soc_info.is_a3:
53+
device_ip = get_cmd_stdout(f"/usr/local/Ascend/driver/tools/hccn_tool -i {device_id} -vnic -g | grep ipaddr").split(":")[1].strip()
54+
super_device_id = get_cmd_stdout(f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep SDID").split(":")[1].strip()
55+
super_pod_id = get_cmd_stdout(f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep \"Super Pod ID\"").split(":")[1].strip()
56+
else:
57+
device_ip = get_cmd_stdout(f"/usr/local/Ascend/driver/tools/hccn_tool -i {device_id} -ip -g | grep ipaddr").split(":")[1].strip()
58+
59+
device_info = {
60+
"server_id": local_host,
61+
"device_id": str(device_id),
62+
"device_ip": str(device_ip),
63+
}
64+
if soc_info.is_a3:
65+
device_info.update({"super_pod_id": str(super_pod_id), "super_device_id": str(super_device_id)})
66+
local_device_list.append(device_info)
67+
68+
dist.init_process_group(backend=dist.Backend.GLOO)
69+
global_device_list = [None] * dist.get_world_size()
70+
dist.all_gather_object(global_device_list, local_device_list)
71+
global_device_list = [device_info for device_list in global_device_list for device_info in device_list]
72+
cnt = 1
73+
for device_info in global_device_list:
74+
device_info["cluster_id"] = str(cnt)
75+
cnt += 1
76+
assert (prefill_device_cnt + decode_device_cnt) <= len(global_device_list), \
77+
"prefill_device_cnt + decode_device_cnt must be less than or equal to number of all devices in cluster"
78+
ranktable = {
79+
"version": "1.2",
80+
"server_count": str(world_size),
81+
"prefill_device_list": global_device_list[:prefill_device_cnt],
82+
"decode_device_list": global_device_list[prefill_device_cnt:prefill_device_cnt+decode_device_cnt],
83+
"status": "completed"
84+
}
85+
86+
87+
with open("ranktable.json", "w") as f:
88+
json.dump(ranktable, f, indent=4)
89+
90+
print("gen ranktable.json done")
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
gen_rank_table.sh
2+
#!/bin/bash
3+
4+
source /usr/local/Ascend/ascend-toolkit/set_env.sh
5+
export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH}
6+
7+
IPs=('1.0.0.0' '1.0.0.1')
8+
LOCAL_HOST=`hostname -I|awk -F " " '{print$1}'`
9+
GPUS_PER_NODE=8
10+
MASTER_ADDR=${IPs[0]}
11+
MASTER_PORT=6657
12+
NNODES=${#IPs[@]}
13+
NODE_RANK="2"
14+
for i in "${!IPs[@]}";
15+
do
16+
echo "${IPs[$i]}"
17+
if [ "$LOCAL_HOST" == "${IPs[$i]}" ];
18+
then
19+
NODE_RANK=$i
20+
break
21+
fi
22+
done
23+
if [[ $NODE_RANK == "" ]];then
24+
echo "[Error] para \"NODE_RANK\" must be confing"
25+
exit 1
26+
fi
27+
28+
WORLD_SIZE=$(($GPUS_PER_NODE * $NNODES))
29+
RANKSTART=`expr $GPUS_PER_NODE \* $NODE_RANK`
30+
31+
echo "========>param:"
32+
echo "WORLD_SIZE: " $WORLD_SIZE
33+
echo "RANKSTART": $RANKSTART
34+
echo "NNODES": $NNODES
35+
echo "NODE_RANK": $NODE_RANK
36+
echo "==============="
37+
38+
if [[ -n "${GEN_RANKTABLE}" || ! -e ${PWD}/ranktable.json ]]; then
39+
GLOO_SOCKET_IFNAME=enp189s0f0 torchrun \
40+
--nproc_per_node 1 \
41+
--nnodes ${NNODES} \
42+
--node_rank ${NODE_RANK} \
43+
--master_addr ${MASTER_ADDR} \
44+
--master_port ${MASTER_PORT} \
45+
gen_ranktable.py --prefill-device-cnt $1 --decode-device-cnt $2
46+
fi

vllm_ascend/distributed/llmdatadist_connector_v1_a3.py

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import msgspec
2+
import os
23
from dataclasses import dataclass
34

45
from typing import Optional, Any, Tuple
@@ -28,6 +29,7 @@
2829
# from .llmdatadist_connector_v1 import TORCH_DTYPE_TO_NPU_DTYPE
2930
from vllm.v1.request import Request
3031
from vllm.utils import logger
32+
from vllm_ascend.soc_info import NPUSocInfo
3133

3234
import llm_datadist
3335
from llm_datadist import LLMDataDist, LLMRole, CacheDesc, BlocksCacheKey, LLMConfig, LLMException
@@ -314,14 +316,15 @@ def __init__(
314316
self.llm_datadist = LLMDataDist(self.llm_datadist_role, self.local_agent_metadata.cluster_id)
315317
self.init_llm_datadist()
316318
self.finished_reqs = set()
319+
self.soc_info = NPUSocInfo()
317320
# remote_ip, remote_rank = self.get_remote_ip_and_rank()
318321
# for idx in range(len(remote_ip)):
319322
# remote_agent_meta = self.read_agent_metadata(global_rank_table, remote_ip[idx], remote_rank[idx], self.llm_datadist_remote_role)
320323
# self.add_remote_agent(remote_agent_meta)
321324

322325

323326
def listen_for_agent_metadat_req(self, event: threading.Event):
324-
port = envs.VLLM_LLMDD_CHANNEL_PORT + self.local_dp_rank * self.tp_size
327+
port = envs.VLLM_LLMDD_CHANNEL_PORT + self.local_dp_rank * self.tp_size + self.tp_rank
325328
url = f"tcp://0.0.0.0:{port}"
326329
msg_encoder = msgspec.msgpack.Encoder()
327330
msg_decoder = msgspec.msgpack.Decoder()
@@ -394,11 +397,11 @@ def read_agent_metadata(self, global_rank_table, server_id, device_id, agent_rol
394397
continue
395398
if device_info["device_id"] != str(device_id):
396399
continue
397-
super_pod_id_ = device_info["super_pod_id"]
400+
super_pod_id_ = device_info.get("super_pod_id", None)
398401
server_id_ = device_info["server_id"]
399402
device_id_ = device_info["device_id"]
400403
device_ip_ = device_info["device_ip"]
401-
super_device_id_ = device_info["super_device_id"]
404+
super_device_id_ = device_info.get("super_device_id", None)
402405
cluster_id_ = int(device_info["cluster_id"])
403406
agent_metadata = LLMDataDistAgentMetadata(
404407
super_pod_id=super_pod_id_,
@@ -537,32 +540,38 @@ def add_remote_agent(self, metadata: LLMDataDistAgentMetadata) -> bool:
537540
decode_server_device_info = None
538541
prefill_server_device_info = {
539542
"device": [
540-
{
541-
"device_id": prefill_metadata.device_id,
542-
"device_ip": prefill_metadata.device_ip,
543-
"super_device_id": prefill_metadata.super_device_id,
544-
"rank_id": "0"
543+
{
544+
k: v for k, v in [
545+
("device_id", prefill_metadata.device_id),
546+
("device_ip", prefill_metadata.device_ip),
547+
("super_device_id", prefill_metadata.super_device_id),
548+
("rank_id", "0")]
549+
if v is not None
545550
}
546551
],
547552
"server_id": prefill_metadata.server_id
548553
}
549554
if is_same_server:
550555
prefill_server_device_info["device"].append(
551-
{
552-
"device_id": decode_metadata.device_id,
553-
"device_ip": decode_metadata.device_ip,
554-
"super_device_id": decode_metadata.super_device_id,
555-
"rank_id": "1"
556+
{
557+
k: v for k, v in [
558+
("device_id", decode_metadata.device_id),
559+
("device_ip", decode_metadata.device_ip),
560+
("super_device_id", decode_metadata.super_device_id),
561+
("rank_id", "1")]
562+
if v is not None
556563
}
557564
)
558565
else:
559566
decode_server_device_info = {
560567
"device": [
561-
{
562-
"device_id": decode_metadata.device_id,
563-
"device_ip": decode_metadata.device_ip,
564-
"super_device_id": decode_metadata.super_device_id,
565-
"rank_id": "1"
568+
{
569+
k: v for k, v in [
570+
("device_id", decode_metadata.device_id),
571+
("device_ip", decode_metadata.device_ip),
572+
("super_device_id", decode_metadata.super_device_id),
573+
("rank_id", "1")]
574+
if v is not None
566575
}
567576
],
568577
"server_id": decode_metadata.server_id
@@ -571,28 +580,29 @@ def add_remote_agent(self, metadata: LLMDataDistAgentMetadata) -> bool:
571580
if decode_server_device_info is not None:
572581
rank_table["server_list"].append(decode_server_device_info)
573582

574-
# generate super_pod_list for rank table
575-
super_pod_list = []
576-
prefill_super_pod_info = {
577-
"super_pod_id": prefill_metadata.super_pod_id,
578-
"server_list": [
579-
{"server_id": prefill_metadata.server_id}
580-
],
581-
}
582-
if is_same_pod and not is_same_server:
583-
prefill_super_pod_info["server_list"].append(
584-
{"server_id": decode_metadata.server_id}
585-
)
586-
super_pod_list.append(prefill_super_pod_info)
587-
if not is_same_pod:
588-
decode_super_pod_id = {
589-
"super_pod_id": decode_metadata.super_pod_id,
583+
if self.soc_info.is_a3:
584+
# generate super_pod_list for rank table
585+
super_pod_list = []
586+
prefill_super_pod_info = {
587+
"super_pod_id": prefill_metadata.super_pod_id,
590588
"server_list": [
591-
{"server_id": decode_metadata.server_id}
592-
],
589+
{"server_id": prefill_metadata.server_id}
590+
],
593591
}
594-
super_pod_list.append(decode_super_pod_id)
595-
rank_table["super_pod_list"] = super_pod_list
592+
if is_same_pod and not is_same_server:
593+
prefill_super_pod_info["server_list"].append(
594+
{"server_id": decode_metadata.server_id}
595+
)
596+
super_pod_list.append(prefill_super_pod_info)
597+
if not is_same_pod:
598+
decode_super_pod_id = {
599+
"super_pod_id": decode_metadata.super_pod_id,
600+
"server_list": [
601+
{"server_id": decode_metadata.server_id}
602+
],
603+
}
604+
super_pod_list.append(decode_super_pod_id)
605+
rank_table["super_pod_list"] = super_pod_list
596606
logger.info(f"LLMDataDistConnectorWorker: try link with remote, comm id: {comm_name}")
597607
logger.info(f"rank table \n{rank_table}")
598608
logger.info(f"comm name: {comm_name}")

vllm_ascend/soc_info.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from dataclasses import dataclass
2+
import torch_npu
3+
4+
@dataclass
5+
class NPUSocInfo:
6+
is_a3: bool = False
7+
8+
def __post_init__(self):
9+
torch_npu.npu._lazy_init()
10+
self.soc_version = torch_npu._C._npu_get_soc_version()
11+
if self.soc_version in (253, 254, 255):
12+
self.is_a3 = True

0 commit comments

Comments
 (0)