Skip to content

Conversation

@nrghosh
Copy link
Contributor

@nrghosh nrghosh commented Oct 24, 2025

Multiply replica_rank by num_devices (tp × pp) to prevent port collisions when scaling to 2+ replicas with TP≥2 or PP≥2.

Root Cause

PR #57771 fixed port collisions in python/ray/llm/_internal/serve/engines/vllm/kv_transfer/base.py for TP/PP by using Ray Serve's replica_rank for port offsets instead of defaulting to 0. However, the implementation doesn't account for port spacing needed when each replica spawns multiple workers - so it could still lead to overlap.

Main issue: Consecutive replicas get consecutive port offsets (0, 1, 2, ...), but each replica actually needs num_devices (tp × pp) consecutive ports for its workers. This causes port ranges to overlap between replicas.

Example: 2 replicas, TP=2

Current implementation in base.py:_compute_port_offset():

  return rc.rank  # Returns 0, 1, 2, ...

Port allocation:

  Replica 0 (rank=0): offset=0  → base port 50000 → workers use [50000, 50001]
  Replica 1 (rank=1): offset=1  → base port 50001 → workers use [50001, 50002]
                                                                   ^^^^^ Collision!

Replica 0 Worker 1 and Replica 1 Worker 0 both bind to port 50001.

Example: 2 replicas, TP=2, PP=2

Each replica spawns 4 workers (tp × pp = 2 × 2 = 4). Each worker needs a unique port.

Current implementation (incorrect):
  Replica 0 (rank=0): offset=0  → workers use [50000, 50001, 50002, 50003]
  Replica 1 (rank=1): offset=1  → workers use [50001, 50002, 50003, 50004]
                                              ^^^^^ Multiple collisions!

With fix (multiply by num_devices = 4):
  Replica 0 (rank=0): offset=0  → workers use [50000, 50001, 50002, 50003]
  Replica 1 (rank=1): offset=4  → workers use [50004, 50005, 50006, 50007]
                                              ✓ No collisions

Solution:

Space replicas by num_devices (tp × pp) ports to reserve room for all workers:

Replica 0 uses ports: [base, base+1, ..., base+(num_devices-1)]
Replica 1 uses ports: [base+num_devices, base+num_devices+1, ...]

The fix uses llm_config.get_engine_config().num_devices which correctly accounts for both TP and PP workers.

Impact:

  • Fixes port collisions when autoscaling to 2+ replicas with TP≥2 or PP≥2
  • Handles combined TP+PP scenarios correctly (e.g., TP=2, PP=2 requires 4 ports per replica)
  • Backward compatible: TP=1, PP=1 multiplies by 1 (no-op)
  • DP deployments unchanged: vLLM handles spacing internally
  • Single replica deployments unchanged: no other replica to collide with

Note (about Data Parallel)

DP deployments don't need this fix because vLLM already multiplies data_parallel_rank by tp_size for the offset internally:

# vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:641-642
side_channel_port = base + (data_parallel_rank × tensor_parallel_size)

So for DP, the spacing is automatic - but for replica_rank, we do the offset multiplication ourselves since vLLM doesn't know about Ray Serve's replica concept. The fix uses num_devices instead of just tp_size to ensure PP workers also get unique ports.

Related: PR #57771, #55775, #58072

Multiplies replica_rank by tensor_parallel_size to prevent port collisions
when scaling to 2+ replicas with TP≥2.

Problem:
PR ray-project#57771 fixed inter-replica port collisions by using replica_rank instead
of defaulting to 0. However, it didn't account for the port space needed by
TP workers within each replica.

vLLM workers add their tp_rank (0, 1, ..., tp_size-1) to the base port at
bind time (vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:790).
Without proper spacing, consecutive replicas have overlapping port ranges:
  Replica 0 TP Worker 1: base + 0 + 1 = 50001
  Replica 1 TP Worker 0: base + 1 + 0 = 50001  ← Collision

Solution:
Space replicas by tp_size ports to reserve room for all TP workers:
  Replica 0 uses ports: [base, base+1, ..., base+(tp_size-1)]
  Replica 1 uses ports: [base+tp_size, base+tp_size+1, ...]

Impact:
- Fixes port collisions when autoscaling to 2+ replicas with TP≥2
- Backward compatible: TP=1 multiplies by 1 (no-op)
- DP deployments unchanged: vLLM handles spacing
- Single replica deployments unchanged: no other replica to collide with

Related: PR ray-project#57771, ray-project#55775

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
@nrghosh nrghosh closed this Oct 24, 2025
@nrghosh nrghosh reopened this Oct 24, 2025
Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
@nrghosh nrghosh added the go add ONLY when ready to merge, run all tests label Oct 24, 2025
@nrghosh nrghosh self-assigned this Oct 24, 2025
@nrghosh nrghosh marked this pull request as ready for review October 28, 2025 17:04
@nrghosh nrghosh requested a review from a team as a code owner October 28, 2025 17:04
@ray-gardener ray-gardener bot added serve Ray Serve Related Issue llm labels Oct 28, 2025
@nrghosh nrghosh requested a review from kouroshHakha October 28, 2025 23:23
@kouroshHakha kouroshHakha changed the title serve.llm: Add TP spacing to port offset for multi-replica deployments [serve][llm] Add TP spacing to port offset for multi-replica deployments Oct 30, 2025
Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix would still have a problem when we have TP2PP2, because it doesn't consider PP at all. You should use a generic num_device API which already exist in llmconfig --> engine_config.

return rc.rank
# Multiply by tp_size to reserve ports for all TP workers
# Each TP worker will add its tp_rank (0, 1, ..., tp_size-1)
return rc.rank * tp_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to offset by tp * pp . Effectively you should use llm_config.get_engine_config().num_devices

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Previous fix didn't quite get it right for TPXPPY scenario.
Use llm_config.get_engine_config().num_devices instead of manually
calculating tp_size, ensuring proper port spacing for both TP and PP
workers.

Fixes the case where PP workers also bind NIXL ports and need
spacing in addition to TP workers.

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
@nrghosh nrghosh changed the title [serve][llm] Add TP spacing to port offset for multi-replica deployments [serve][llm] Add TP*PP spacing to port offset for multi-replica deployments Oct 30, 2025
@kouroshHakha kouroshHakha merged commit 71a2f40 into ray-project:master Oct 31, 2025
6 checks passed
YoussefEssDS pushed a commit to YoussefEssDS/ray that referenced this pull request Nov 8, 2025
…yments (ray-project#58073)

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
landscapepainter pushed a commit to landscapepainter/ray that referenced this pull request Nov 17, 2025
…yments (ray-project#58073)

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
Aydin-ab pushed a commit to Aydin-ab/ray-aydin that referenced this pull request Nov 19, 2025
…yments (ray-project#58073)

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
Signed-off-by: Aydin Abiar <aydin@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go add ONLY when ready to merge, run all tests llm serve Ray Serve Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants