Skip to content

Commit 2aa3c9e

Browse files
committed
Cherry picking the changes from PR#138 to this branch
1 parent c384337 commit 2aa3c9e

File tree

3 files changed

+75
-41
lines changed

3 files changed

+75
-41
lines changed

smdebug/core/utils.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -343,21 +343,28 @@ def get_distributed_worker():
343343
except (ModuleNotFoundError, ValueError, ImportError):
344344
pass
345345

346-
try:
347-
import smdistributed.dataparallel.torch.distributed as smdataparallel
348-
349-
if smdataparallel.get_world_size():
350-
rank = smdataparallel.get_rank()
351-
except (ModuleNotFoundError, ValueError, ImportError):
352-
pass
353-
354-
try:
355-
import smdistributed.dataparallel.tensorflow as smdataparallel
356-
357-
if smdataparallel.size():
358-
rank = smdataparallel.rank()
359-
except (ModuleNotFoundError, ValueError, ImportError):
360-
pass
346+
# smdistributed.dataparallel should be invoked via `mpirun`.
347+
# It supports EC2 machines with 8 GPUs per machine.
348+
_is_invoked_via_mpi = (
349+
os.getenv("OMPI_COMM_WORLD_SIZE") is not None
350+
and int(os.getenv("OMPI_COMM_WORLD_SIZE")) >= 8
351+
)
352+
if _is_invoked_via_mpi:
353+
try:
354+
import smdistributed.dataparallel.torch.distributed as smdataparallel
355+
356+
if smdataparallel.get_world_size():
357+
rank = smdataparallel.get_rank()
358+
except (ModuleNotFoundError, ValueError, ImportError):
359+
pass
360+
361+
try:
362+
import smdistributed.dataparallel.tensorflow as smdataparallel
363+
364+
if smdataparallel.size():
365+
rank = smdataparallel.rank()
366+
except (ModuleNotFoundError, ValueError, ImportError):
367+
pass
361368
return rank
362369

363370

smdebug/pytorch/hook.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,16 @@
2323
from smdebug.pytorch.singleton_utils import set_hook
2424
from smdebug.pytorch.utils import get_reduction_of_data
2525

26-
try:
27-
import smdistributed.dataparallel.torch.distributed as smdataparallel
28-
except ImportError:
29-
smdataparallel = None
26+
# smdistributed.dataparallel should be invoked via `mpirun`.
27+
# It supports EC2 machines with 8 GPUs per machine.
28+
_is_invoked_via_mpi = (
29+
os.getenv("OMPI_COMM_WORLD_SIZE") is not None and int(os.getenv("OMPI_COMM_WORLD_SIZE")) >= 8
30+
)
31+
if _is_invoked_via_mpi:
32+
try:
33+
import smdistributed.dataparallel.torch.distributed as smdataparallel
34+
except ImportError:
35+
smdataparallel = None
3036

3137

3238
DEFAULT_INCLUDE_COLLECTIONS = [CollectionKeys.LOSSES]
@@ -185,13 +191,20 @@ def _get_num_workers(self):
185191
pass
186192

187193
# Try smdataparallel
188-
try:
189-
import smdistributed.dataparallel.torch.distributed as smdataparallel
190-
191-
if smdataparallel.get_world_size():
192-
return smdataparallel.get_world_size()
193-
except (ModuleNotFoundError, ValueError, ImportError):
194-
pass
194+
# smdistributed.dataparallel should be invoked via `mpirun`.
195+
# It supports EC2 machines with 8 GPUs per machine.
196+
_is_invoked_via_mpi = (
197+
os.getenv("OMPI_COMM_WORLD_SIZE") is not None
198+
and int(os.getenv("OMPI_COMM_WORLD_SIZE")) >= 8
199+
)
200+
if _is_invoked_via_mpi:
201+
try:
202+
import smdistributed.dataparallel.torch.distributed as smdataparallel
203+
204+
if smdataparallel.get_world_size():
205+
return smdataparallel.get_world_size()
206+
except (ModuleNotFoundError, ValueError, ImportError):
207+
pass
195208
# Return default
196209
return 1
197210

@@ -212,13 +225,20 @@ def _get_worker_name(self):
212225
pass
213226

214227
# Try smdataparallel
215-
try:
216-
import smdistributed.dataparallel.torch.distributed as smdataparallel
217-
218-
if smdataparallel.get_world_size():
219-
return f"worker_{smdataparallel.get_rank()}"
220-
except (ModuleNotFoundError, ValueError, ImportError):
221-
pass
228+
# smdistributed.dataparallel should be invoked via `mpirun`.
229+
# It supports EC2 machines with 8 GPUs per machine.
230+
_is_invoked_via_mpi = (
231+
os.getenv("OMPI_COMM_WORLD_SIZE") is not None
232+
and int(os.getenv("OMPI_COMM_WORLD_SIZE")) >= 8
233+
)
234+
if _is_invoked_via_mpi:
235+
try:
236+
import smdistributed.dataparallel.torch.distributed as smdataparallel
237+
238+
if smdataparallel.get_world_size():
239+
return f"worker_{smdataparallel.get_rank()}"
240+
except (ModuleNotFoundError, ValueError, ImportError):
241+
pass
222242
# Return default
223243
return DEFAULT_WORKER_NAME
224244

smdebug/tensorflow/base_hook.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,21 @@ def _get_distribution_strategy(self) -> TFDistributionStrategy:
133133
except (ModuleNotFoundError, ValueError, ImportError):
134134
pass
135135

136-
try:
137-
import smdistributed.dataparallel.tensorflow as smdataparallel
138-
139-
# The total number of GPUs across all the nodes in the cluster
140-
if smdataparallel.size():
141-
return TFDistributionStrategy.SMDATAPARALLEL
142-
except (ModuleNotFoundError, ValueError, ImportError):
143-
pass
136+
# smdistributed.dataparallel should be invoked via `mpirun`.
137+
# It supports EC2 machines with 8 GPUs per machine.
138+
_is_invoked_via_mpi = (
139+
os.getenv("OMPI_COMM_WORLD_SIZE") is not None
140+
and int(os.getenv("OMPI_COMM_WORLD_SIZE")) >= 8
141+
)
142+
if _is_invoked_via_mpi:
143+
try:
144+
import smdistributed.dataparallel.tensorflow as smdataparallel
145+
146+
# The total number of GPUs across all the nodes in the cluster
147+
if smdataparallel.size():
148+
return TFDistributionStrategy.SMDATAPARALLEL
149+
except (ModuleNotFoundError, ValueError, ImportError):
150+
pass
144151

145152
strat = tf.distribute.get_strategy()
146153
if is_mirrored_strategy(strat):

0 commit comments

Comments
 (0)