2323from smdebug .pytorch .singleton_utils import set_hook
2424from smdebug .pytorch .utils import get_reduction_of_data
2525
26- try :
27- import smdistributed .dataparallel .torch .distributed as smdataparallel
28- except ImportError :
29- smdataparallel = None
26+ # smdistributed.dataparallel should be invoked via `mpirun`.
27+ # It supports EC2 machines with 8 GPUs per machine.
28+ _is_invoked_via_mpi = (
29+ os .getenv ("OMPI_COMM_WORLD_SIZE" ) is not None and int (os .getenv ("OMPI_COMM_WORLD_SIZE" )) >= 8
30+ )
31+ if _is_invoked_via_mpi :
32+ try :
33+ import smdistributed .dataparallel .torch .distributed as smdataparallel
34+ except ImportError :
35+ smdataparallel = None
3036
3137
3238DEFAULT_INCLUDE_COLLECTIONS = [CollectionKeys .LOSSES ]
@@ -185,13 +191,20 @@ def _get_num_workers(self):
185191 pass
186192
187193 # Try smdataparallel
188- try :
189- import smdistributed .dataparallel .torch .distributed as smdataparallel
190-
191- if smdataparallel .get_world_size ():
192- return smdataparallel .get_world_size ()
193- except (ModuleNotFoundError , ValueError , ImportError ):
194- pass
194+ # smdistributed.dataparallel should be invoked via `mpirun`.
195+ # It supports EC2 machines with 8 GPUs per machine.
196+ _is_invoked_via_mpi = (
197+ os .getenv ("OMPI_COMM_WORLD_SIZE" ) is not None
198+ and int (os .getenv ("OMPI_COMM_WORLD_SIZE" )) >= 8
199+ )
200+ if _is_invoked_via_mpi :
201+ try :
202+ import smdistributed .dataparallel .torch .distributed as smdataparallel
203+
204+ if smdataparallel .get_world_size ():
205+ return smdataparallel .get_world_size ()
206+ except (ModuleNotFoundError , ValueError , ImportError ):
207+ pass
195208 # Return default
196209 return 1
197210
@@ -212,13 +225,20 @@ def _get_worker_name(self):
212225 pass
213226
214227 # Try smdataparallel
215- try :
216- import smdistributed .dataparallel .torch .distributed as smdataparallel
217-
218- if smdataparallel .get_world_size ():
219- return f"worker_{ smdataparallel .get_rank ()} "
220- except (ModuleNotFoundError , ValueError , ImportError ):
221- pass
228+ # smdistributed.dataparallel should be invoked via `mpirun`.
229+ # It supports EC2 machines with 8 GPUs per machine.
230+ _is_invoked_via_mpi = (
231+ os .getenv ("OMPI_COMM_WORLD_SIZE" ) is not None
232+ and int (os .getenv ("OMPI_COMM_WORLD_SIZE" )) >= 8
233+ )
234+ if _is_invoked_via_mpi :
235+ try :
236+ import smdistributed .dataparallel .torch .distributed as smdataparallel
237+
238+ if smdataparallel .get_world_size ():
239+ return f"worker_{ smdataparallel .get_rank ()} "
240+ except (ModuleNotFoundError , ValueError , ImportError ):
241+ pass
222242 # Return default
223243 return DEFAULT_WORKER_NAME
224244
0 commit comments