|  | 
| 7 | 7 | import dataclasses | 
| 8 | 8 | import datetime | 
| 9 | 9 | import pickle | 
|  | 10 | +import socket | 
| 10 | 11 | import time | 
| 11 | 12 | from collections import deque | 
| 12 | 13 | from typing import Any, Deque, Dict, Optional, Sequence, Tuple | 
| @@ -123,6 +124,10 @@ class StatelessProcessGroup: | 
| 123 | 124 |     rank: int | 
| 124 | 125 |     world_size: int | 
| 125 | 126 |     store: torch._C._distributed_c10d.Store | 
|  | 127 | + | 
|  | 128 | +    # stores a reference to the socket so that the file descriptor stays alive | 
|  | 129 | +    socket: Optional[socket.socket] | 
|  | 130 | + | 
| 126 | 131 |     data_expiration_seconds: int = 3600  # 1 hour | 
| 127 | 132 | 
 | 
| 128 | 133 |     # dst rank -> counter | 
| @@ -234,18 +239,33 @@ def create( | 
| 234 | 239 |         can call `StatelessProcessGroup.create` to form a group, and then process A, B, | 
| 235 | 240 |         C, and D can call `StatelessProcessGroup.create` to form another group. | 
| 236 | 241 |         """ # noqa | 
|  | 242 | +        launch_server = rank == 0 | 
|  | 243 | +        if launch_server: | 
|  | 244 | +            # listen on the specified interface (instead of 0.0.0.0) | 
|  | 245 | +            listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | 
|  | 246 | +            listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | 
|  | 247 | +            listen_socket.bind((host, port)) | 
|  | 248 | +            listen_socket.listen() | 
|  | 249 | +            listen_fd = listen_socket.fileno() | 
|  | 250 | +        else: | 
|  | 251 | +            listen_socket = None | 
|  | 252 | +            listen_fd = None | 
|  | 253 | + | 
| 237 | 254 |         store = TCPStore( | 
| 238 | 255 |             host_name=host, | 
| 239 | 256 |             port=port, | 
| 240 | 257 |             world_size=world_size, | 
| 241 |  | -            is_master=(rank == 0), | 
|  | 258 | +            is_master=launch_server, | 
| 242 | 259 |             timeout=datetime.timedelta(seconds=store_timeout), | 
|  | 260 | +            use_libuv=False,  # for now: github.com/pytorch/pytorch/pull/150215 | 
|  | 261 | +            master_listen_fd=listen_fd, | 
| 243 | 262 |         ) | 
| 244 | 263 | 
 | 
| 245 | 264 |         return StatelessProcessGroup( | 
| 246 | 265 |             rank=rank, | 
| 247 | 266 |             world_size=world_size, | 
| 248 | 267 |             store=store, | 
|  | 268 | +            socket=listen_socket, | 
| 249 | 269 |             data_expiration_seconds=data_expiration_seconds) | 
| 250 | 270 | 
 | 
| 251 | 271 | 
 | 
|  | 
0 commit comments