Skip to content

Commit d9b0a85

Browse files
authored
[bugfix][serve][llm] Fix port collisions for TP/PP with NIXL/LMCache (#57771)
Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
1 parent 11fd60e commit d9b0a85

File tree

3 files changed

+41
-17
lines changed

3 files changed

+41
-17
lines changed

python/ray/llm/_internal/serve/engines/vllm/kv_transfer/base.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import string
44
from typing import TYPE_CHECKING, Any, Dict
55

6+
from ray import serve
7+
68
if TYPE_CHECKING:
79
from ray.llm._internal.serve.core.configs.llm_config import LLMConfig
810

@@ -35,6 +37,31 @@ def _get_unique_suffix(self, len: int = 6) -> str:
3537
"""
3638
return "".join(random.choices(string.ascii_letters + string.digits, k=len))
3739

40+
def _compute_port_offset(self) -> int:
41+
"""Compute a deterministic port offset for this replica.
42+
43+
Uses data_parallel_rank if DP case, otherwise falls back to
44+
the replica rank assigned by Ray Serve (TP/PP case).
45+
46+
Returns:
47+
Non-negative integer offset to add to a base port.
48+
"""
49+
# Prefer explicit DP rank when available
50+
dp_rank = self.llm_config.engine_kwargs.get("data_parallel_rank")
51+
if isinstance(dp_rank, int) and dp_rank >= 0:
52+
return dp_rank
53+
54+
# Fall back to Serve replica rank for TP/PP cases
55+
try:
56+
rc = serve.get_replica_context()
57+
if rc and hasattr(rc, "rank"):
58+
return rc.rank
59+
except Exception:
60+
# Best-effort fallback; avoid introducing failures in setup paths
61+
pass
62+
63+
return 0
64+
3865
@abc.abstractmethod
3966
def setup(self) -> None:
4067
"""Setup the connector backend.

python/ray/llm/_internal/serve/engines/vllm/kv_transfer/lmcache.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ class LMCacheConnectorV1Backend(BaseConnectorBackend):
2323

2424
def setup(self) -> None:
2525
"""Initialize the LMCache connector backend.
26-
This method sets up the LMCache connector by:
27-
1. Checking if LMCache is installed.
28-
2. Configuring the LMCache RPC port if not already set.
29-
3. Creating a unique LMCache RPC port across replicas.
26+
27+
Creates a unique LMCache RPC port name across replicas by appending
28+
a random suffix to the base port name.
29+
3030
Raises:
3131
ImportError: If LMCache is not installed.
3232
"""
@@ -41,21 +41,21 @@ def setup(self) -> None:
4141
kv_connector_extra_config = self.kv_transfer_config[
4242
LMCacheConnectorV1Backend.KV_CONNECTOR_EXTRA_CONFIG_FIELD_NAME
4343
]
44-
lmcache_rpc_port = (
45-
kv_connector_extra_config.get(
46-
LMCacheConnectorV1Backend.LMCACHE_RPC_PORT_FIELD_NAME,
47-
LMCacheConnectorV1Backend.DEFAULT_LMCACHE_RPC_PORT_NAME,
48-
)
49-
+ self._get_unique_suffix()
44+
base_value = kv_connector_extra_config.get(
45+
LMCacheConnectorV1Backend.LMCACHE_RPC_PORT_FIELD_NAME,
46+
LMCacheConnectorV1Backend.DEFAULT_LMCACHE_RPC_PORT_NAME,
5047
)
48+
49+
# Append random suffix for uniqueness
50+
lmcache_rpc_port_value = str(base_value) + self._get_unique_suffix()
5151
if (
5252
LMCacheConnectorV1Backend.LMCACHE_RPC_PORT_FIELD_NAME
5353
in kv_connector_extra_config
5454
):
5555
logger.info(
56-
f"Setting unique {lmcache_rpc_port=} for current replica LMCacheConnectorV1."
56+
f"Setting unique lmcache_rpc_port={lmcache_rpc_port_value} for current replica."
5757
)
5858

5959
kv_connector_extra_config[
6060
LMCacheConnectorV1Backend.LMCACHE_RPC_PORT_FIELD_NAME
61-
] = lmcache_rpc_port
61+
] = lmcache_rpc_port_value

python/ray/llm/_internal/serve/engines/vllm/kv_transfer/nixl.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@ def _set_side_channel_port(self):
1515
"NIXL_SIDE_CHANNEL_PORT_BASE", vllm_utils.get_open_port()
1616
)
1717
)
18-
# If dp_rank is set, we should use the
19-
# base port + dp_rank as the side channel port
20-
# due to a potential ray condition for getting the free ports.
21-
dp_rank = self.llm_config.engine_kwargs.get("data_parallel_rank", 0)
22-
port = base_port + dp_rank
18+
# Use a deterministic rank-based offset (DP rank if set; else replica hash)
19+
port = base_port + self._compute_port_offset()
2320
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(port)
2421

2522
def _set_side_channel_host(self):

0 commit comments

Comments
 (0)