Skip to content

Commit a5bd628

Browse files
wseatonamitm02
authored andcommitted
[P/D] NixlConnector DP fixes (vllm-project#18903)
Signed-off-by: Will Eaton <weaton@redhat.com> Signed-off-by: amit <amit.man@gmail.com>
1 parent ee9dd34 commit a5bd628

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

vllm/distributed/kv_transfer/kv_connector/factory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def create_connector_v1(
7070
connector_module = importlib.import_module(connector_module_path)
7171
connector_cls = getattr(connector_module, connector_name)
7272
assert issubclass(connector_cls, KVConnectorBase_V1)
73-
logger.info("Creating v1 connector with name: %s", connector_name)
73+
logger.info("Creating v1 connector with name: %s and engine_id: %s",
74+
connector_name, kv_transfer_config.engine_id)
7475
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
7576
# Scheduler connector:
7677
# - Co-locate with scheduler process

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
2020
from vllm.distributed.parallel_state import (
2121
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
22-
get_tp_group)
22+
get_tp_group, get_world_group)
2323
from vllm.logger import init_logger
2424
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
2525
from vllm.v1.core.sched.output import SchedulerOutput
@@ -334,6 +334,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
334334
self.engine_id = engine_id
335335
self.rank = get_tensor_model_parallel_rank()
336336
self.world_size = get_tensor_model_parallel_world_size()
337+
self.world_rank = get_world_group().rank_in_group
337338
self.tp_group = get_tp_group()
338339

339340
# KV Caches and nixl tracking data.
@@ -382,7 +383,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
382383

383384
@staticmethod
384385
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
385-
ready_event: threading.Event, rank: int):
386+
ready_event: threading.Event,
387+
world_rank: int):
386388
"""Background thread for getting new NIXL handshakes."""
387389
# NOTE(rob): this is a simple implementation. We will move
388390
# to a better approach like an ETCD server in the future.
@@ -403,7 +405,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
403405
# NOTE(rob): we need each rank to have a unique port. This
404406
# hack to keeps us moving. We will switch when moving to etcd
405407
# or where we have a single ZMQ socket in the scheduler.
406-
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
408+
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank
407409
path = make_zmq_path("tcp", host, port)
408410
logger.debug("Starting listening on path: %s", path)
409411
with zmq_ctx(zmq.ROUTER, path) as sock:
@@ -422,7 +424,7 @@ def _nixl_handshake(self, host: str, port: int):
422424
# NOTE(rob): we need each rank to have a unique port. This is
423425
# a hack to keep us moving. We will switch when moving to etcd
424426
# or where we have a single ZMQ socket in the scheduler.
425-
path = make_zmq_path("tcp", host, port + self.rank)
427+
path = make_zmq_path("tcp", host, port + self.world_rank)
426428
logger.debug("Querying metadata on path: %s", path)
427429
with zmq_ctx(zmq.REQ, path) as sock:
428430
# Send query for the request.
@@ -529,7 +531,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
529531
ready_event = threading.Event()
530532
self._nixl_handshake_listener_t = threading.Thread(
531533
target=self._nixl_handshake_listener,
532-
args=(metadata, ready_event, self.rank),
534+
args=(metadata, ready_event, self.world_rank),
533535
daemon=True,
534536
name="nixl_handshake_listener")
535537
self._nixl_handshake_listener_t.start()

vllm/v1/engine/core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,15 @@ def _init_data_parallel(self, vllm_config: VllmConfig):
707707
assert dp_size > 1
708708
assert 0 <= local_dp_rank <= dp_rank < dp_size
709709

710+
if vllm_config.kv_transfer_config is not None:
711+
# modify the engine_id and append the local_dp_rank to it to ensure
712+
# that the kv_transfer_config is unique for each DP rank.
713+
vllm_config.kv_transfer_config.engine_id = (
714+
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
715+
)
716+
logger.debug("Setting kv_transfer_config.engine_id to %s",
717+
vllm_config.kv_transfer_config.engine_id)
718+
710719
from vllm.platforms import current_platform
711720
device_control_env_var = current_platform.device_control_env_var
712721
world_size = vllm_config.parallel_config.world_size

0 commit comments

Comments
 (0)