11import msgspec
2+ import os
23from dataclasses import dataclass
34
45from typing import Optional , Any , Tuple
2829# from .llmdatadist_connector_v1 import TORCH_DTYPE_TO_NPU_DTYPE
2930from vllm .v1 .request import Request
3031from vllm .utils import logger
32+ from vllm_ascend .soc_info import NPUSocInfo
3133
3234import llm_datadist
3335from llm_datadist import LLMDataDist , LLMRole , CacheDesc , BlocksCacheKey , LLMConfig , LLMException
@@ -314,14 +316,15 @@ def __init__(
314316 self .llm_datadist = LLMDataDist (self .llm_datadist_role , self .local_agent_metadata .cluster_id )
315317 self .init_llm_datadist ()
316318 self .finished_reqs = set ()
319+ self .soc_info = NPUSocInfo ()
317320 # remote_ip, remote_rank = self.get_remote_ip_and_rank()
318321 # for idx in range(len(remote_ip)):
319322 # remote_agent_meta = self.read_agent_metadata(global_rank_table, remote_ip[idx], remote_rank[idx], self.llm_datadist_remote_role)
320323 # self.add_remote_agent(remote_agent_meta)
321324
322325
323326 def listen_for_agent_metadat_req (self , event : threading .Event ):
324- port = envs .VLLM_LLMDD_CHANNEL_PORT + self .local_dp_rank * self .tp_size
327+ port = envs .VLLM_LLMDD_CHANNEL_PORT + self .local_dp_rank * self .tp_size + self . tp_rank
325328 url = f"tcp://0.0.0.0:{ port } "
326329 msg_encoder = msgspec .msgpack .Encoder ()
327330 msg_decoder = msgspec .msgpack .Decoder ()
@@ -394,11 +397,11 @@ def read_agent_metadata(self, global_rank_table, server_id, device_id, agent_rol
394397 continue
395398 if device_info ["device_id" ] != str (device_id ):
396399 continue
397- super_pod_id_ = device_info [ "super_pod_id" ]
400+ super_pod_id_ = device_info . get ( "super_pod_id" , None )
398401 server_id_ = device_info ["server_id" ]
399402 device_id_ = device_info ["device_id" ]
400403 device_ip_ = device_info ["device_ip" ]
401- super_device_id_ = device_info [ "super_device_id" ]
404+ super_device_id_ = device_info . get ( "super_device_id" , None )
402405 cluster_id_ = int (device_info ["cluster_id" ])
403406 agent_metadata = LLMDataDistAgentMetadata (
404407 super_pod_id = super_pod_id_ ,
@@ -537,32 +540,38 @@ def add_remote_agent(self, metadata: LLMDataDistAgentMetadata) -> bool:
537540 decode_server_device_info = None
538541 prefill_server_device_info = {
539542 "device" : [
540- {
541- "device_id" : prefill_metadata .device_id ,
542- "device_ip" : prefill_metadata .device_ip ,
543- "super_device_id" : prefill_metadata .super_device_id ,
544- "rank_id" : "0"
543+ {
544+ k : v for k , v in [
545+ ("device_id" , prefill_metadata .device_id ),
546+ ("device_ip" , prefill_metadata .device_ip ),
547+ ("super_device_id" , prefill_metadata .super_device_id ),
548+ ("rank_id" , "0" )]
549+ if v is not None
545550 }
546551 ],
547552 "server_id" : prefill_metadata .server_id
548553 }
549554 if is_same_server :
550555 prefill_server_device_info ["device" ].append (
551- {
552- "device_id" : decode_metadata .device_id ,
553- "device_ip" : decode_metadata .device_ip ,
554- "super_device_id" : decode_metadata .super_device_id ,
555- "rank_id" : "1"
556+ {
557+ k : v for k , v in [
558+ ("device_id" , decode_metadata .device_id ),
559+ ("device_ip" , decode_metadata .device_ip ),
560+ ("super_device_id" , decode_metadata .super_device_id ),
561+ ("rank_id" , "1" )]
562+ if v is not None
556563 }
557564 )
558565 else :
559566 decode_server_device_info = {
560567 "device" : [
561- {
562- "device_id" : decode_metadata .device_id ,
563- "device_ip" : decode_metadata .device_ip ,
564- "super_device_id" : decode_metadata .super_device_id ,
565- "rank_id" : "1"
568+ {
569+ k : v for k , v in [
570+ ("device_id" , decode_metadata .device_id ),
571+ ("device_ip" , decode_metadata .device_ip ),
572+ ("super_device_id" , decode_metadata .super_device_id ),
573+ ("rank_id" , "1" )]
574+ if v is not None
566575 }
567576 ],
568577 "server_id" : decode_metadata .server_id
@@ -571,28 +580,29 @@ def add_remote_agent(self, metadata: LLMDataDistAgentMetadata) -> bool:
571580 if decode_server_device_info is not None :
572581 rank_table ["server_list" ].append (decode_server_device_info )
573582
574- # generate super_pod_list for rank table
575- super_pod_list = []
576- prefill_super_pod_info = {
577- "super_pod_id" : prefill_metadata .super_pod_id ,
578- "server_list" : [
579- {"server_id" : prefill_metadata .server_id }
580- ],
581- }
582- if is_same_pod and not is_same_server :
583- prefill_super_pod_info ["server_list" ].append (
584- {"server_id" : decode_metadata .server_id }
585- )
586- super_pod_list .append (prefill_super_pod_info )
587- if not is_same_pod :
588- decode_super_pod_id = {
589- "super_pod_id" : decode_metadata .super_pod_id ,
583+ if self .soc_info .is_a3 :
584+ # generate super_pod_list for rank table
585+ super_pod_list = []
586+ prefill_super_pod_info = {
587+ "super_pod_id" : prefill_metadata .super_pod_id ,
590588 "server_list" : [
591- {"server_id" : decode_metadata .server_id }
592- ],
589+ {"server_id" : prefill_metadata .server_id }
590+ ],
593591 }
594- super_pod_list .append (decode_super_pod_id )
595- rank_table ["super_pod_list" ] = super_pod_list
592+ if is_same_pod and not is_same_server :
593+ prefill_super_pod_info ["server_list" ].append (
594+ {"server_id" : decode_metadata .server_id }
595+ )
596+ super_pod_list .append (prefill_super_pod_info )
597+ if not is_same_pod :
598+ decode_super_pod_id = {
599+ "super_pod_id" : decode_metadata .super_pod_id ,
600+ "server_list" : [
601+ {"server_id" : decode_metadata .server_id }
602+ ],
603+ }
604+ super_pod_list .append (decode_super_pod_id )
605+ rank_table ["super_pod_list" ] = super_pod_list
596606 logger .info (f"LLMDataDistConnectorWorker: try link with remote, comm id: { comm_name } " )
597607 logger .info (f"rank table \n { rank_table } " )
598608 logger .info (f"comm name: { comm_name } " )
0 commit comments