Skip to content

Commit 07a3fd9

Browse files
authored
Use SMP rank and size when applicable (#411)
* Add smp rank * Switch to core initialize * Use smp size * Cache whether SMP can be imported * Lint * try import with noqa' * Add smp rank call in core * Import only once * Use nested except blocks
1 parent 9d2d0c3 commit 07a3fd9

File tree

2 files changed

+89
-38
lines changed

2 files changed

+89
-38
lines changed

smdebug/core/utils.py

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,54 @@
2323
from smdebug.exceptions import IndexReaderException
2424

2525
_is_invoked_via_smddp = None
26+
27+
try:
28+
import smdistributed.modelparallel.tensorflow as smp
29+
30+
_smp_imported = smp
31+
except (ImportError, ModuleNotFoundError):
32+
try:
33+
import smdistributed.modelparallel.torch as smp
34+
35+
_smp_imported = smp
36+
except (ImportError, ModuleNotFoundError):
37+
_smp_imported = None
38+
39+
40+
try:
41+
import torch.distributed as dist
42+
43+
_torch_dist_imported = dist
44+
except (ImportError, ModuleNotFoundError):
45+
_torch_dist_imported = None
46+
47+
48+
try:
49+
import horovod.torch as hvd
50+
51+
_hvd_imported = hvd
52+
except (ModuleNotFoundError, ImportError):
53+
try:
54+
import horovod.tensorflow as hvd
55+
56+
_hvd_imported = hvd
57+
except (ModuleNotFoundError, ImportError):
58+
_hvd_imported = None
59+
60+
61+
try:
62+
import smdistributed.dataparallel.torch.distributed as smdataparallel
63+
64+
_smdataparallel_imported = smdataparallel
65+
except (ModuleNotFoundError, ImportError):
66+
try:
67+
import smdistributed.dataparallel.tensorflow as smdataparallel
68+
69+
_smdataparallel_imported = smdataparallel
70+
except (ModuleNotFoundError, ImportError):
71+
_smdataparallel_imported = None
72+
73+
2674
logger = get_logger()
2775

2876

@@ -317,51 +365,34 @@ def get_tb_worker():
317365

318366

319367
def get_distributed_worker():
320-
"""Get the rank for horovod or torch distributed. If none of them are being used,
368+
"""
369+
Get the rank for horovod or torch distributed. If none of them are being used,
321370
return None"""
322371
rank = None
323-
try:
324-
import torch.distributed as dist
325-
except (ImportError, ModuleNotFoundError):
326-
dist = None
327-
rank = None
328-
if dist and hasattr(dist, "is_initialized") and dist.is_initialized():
329-
rank = dist.get_rank()
330-
else:
372+
if (
373+
_torch_dist_imported
374+
and hasattr(_torch_dist_imported, "is_initialized")
375+
and _torch_dist_imported.is_initialized()
376+
):
377+
rank = _torch_dist_imported.get_rank()
378+
elif _smp_imported and smp.core.initialized:
379+
rank = smp.rank()
380+
elif check_smdataparallel_env():
381+
# smdistributed.dataparallel should be invoked via `mpirun`.
382+
# It supports EC2 machines with 8 GPUs per machine.
383+
assert smdataparallel is not None
331384
try:
332-
import horovod.torch as hvd
333-
334-
if hvd.size():
335-
rank = hvd.rank()
336-
except (ModuleNotFoundError, ValueError, ImportError):
385+
if smdataparallel.get_world_size():
386+
return smdataparallel.get_rank()
387+
except ValueError:
337388
pass
338-
389+
elif _hvd_imported:
339390
try:
340-
import horovod.tensorflow as hvd
341-
342391
if hvd.size():
343392
rank = hvd.rank()
344-
except (ModuleNotFoundError, ValueError, ImportError):
393+
except ValueError:
345394
pass
346395

347-
# smdistributed.dataparallel should be invoked via `mpirun`.
348-
# It supports EC2 machines with 8 GPUs per machine.
349-
if check_smdataparallel_env():
350-
try:
351-
import smdistributed.dataparallel.torch.distributed as smdataparallel
352-
353-
if smdataparallel.get_world_size():
354-
return smdataparallel.get_rank()
355-
except (ModuleNotFoundError, ValueError, ImportError):
356-
pass
357-
358-
try:
359-
import smdistributed.dataparallel.tensorflow as smdataparallel
360-
361-
if smdataparallel.size():
362-
return smdataparallel.rank()
363-
except (ModuleNotFoundError, ValueError, ImportError):
364-
pass
365396
return rank
366397

367398

smdebug/tensorflow/base_hook.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@
3333
)
3434

3535
try:
36-
pass
36+
import smdistributed.modelparallel.tensorflow as smp # noqa isort:skip
37+
38+
_smp_importable = True
3739
except ImportError:
38-
pass
40+
_smp_importable = False
3941

4042

4143
DEFAULT_INCLUDE_COLLECTIONS = [
@@ -183,6 +185,15 @@ def _get_worker_name(self) -> str:
183185
"""
184186
self._assert_distribution_strategy()
185187
if self.distribution_strategy == TFDistributionStrategy.HOROVOD:
188+
if _smp_importable:
189+
# when model parallel is being used, there will be multiple processes
190+
# with same hvd rank, hence use smp.rank
191+
import smdistributed.modelparallel.tensorflow as smp
192+
193+
if smp.core.initialized:
194+
# if smp is in use
195+
return f"worker_{smp.rank()}"
196+
186197
import horovod.tensorflow as hvd
187198

188199
return f"worker_{hvd.rank()}"
@@ -260,6 +271,15 @@ def _get_custom_and_default_collections(self) -> Tuple[Set["Collection"], Set["C
260271
def _get_num_workers(self):
261272
self._assert_distribution_strategy()
262273
if self.distribution_strategy == TFDistributionStrategy.HOROVOD:
274+
if _smp_importable:
275+
# when model parallel is being used, there will be multiple hvd process groups,
276+
# hence use smp.size
277+
import smdistributed.modelparallel.tensorflow as smp
278+
279+
if smp.core.initialized:
280+
# if smp is in use
281+
return smp.size()
282+
263283
import horovod.tensorflow as hvd
264284

265285
return hvd.size()

0 commit comments

Comments
 (0)