| 
7 | 7 | import dataclasses  | 
8 | 8 | import datetime  | 
9 | 9 | import pickle  | 
 | 10 | +import socket  | 
10 | 11 | import time  | 
11 | 12 | from collections import deque  | 
12 | 13 | from typing import Any, Deque, Dict, Optional, Sequence, Tuple  | 
@@ -123,6 +124,10 @@ class StatelessProcessGroup:  | 
123 | 124 |     rank: int  | 
124 | 125 |     world_size: int  | 
125 | 126 |     store: torch._C._distributed_c10d.Store  | 
 | 127 | + | 
 | 128 | +    # stores a reference to the socket so that the file descriptor stays alive  | 
 | 129 | +    socket: Optional[socket.socket]  | 
 | 130 | + | 
126 | 131 |     data_expiration_seconds: int = 3600  # 1 hour  | 
127 | 132 | 
 
  | 
128 | 133 |     # dst rank -> counter  | 
@@ -234,18 +239,33 @@ def create(  | 
234 | 239 |         can call `StatelessProcessGroup.create` to form a group, and then process A, B,  | 
235 | 240 |         C, and D can call `StatelessProcessGroup.create` to form another group.  | 
236 | 241 |         """ # noqa  | 
 | 242 | +        launch_server = rank == 0  | 
 | 243 | +        if launch_server:  | 
 | 244 | +            # listen on the specified interface (instead of 0.0.0.0)  | 
 | 245 | +            listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)  | 
 | 246 | +            listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)  | 
 | 247 | +            listen_socket.bind((host, port))  | 
 | 248 | +            listen_socket.listen()  | 
 | 249 | +            listen_fd = listen_socket.fileno()  | 
 | 250 | +        else:  | 
 | 251 | +            listen_socket = None  | 
 | 252 | +            listen_fd = None  | 
 | 253 | + | 
237 | 254 |         store = TCPStore(  | 
238 | 255 |             host_name=host,  | 
239 | 256 |             port=port,  | 
240 | 257 |             world_size=world_size,  | 
241 |  | -            is_master=(rank == 0),  | 
 | 258 | +            is_master=launch_server,  | 
242 | 259 |             timeout=datetime.timedelta(seconds=store_timeout),  | 
 | 260 | +            use_libuv=False,  # for now: github.com/pytorch/pytorch/pull/150215  | 
 | 261 | +            master_listen_fd=listen_fd,  | 
243 | 262 |         )  | 
244 | 263 | 
 
  | 
245 | 264 |         return StatelessProcessGroup(  | 
246 | 265 |             rank=rank,  | 
247 | 266 |             world_size=world_size,  | 
248 | 267 |             store=store,  | 
 | 268 | +            socket=listen_socket,  | 
249 | 269 |             data_expiration_seconds=data_expiration_seconds)  | 
250 | 270 | 
 
  | 
251 | 271 | 
 
  | 
 | 
0 commit comments