diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 2cb57afd4566..e4d4008cd0a6 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -7,6 +7,7 @@ import dataclasses import datetime import pickle +import socket import time from collections import deque from typing import Any, Deque, Dict, Optional, Sequence, Tuple @@ -123,6 +124,10 @@ class StatelessProcessGroup: rank: int world_size: int store: torch._C._distributed_c10d.Store + + # stores a reference to the socket so that the file descriptor stays alive + socket: Optional[socket.socket] + data_expiration_seconds: int = 3600 # 1 hour # dst rank -> counter @@ -234,18 +239,33 @@ def create( can call `StatelessProcessGroup.create` to form a group, and then process A, B, C, and D can call `StatelessProcessGroup.create` to form another group. """ # noqa + launch_server = rank == 0 + if launch_server: + # listen on the specified interface (instead of 0.0.0.0) + listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listen_socket.bind((host, port)) + listen_socket.listen() + listen_fd = listen_socket.fileno() + else: + listen_socket = None + listen_fd = None + store = TCPStore( host_name=host, port=port, world_size=world_size, - is_master=(rank == 0), + is_master=launch_server, timeout=datetime.timedelta(seconds=store_timeout), + use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 + master_listen_fd=listen_fd, ) return StatelessProcessGroup( rank=rank, world_size=world_size, store=store, + socket=listen_socket, data_expiration_seconds=data_expiration_seconds)