@@ -167,6 +167,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: int):
167167 # Can not retrive the parallel config since it is not initialized.
168168 self .local_dp_rank = None
169169 self .tp_size = None
170+ self .port = self .vllm_config .parallel_config .data_parallel_rank_local * self .vllm_config .parallel_config .tensor_parallel_size + envs .VLLM_LLMDD_CHANNEL_PORT
170171
171172 self ._reqs_need_recv : dict [str , tuple [Request , list [int ]]] = {}
172173
@@ -244,12 +245,6 @@ def request_finished(
244245 request : "Request" ,
245246 block_ids : list [int ],
246247 ) -> tuple [bool , Optional [dict [str , Any ]]]:
247- if self .local_dp_rank is None :
248- vllm_config = get_current_vllm_config ()
249- # Need this dp rank to locate the only dp rank the kv cache from
250- self .local_dp_rank = vllm_config .parallel_config .data_parallel_rank_local
251- # Need this tp size to offset the port in tp size
252- self .tp_size = vllm_config .parallel_config .tensor_parallel_size
253248
254249 params = request .kv_transfer_params
255250 logger .debug (
@@ -267,13 +262,13 @@ def request_finished(
267262 # If prompt < block_size, no xfer so free blocks immediately.
268263 delay_free_blocks = len (computed_block_ids ) > 0
269264
270- return delay_free_blocks , dict (
265+ return False , dict (
271266 do_remote_prefill = True ,
272267 do_remote_decode = False ,
273268 remote_block_ids = computed_block_ids ,
274269 remote_engine_id = self .engine_id ,
275270 remote_host = self .local_ip ,
276- remote_port = envs . VLLM_LLMDD_CHANNEL_PORT + self .local_dp_rank * self . tp_size ,
271+ remote_port = self .port ,
277272 )
278273
279274class LLMDataDistConnectorWorker ():
@@ -295,6 +290,7 @@ def __init__(
295290 self .local_ip = get_ip ()
296291 self .kv_transfer_config : Optional [KVTransferConfig ] = vllm_config .kv_transfer_config
297292 self .local_agent_metadata : Optional [LLMDataDistAgentMetadata ] = None
293+ self .vllm_config = vllm_config
298294
299295 self .llm_datadist_role = None
300296 self .llm_datadist_remote_role = None
@@ -498,11 +494,11 @@ def start_load_kv(self, metadata: LLMDataDistConnectorMetadata):
498494 )
499495 self .finished_reqs .add (req_id )
500496
501- def add_remote_agent (self , metadata : LLMDataDistAgentMetadata ) -> bool :
497+ def add_remote_agent (self , metadata : LLMDataDistAgentMetadata ) -> int :
502498 remote_cluster_id = metadata .cluster_id
503499 if remote_cluster_id in self .linked_cluster :
504500 logger .debug (f"LLMDataDistConnectorWorker: remote cluster_id: { metadata .cluster_id } already linked with this server, skip the connection" )
505- return False
501+ return remote_cluster_id
506502 remote_super_pod_id = metadata .super_pod_id
507503 remote_device_id = metadata .device_id
508504 remote_device_ip = metadata .device_ip
@@ -618,10 +614,10 @@ def add_remote_agent(self, metadata: LLMDataDistAgentMetadata) -> bool:
618614 raise RuntimeError (f"LLMDataDistConnectorWorker: Linking failed, comm id: { comm_id } " )
619615 time .sleep (1 )
620616 logger .info ("Checking query_register_mem_status again" )
621- self .linked_cluster .update ({remote_server_id : ( remote_cluster_id , comm_id ) })
617+ self .linked_cluster .update ({remote_cluster_id : comm_id })
622618 logger .info (f"cached linked cluster: { self .linked_cluster } " )
623619 logger .info (f"Sucessfully build link with cluster id { remote_cluster_id } with cluster name { comm_name } !" )
624- return True
620+ return remote_cluster_id
625621
626622
627623 def remove_remote_agent (self , cluster_id : int ):
@@ -641,7 +637,7 @@ def connect_to_remote_agent(
641637 host : str ,
642638 port : int
643639 ):
644- url = f"tcp://{ host } :{ port + self . tp_rank } "
640+ url = f"tcp://{ host } :{ port } "
645641 logger .debug (f"Querying metadata from url: { url } " )
646642 msg_encoder = msgspec .msgpack .Encoder ()
647643 msg_send = msg_encoder .encode (self .local_agent_metadata )
@@ -653,7 +649,8 @@ def connect_to_remote_agent(
653649 metadata = decoder .decode (metadata_bytes )
654650 metadata = LLMDataDistAgentMetadata (** metadata )
655651 logger .info (f"recving metadata: { metadata } " )
656- self .add_remote_agent (metadata )
652+ cluster_id = self .add_remote_agent (metadata )
653+ return cluster_id
657654
658655 def _read_blocks (
659656 self ,
@@ -664,8 +661,8 @@ def _read_blocks(
664661 remote_engine_id : str ,
665662 request_id : str ,
666663 ):
667- if remote_ip not in self .linked_cluster :
668- self .connect_to_remote_agent (remote_ip , remote_port )
664+ # if remote_ip not in self.linked_cluster:
665+ self .connect_to_remote_agent (remote_ip , remote_port + self . tp_rank )
669666 num_local_blocks = len (local_block_ids )
670667 if num_local_blocks == 0 :
671668 return
@@ -681,8 +678,8 @@ def _read_blocks(
681678 remote_cache_key_k_pe = BlocksCacheKey (cluster_id = remote_cluster_id , model_id = 1 )
682679 logger .info ("Try pull blocks from remote server" )
683680 try :
684- self .cache_manager .pull_blocks (remote_cache_key_k_normed , self .cache [0 ], local_block_ids , remote_block_ids )
685- self .cache_manager .pull_blocks (remote_cache_key_k_pe , self .cache [1 ], local_block_ids , remote_block_ids )
681+ self .cache_manager .pull_blocks (remote_cache_key_k_normed , self .cache [0 ], remote_block_ids , local_block_ids )
682+ self .cache_manager .pull_blocks (remote_cache_key_k_pe , self .cache [1 ], remote_block_ids , local_block_ids )
686683 except (TypeError , ValueError ) as e :
687684 raise RuntimeError (f"LLMDataDistConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: { remote_cache_key } , cache: { self .cache } , local_block_ids: { local_block_ids } , remote_block_ids: { remote_block_ids } " )
688685 except LLMException :
@@ -691,7 +688,7 @@ def _read_blocks(
691688 remote_cache_key = BlocksCacheKey (cluster_id = remote_cluster_id )
692689 logger .info ("Try pull blocks from remote server" )
693690 try :
694- self .cache_manager .pull_blocks (remote_cache_key , self .cache , local_block_ids , remote_block_ids )
691+ self .cache_manager .pull_blocks (remote_cache_key , self .cache , remote_block_ids , local_block_ids )
695692 except (TypeError , ValueError ) as e :
696693 raise RuntimeError (f"LLMDataDistConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: { remote_cache_key } , cache: { self .cache } , local_block_ids: { local_block_ids } , remote_block_ids: { remote_block_ids } " )
697694 except LLMException :
0 commit comments