@@ -90,14 +90,54 @@ def _destroy_dist_context():
9090 _set_model (_SerialModel ())
9191
9292
93+ def _find_free_port ():
94+ # Taken from https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/launch.py
95+ import socket
96+
97+ sock = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
98+ sock .bind (("" , 0 ))
99+ port = sock .getsockname ()[1 ]
100+ sock .close ()
101+ return port
102+
103+
104+ def _setup_free_port (local_rank ):
105+ import time
106+ import os
107+
108+ port_file = "/tmp/free_port"
109+
110+ if local_rank == 0 :
111+ port = _find_free_port ()
112+ with open (port_file , "w" ) as h :
113+ h .write (str (port ))
114+ return port
115+ else :
116+ counter = 10
117+ while counter > 0 :
118+ counter -= 1
119+ time .sleep (1 )
120+ if not os .path .exists (port_file ):
121+ continue
122+ with open (port_file , "r" ) as h :
123+ port = h .readline ()
124+ return int (port )
125+
126+ raise RuntimeError ("Failed to fetch free port on local rank {}" .format (local_rank ))
127+
128+
93129@pytest .fixture ()
94130def distributed_context_single_node_nccl (local_rank , world_size ):
95131
132+ free_port = _setup_free_port (local_rank )
133+
134+ print (local_rank , "Port:" , free_port )
135+
96136 dist_info = {
97137 "backend" : "nccl" ,
98138 "world_size" : world_size ,
99139 "rank" : local_rank ,
100- "init_method" : "tcp://localhost:2223" ,
140+ "init_method" : "tcp://localhost:{}" . format ( free_port ) ,
101141 }
102142 yield _create_dist_context (dist_info , local_rank )
103143 _destroy_dist_context ()
@@ -108,11 +148,14 @@ def distributed_context_single_node_gloo(local_rank, world_size):
108148
109149 from datetime import timedelta
110150
111- init_method = "tcp://localhost:2223"
112- temp_file = None
113151 if sys .platform .startswith ("win" ):
114152 temp_file = tempfile .NamedTemporaryFile (delete = False )
115153 init_method = "file:///{}" .format (temp_file .name .replace ("\\ " , "/" ))
154+ else :
155+ free_port = _setup_free_port (local_rank )
156+ print (local_rank , "Port:" , free_port )
157+ init_method = "tcp://localhost:{}" .format (free_port )
158+ temp_file = None
116159
117160 dist_info = {
118161 "backend" : "gloo" ,
0 commit comments