Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,10 +593,14 @@ def __post_init__(self) -> None:
"max_parallel_loading_workers is currently "
"not supported and will be ignored."
)
if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1:
allowed_backends = ("mp", "uni", "external_launcher")
if (
self.distributed_executor_backend not in allowed_backends
and self.nnodes > 1
):
raise ValueError(
"nnodes > 1 can only be set when distributed executor "
"backend is mp or uni."
"backend is mp, uni or external_launcher."
)

@property
Expand Down
39 changes: 21 additions & 18 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,33 +1169,36 @@ def init_distributed_environment(
from vllm.config import get_current_vllm_config

config = get_current_vllm_config()
if config is not None and config.parallel_config.nnodes > 1:
parallel_config = config.parallel_config
ip = parallel_config.master_addr
rank = parallel_config.data_parallel_rank * world_size + rank
world_size = parallel_config.world_size_across_dp
port = parallel_config.master_port
distributed_init_method = get_distributed_init_method(ip, port)
elif (
if (
config is not None
and config.parallel_config.data_parallel_size > 1
and config.parallel_config.distributed_executor_backend != "external_launcher"
and (
config.parallel_config.nnodes > 1
or config.parallel_config.data_parallel_size > 1
)
):
parallel_config = config.parallel_config
# adjust to take into account data parallelism
# offset the rank by the data parallel rank
rank = parallel_config.data_parallel_rank * world_size + rank
# adjust the world size to take into account data parallelism
world_size = parallel_config.world_size_across_dp
ip = parallel_config.data_parallel_master_ip
port = parallel_config.get_next_dp_init_port()
distributed_init_method = get_distributed_init_method(ip, port)
logger.debug(
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
world_size,
rank,
distributed_init_method,
)

# Use appropriate IP and port based on configuration
if parallel_config.nnodes > 1:
ip = parallel_config.master_addr
port = parallel_config.master_port
distributed_init_method = get_distributed_init_method(ip, port)
else:
ip = parallel_config.data_parallel_master_ip
port = parallel_config.get_next_dp_init_port()
distributed_init_method = get_distributed_init_method(ip, port)
logger.debug(
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
world_size,
rank,
distributed_init_method,
)
if not torch.distributed.is_initialized():
logger.info(
"world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
Expand Down