1919 KVConnectorBase_V1 , KVConnectorMetadata , KVConnectorRole )
2020from vllm .distributed .parallel_state import (
2121 get_tensor_model_parallel_rank , get_tensor_model_parallel_world_size ,
22- get_tp_group , get_world_group )
22+ get_tp_group )
2323from vllm .logger import init_logger
2424from vllm .utils import make_zmq_path , make_zmq_socket , round_down
2525from vllm .v1 .core .sched .output import SchedulerOutput
@@ -172,6 +172,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
172172 self .vllm_config = vllm_config
173173 self .block_size = vllm_config .cache_config .block_size
174174 self .engine_id = engine_id
175+ self .side_channel_host = envs .VLLM_NIXL_SIDE_CHANNEL_HOST
176+ self .side_channel_port = (
177+ envs .VLLM_NIXL_SIDE_CHANNEL_PORT +
178+ vllm_config .parallel_config .data_parallel_rank_local *
179+ vllm_config .parallel_config .tensor_parallel_size )
175180 logger .info ("Initializing NIXL Scheduler %s" , engine_id )
176181
177182 # Requests that need to start recv.
@@ -310,8 +315,8 @@ def request_finished(
310315 do_remote_decode = False ,
311316 remote_block_ids = computed_block_ids ,
312317 remote_engine_id = self .engine_id ,
313- remote_host = envs . VLLM_NIXL_SIDE_CHANNEL_HOST ,
314- remote_port = envs . VLLM_NIXL_SIDE_CHANNEL_PORT ,
318+ remote_host = self . side_channel_host ,
319+ remote_port = self . side_channel_port ,
315320 )
316321
317322
@@ -330,11 +335,19 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
330335 # Map of engine_id -> agent_name.
331336 self ._remote_agents : dict [str , str ] = {}
332337
338+ # NIXL handshake port.
339+ # NOTE(rob): Within a DP group, each DP rank gets its own
340+ # base port (which is sent in the KVTransferParams).
341+ # Each TP rank listens/queries on the base_port + tp_rank.
342+ self .side_channel_port = (
343+ envs .VLLM_NIXL_SIDE_CHANNEL_PORT +
344+ vllm_config .parallel_config .data_parallel_rank_local *
345+ vllm_config .parallel_config .tensor_parallel_size )
346+
333347 # Metadata.
334348 self .engine_id = engine_id
335- self .rank = get_tensor_model_parallel_rank ()
349+ self .tp_rank = get_tensor_model_parallel_rank ()
336350 self .world_size = get_tensor_model_parallel_world_size ()
337- self .world_rank = get_world_group ().rank_in_group
338351 self .tp_group = get_tp_group ()
339352
340353 # KV Caches and nixl tracking data.
@@ -383,16 +396,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
383396
384397 @staticmethod
385398 def _nixl_handshake_listener (metadata : NixlAgentMetadata ,
386- ready_event : threading .Event ,
387- world_rank : int ):
399+ ready_event : threading .Event , base_port : int ,
400+ tp_rank : int ):
388401 """Background thread for getting new NIXL handshakes."""
389402 # NOTE(rob): this is a simple implementation. We will move
390- # to a better approach like an ETCD server in the future.
391-
392- # NOTE(rob): to support heterogeneous TP, we will have to
393- # move this into the scheduler rather than worker, since
394- # each rank needs the metadata of all other ranks (whereas
395- # in this setup, each rank only gets one other rank's meta.
403+ # to a better approach via HTTP endpoint soon.
396404
397405 encoder = msgspec .msgpack .Encoder ()
398406 encoded_data = encoder .encode (metadata )
@@ -402,11 +410,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
402410
403411 # Listen for new requests for metadata.
404412 host = envs .VLLM_NIXL_SIDE_CHANNEL_HOST
405- # NOTE(rob): we need each rank to have a unique port. This
406- # hack to keeps us moving. We will switch when moving to etcd
407- # or where we have a single ZMQ socket in the scheduler.
408- port = envs .VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank
409- path = make_zmq_path ("tcp" , host , port )
413+ path = make_zmq_path ("tcp" , host , base_port + tp_rank )
410414 logger .debug ("Starting listening on path: %s" , path )
411415 with zmq_ctx (zmq .ROUTER , path ) as sock :
412416 ready_event .set ()
@@ -421,10 +425,10 @@ def _nixl_handshake(self, host: str, port: int):
421425 """Do a NIXL handshake with a remote instance."""
422426
423427 start_time = time .perf_counter ()
424- # NOTE(rob): we need each rank to have a unique port. This is
425- # a hack to keep us moving. We will switch when moving to etcd
426- # or where we have a single ZMQ socket in the scheduler .
427- path = make_zmq_path ("tcp" , host , port + self .world_rank )
428+ # NOTE(rob): we need each tp_rank to have a unique port.
429+ # This is a hack to keep us moving. We will switch when
430+ # we switch to HTTP-based NIXL metadata exchange .
431+ path = make_zmq_path ("tcp" , host , port + self .tp_rank )
428432 logger .debug ("Querying metadata on path: %s" , path )
429433 with zmq_ctx (zmq .REQ , path ) as sock :
430434 # Send query for the request.
@@ -532,7 +536,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
532536 ready_event = threading .Event ()
533537 self ._nixl_handshake_listener_t = threading .Thread (
534538 target = self ._nixl_handshake_listener ,
535- args = (metadata , ready_event , self .world_rank ),
539+ args = (metadata , ready_event , self .side_channel_port , self . tp_rank ),
536540 daemon = True ,
537541 name = "nixl_handshake_listener" )
538542 self ._nixl_handshake_listener_t .start ()
@@ -556,9 +560,9 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
556560 block_offset = block_id * self .block_len
557561 # (addr, len, device id)
558562 blocks_data .append (
559- (base_addr + block_offset , self .block_len , self .rank ))
560- logger .debug ("Created %s blocks for src engine %s and rank %s" ,
561- len (blocks_data ), self .engine_id , self .rank )
563+ (base_addr + block_offset , self .block_len , self .tp_rank ))
564+ logger .debug ("Created %s blocks for src engine %s and tp_rank %s" ,
565+ len (blocks_data ), self .engine_id , self .tp_rank )
562566
563567 # Register with NIXL.
564568 descs = self .nixl_wrapper .get_xfer_descs (blocks_data , "VRAM" )
@@ -573,9 +577,9 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
573577 block_offset = block_id * self .block_len
574578 # (addr, len, device id)
575579 blocks_data .append (
576- (base_addr + block_offset , self .block_len , self .rank ))
577- logger .debug ("Created %s blocks for dst engine %s and rank %s" ,
578- len (blocks_data ), engine_id , self .rank )
580+ (base_addr + block_offset , self .block_len , self .tp_rank ))
581+ logger .debug ("Created %s blocks for dst engine %s and tp_rank %s" ,
582+ len (blocks_data ), engine_id , self .tp_rank )
579583
580584 # Register with NIXL.
581585 descs = self .nixl_wrapper .get_xfer_descs (blocks_data , "VRAM" )
@@ -600,14 +604,14 @@ def get_finished(self) -> tuple[set[str], set[str]]:
600604 if len (done_sending ) > 0 or len (done_recving ) > 0 :
601605 logger .debug (
602606 "Rank %s, get_finished: %s requests done sending "
603- "and %s requests done recving" , self .rank , len ( done_sending ) ,
604- len (done_recving ))
607+ "and %s requests done recving" , self .tp_rank ,
608+ len (done_sending ), len ( done_recving ))
605609
606610 if self .world_size == 1 :
607611 return done_sending , done_recving
608612
609613 # Rank 0: get finished from all other ranks.
610- if self .rank == 0 :
614+ if self .tp_rank == 0 :
611615 for req_id in done_sending :
612616 self ._done_sending_count [req_id ] += 1
613617 for req_id in done_recving :
0 commit comments