Skip to content

Commit c74a53f

Browse files
committed
serve.llm: Fix port collisions for TP/PP with NIXL/LMCache
Extends port collision fix to Tensor Parallelism (TP) and Pipeline Parallelism (PP) scenarios. Previous fix (PR ray-project#55802) only addressed Data Parallelism by using explicit data_parallel_rank. Changes: - base.py: Added _compute_port_offset() method with fallback logic * Priority 1: Use data_parallel_rank if set (DP case) * Priority 2: Hash replica_tag for deterministic offset (TP/PP case) * Fallback: Return 0 - nixl_connector.py: Use _compute_port_offset() instead of dp_rank - lmcache_connector_v1.py: Add numeric port support with offset logic Fixes port collision errors in TP/PP deployments: - Multiple workers no longer bind to same port - Prevents NIXL_ERR_BACKEND and ZMQ errors - Enables successful deployment with pipeline_parallel_size > 1 Reproduction: Deployed Ray Serve with pipeline_parallel_size=2 and NIXL on Ray 3.0.0.dev0 (8 x L4 GPU cluster). Before fix, all workers used identical port (e.g., 52910), causing NIXL_ERR_BACKEND. Logs showed: 'Creating v1 connector with engine_id: ...-52910 [repeated 3x]' After fix, each worker receives unique port via replica tag hashing, eliminating collisions. Related: ray-project#55775 Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
1 parent 6f9ef13 commit c74a53f

File tree

3 files changed

+63
-19
lines changed

3 files changed

+63
-19
lines changed

python/ray/llm/_internal/serve/deployments/llm/vllm/kv_transfer_backends/base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,38 @@ def _get_unique_suffix(self, len: int = 6) -> str:
3535
"""
3636
return "".join(random.choices(string.ascii_letters + string.digits, k=len))
3737

38+
def _compute_port_offset(self) -> int:
39+
"""Compute a deterministic port offset for this replica/process.
40+
41+
Priority:
42+
1) data_parallel_rank if present (set by DPServer).
43+
2) Stable hash of Serve replica tag (avoids cross-replica collisions when TP/PP only).
44+
45+
Returns:
46+
A small non-negative integer to add to a base port.
47+
"""
48+
# Prefer explicit DP rank when available
49+
dp_rank = self.llm_config.engine_kwargs.get("data_parallel_rank")
50+
if isinstance(dp_rank, int) and dp_rank >= 0:
51+
return dp_rank
52+
53+
# Fall back to a stable hash of the Serve replica tag if available
54+
try:
55+
# Import locally to avoid import-time side effects
56+
from ray import serve # type: ignore
57+
58+
rc = serve.get_replica_context()
59+
if rc and getattr(rc, "replica_tag", None):
60+
import zlib
61+
62+
# Keep the offset bounded to avoid large jumps
63+
return zlib.adler32(rc.replica_tag.encode("utf-8")) % 1024
64+
except Exception:
65+
# Best-effort fallback; avoid introducing failures in setup paths
66+
pass
67+
68+
return 0
69+
3870
@abc.abstractmethod
3971
def setup(self) -> None:
4072
"""Setup the connector backend.

python/ray/llm/_internal/serve/deployments/llm/vllm/kv_transfer_backends/lmcache_connector_v1.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ def setup(self) -> None:
2525
"""Initialize the LMCache connector backend.
2626
This method sets up the LMCache connector by:
2727
1. Checking if LMCache is installed.
28-
2. Configuring the LMCache RPC port if not already set.
29-
3. Creating a unique LMCache RPC port across replicas.
28+
2. Configuring the LMCache RPC port name/value if not already set.
29+
3. Creating a unique LMCache RPC port across replicas either by
30+
appending a random suffix (default behavior for string port names),
31+
or by adding a rank-based integer offset when a numeric base is provided.
3032
Raises:
3133
ImportError: If LMCache is not installed.
3234
"""
@@ -41,21 +43,34 @@ def setup(self) -> None:
4143
kv_connector_extra_config = self.kv_transfer_config[
4244
LMCacheConnectorV1Backend.KV_CONNECTOR_EXTRA_CONFIG_FIELD_NAME
4345
]
44-
lmcache_rpc_port = (
45-
kv_connector_extra_config.get(
46-
LMCacheConnectorV1Backend.LMCACHE_RPC_PORT_FIELD_NAME,
47-
LMCacheConnectorV1Backend.DEFAULT_LMCACHE_RPC_PORT_NAME,
48-
)
49-
+ self._get_unique_suffix()
46+
# Determine the desired style of RPC port configuration.
47+
# If user passes a numeric base (e.g., 50000), add a deterministic
48+
# rank-based offset to avoid collisions across DP/TP/PP.
49+
# Otherwise, default to string-based name + random suffix.
50+
base_value = kv_connector_extra_config.get(
51+
LMCacheConnectorV1Backend.LMCACHE_RPC_PORT_FIELD_NAME,
52+
LMCacheConnectorV1Backend.DEFAULT_LMCACHE_RPC_PORT_NAME,
5053
)
51-
if (
52-
LMCacheConnectorV1Backend.LMCACHE_RPC_PORT_FIELD_NAME
53-
in kv_connector_extra_config
54-
):
54+
55+
if isinstance(base_value, int):
56+
# Numeric base; add rank-based offset and set as int
57+
offset = self._compute_port_offset()
58+
lmcache_rpc_port_value = int(base_value) + int(offset)
5559
logger.info(
56-
f"Setting unique {lmcache_rpc_port=} for current replica LMCacheConnectorV1."
60+
f"Setting LMCache numeric rpc port base={base_value} offset={offset} value={lmcache_rpc_port_value}."
5761
)
62+
else:
63+
# String name; append random suffix for uniqueness
64+
base_str = str(base_value)
65+
lmcache_rpc_port_value = base_str + self._get_unique_suffix()
66+
if (
67+
LMCacheConnectorV1Backend.LMCACHE_RPC_PORT_FIELD_NAME
68+
in kv_connector_extra_config
69+
):
70+
logger.info(
71+
f"Setting unique lmcache_rpc_port={lmcache_rpc_port_value} for current replica LMCacheConnectorV1."
72+
)
5873

5974
kv_connector_extra_config[
6075
LMCacheConnectorV1Backend.LMCACHE_RPC_PORT_FIELD_NAME
61-
] = lmcache_rpc_port
76+
] = lmcache_rpc_port_value

python/ray/llm/_internal/serve/deployments/llm/vllm/kv_transfer_backends/nixl_connector.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@ def _set_side_channel_port(self):
1515
"NIXL_SIDE_CHANNEL_PORT_BASE", vllm_utils.get_open_port()
1616
)
1717
)
18-
# If dp_rank is set, we should use the
19-
# base port + dp_rank as the side channel port
20-
# due to a potential ray condition for getting the free ports.
21-
dp_rank = self.llm_config.engine_kwargs.get("data_parallel_rank", 0)
22-
port = base_port + dp_rank
18+
# Use a deterministic rank-based offset (DP rank if set; else replica hash)
19+
port = base_port + self._compute_port_offset()
2320
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(port)
2421

2522
def _set_side_channel_host(self):

0 commit comments

Comments
 (0)