Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions python/ray/llm/_internal/serve/engines/vllm/kv_transfer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import string
from typing import TYPE_CHECKING, Any, Dict

from ray import serve

if TYPE_CHECKING:
from ray.llm._internal.serve.core.configs.llm_config import LLMConfig

Expand Down Expand Up @@ -35,6 +37,31 @@ def _get_unique_suffix(self, len: int = 6) -> str:
"""
return "".join(random.choices(string.ascii_letters + string.digits, k=len))

def _compute_port_offset(self) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should just use the replica rank to do this I feel like.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep

now _compute_port_offset() uses replica_rank from replica context instead of hashing approach

so now logic is:

  1. Use data_parallel_rank if explicitly set (DP deployments via DPServer)
  2. Fall back to replica_rank from serve context (TP/PP deployments)
  3. Return 0 as final fallback

"""Compute a deterministic port offset for this replica.

Uses data_parallel_rank if DP case, otherwise falls back to
the replica rank assigned by Ray Serve (TP/PP case).

Returns:
Non-negative integer offset to add to a base port.
"""
# Prefer explicit DP rank when available
dp_rank = self.llm_config.engine_kwargs.get("data_parallel_rank")
if isinstance(dp_rank, int) and dp_rank >= 0:
return dp_rank

# Fall back to Serve replica rank for TP/PP cases
try:
rc = serve.get_replica_context()
if rc and hasattr(rc, "rank"):
return rc.rank
except Exception:
# Best-effort fallback; avoid introducing failures in setup paths
pass

return 0

@abc.abstractmethod
def setup(self) -> None:
"""Setup the connector backend.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ class LMCacheConnectorV1Backend(BaseConnectorBackend):

def setup(self) -> None:
"""Initialize the LMCache connector backend.
This method sets up the LMCache connector by:
1. Checking if LMCache is installed.
2. Configuring the LMCache RPC port if not already set.
3. Creating a unique LMCache RPC port across replicas.

Creates a unique LMCache RPC port name across replicas by appending
a random suffix to the base port name.

Raises:
ImportError: If LMCache is not installed.
"""
Expand All @@ -41,21 +41,21 @@ def setup(self) -> None:
kv_connector_extra_config = self.kv_transfer_config[
LMCacheConnectorV1Backend.KV_CONNECTOR_EXTRA_CONFIG_FIELD_NAME
]
lmcache_rpc_port = (
kv_connector_extra_config.get(
LMCacheConnectorV1Backend.LMCACHE_RPC_PORT_FIELD_NAME,
LMCacheConnectorV1Backend.DEFAULT_LMCACHE_RPC_PORT_NAME,
)
+ self._get_unique_suffix()
base_value = kv_connector_extra_config.get(
LMCacheConnectorV1Backend.LMCACHE_RPC_PORT_FIELD_NAME,
LMCacheConnectorV1Backend.DEFAULT_LMCACHE_RPC_PORT_NAME,
)

# Append random suffix for uniqueness
lmcache_rpc_port_value = str(base_value) + self._get_unique_suffix()
if (
LMCacheConnectorV1Backend.LMCACHE_RPC_PORT_FIELD_NAME
in kv_connector_extra_config
):
logger.info(
f"Setting unique {lmcache_rpc_port=} for current replica LMCacheConnectorV1."
f"Setting unique lmcache_rpc_port={lmcache_rpc_port_value} for current replica."
)

kv_connector_extra_config[
LMCacheConnectorV1Backend.LMCACHE_RPC_PORT_FIELD_NAME
] = lmcache_rpc_port
] = lmcache_rpc_port_value
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@ def _set_side_channel_port(self):
"NIXL_SIDE_CHANNEL_PORT_BASE", vllm_utils.get_open_port()
)
)
# If dp_rank is set, we should use the
# base port + dp_rank as the side channel port
# due to a potential ray condition for getting the free ports.
dp_rank = self.llm_config.engine_kwargs.get("data_parallel_rank", 0)
port = base_port + dp_rank
# Use a deterministic rank-based offset (DP rank if set; else replica hash)
port = base_port + self._compute_port_offset()
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(port)

def _set_side_channel_host(self):
Expand Down