1919 KVConnectorBase_V1 , KVConnectorMetadata , KVConnectorRole )
2020from vllm .distributed .parallel_state import (
2121 get_tensor_model_parallel_rank , get_tensor_model_parallel_world_size ,
22- get_tp_group )
22+ get_tp_group , get_world_group )
2323from vllm .logger import init_logger
2424from vllm .utils import make_zmq_path , make_zmq_socket , round_down
2525from vllm .v1 .core .sched .output import SchedulerOutput
@@ -334,6 +334,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
334334 self .engine_id = engine_id
335335 self .rank = get_tensor_model_parallel_rank ()
336336 self .world_size = get_tensor_model_parallel_world_size ()
337+ self .world_rank = get_world_group ().rank_in_group
337338 self .tp_group = get_tp_group ()
338339
339340 # KV Caches and nixl tracking data.
@@ -382,7 +383,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
382383
383384 @staticmethod
384385 def _nixl_handshake_listener (metadata : NixlAgentMetadata ,
385- ready_event : threading .Event , rank : int ):
386+ ready_event : threading .Event ,
387+ world_rank : int ):
386388 """Background thread for getting new NIXL handshakes."""
387389 # NOTE(rob): this is a simple implementation. We will move
388390 # to a better approach like an ETCD server in the future.
@@ -403,7 +405,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
403405 # NOTE(rob): we need each rank to have a unique port. This
404406 # hack to keeps us moving. We will switch when moving to etcd
405407 # or where we have a single ZMQ socket in the scheduler.
406- port = envs .VLLM_NIXL_SIDE_CHANNEL_PORT + rank
408+ port = envs .VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank
407409 path = make_zmq_path ("tcp" , host , port )
408410 logger .debug ("Starting listening on path: %s" , path )
409411 with zmq_ctx (zmq .ROUTER , path ) as sock :
@@ -422,7 +424,7 @@ def _nixl_handshake(self, host: str, port: int):
422424 # NOTE(rob): we need each rank to have a unique port. This is
423425 # a hack to keep us moving. We will switch when moving to etcd
424426 # or where we have a single ZMQ socket in the scheduler.
425- path = make_zmq_path ("tcp" , host , port + self .rank )
427+ path = make_zmq_path ("tcp" , host , port + self .world_rank )
426428 logger .debug ("Querying metadata on path: %s" , path )
427429 with zmq_ctx (zmq .REQ , path ) as sock :
428430 # Send query for the request.
@@ -529,7 +531,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
529531 ready_event = threading .Event ()
530532 self ._nixl_handshake_listener_t = threading .Thread (
531533 target = self ._nixl_handshake_listener ,
532- args = (metadata , ready_event , self .rank ),
534+ args = (metadata , ready_event , self .world_rank ),
533535 daemon = True ,
534536 name = "nixl_handshake_listener" )
535537 self ._nixl_handshake_listener_t .start ()
0 commit comments