Skip to content

Commit ad19c73

Browse files
authored
smdistributed.dataparallel environment check (#140)
* smdistributed.dataparallel environment check * addressed comments * Modified check_smdataparallel_env logic
1 parent 6662793 commit ad19c73

File tree

3 files changed

+38
-29
lines changed

3 files changed

+38
-29
lines changed

smdebug/core/utils.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from smdebug.core.logger import get_logger
2323
from smdebug.exceptions import IndexReaderException
2424

25+
_is_invoked_via_smddp = None
2526
logger = get_logger()
2627

2728

@@ -345,24 +346,20 @@ def get_distributed_worker():
345346

346347
# smdistributed.dataparallel should be invoked via `mpirun`.
347348
# It supports EC2 machines with 8 GPUs per machine.
348-
_is_invoked_via_mpi = (
349-
os.getenv("OMPI_COMM_WORLD_SIZE") is not None
350-
and int(os.getenv("OMPI_COMM_WORLD_SIZE")) >= 8
351-
)
352-
if _is_invoked_via_mpi:
349+
if check_smdataparallel_env():
353350
try:
354351
import smdistributed.dataparallel.torch.distributed as smdataparallel
355352

356353
if smdataparallel.get_world_size():
357-
rank = smdataparallel.get_rank()
354+
return smdataparallel.get_rank()
358355
except (ModuleNotFoundError, ValueError, ImportError):
359356
pass
360357

361358
try:
362359
import smdistributed.dataparallel.tensorflow as smdataparallel
363360

364361
if smdataparallel.size():
365-
rank = smdataparallel.rank()
362+
return smdataparallel.rank()
366363
except (ModuleNotFoundError, ValueError, ImportError):
367364
pass
368365
return rank
@@ -474,3 +471,29 @@ def __exit__(self, *args):
474471
shutil.rmtree(self.out_dir, ignore_errors=True)
475472
if self.tensorboard_dir:
476473
shutil.rmtree(self.tensorboard_dir, ignore_errors=True)
474+
475+
476+
def check_smdataparallel_env():
477+
# Check to ensure it is invoked by mpi and the SM distribution is `dataparallel`
478+
global _is_invoked_via_smddp
479+
if _is_invoked_via_smddp is None:
480+
_is_invoked_via_mpi = (
481+
os.getenv("OMPI_COMM_WORLD_SIZE") is not None
482+
and int(os.getenv("OMPI_COMM_WORLD_SIZE")) >= 8
483+
)
484+
if os.getenv("SM_FRAMEWORK_PARAMS") is None:
485+
_is_invoked_via_smddp = False
486+
else:
487+
try:
488+
smddp_flag = json.loads(os.getenv("SM_FRAMEWORK_PARAMS"))
489+
except:
490+
_is_invoked_via_smddp = False
491+
return _is_invoked_via_smddp
492+
if (
493+
smddp_flag.get("sagemaker_distributed_dataparallel_enabled", False)
494+
and _is_invoked_via_mpi
495+
):
496+
_is_invoked_via_smddp = True
497+
else:
498+
_is_invoked_via_smddp = False
499+
return _is_invoked_via_smddp

smdebug/pytorch/hook.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from smdebug.core.collection import DEFAULT_PYTORCH_COLLECTIONS, CollectionKeys
1313
from smdebug.core.hook import CallbackHook
1414
from smdebug.core.json_config import DEFAULT_WORKER_NAME
15-
from smdebug.core.utils import make_numpy_array
15+
from smdebug.core.utils import check_smdataparallel_env, make_numpy_array
1616
from smdebug.profiler.hvd_trace_file_rotation import HvdTraceFileRotation
1717
from smdebug.profiler.profiler_config_parser import MetricsCategory, ProfilerConfigParser
1818
from smdebug.profiler.profiler_constants import CONVERT_TO_MICROSECS
@@ -25,14 +25,12 @@
2525

2626
# smdistributed.dataparallel should be invoked via `mpirun`.
2727
# It supports EC2 machines with 8 GPUs per machine.
28-
_is_invoked_via_mpi = (
29-
os.getenv("OMPI_COMM_WORLD_SIZE") is not None and int(os.getenv("OMPI_COMM_WORLD_SIZE")) >= 8
30-
)
31-
if _is_invoked_via_mpi:
28+
smdataparallel = None
29+
if check_smdataparallel_env():
3230
try:
3331
import smdistributed.dataparallel.torch.distributed as smdataparallel
3432
except ImportError:
35-
smdataparallel = None
33+
pass
3634

3735

3836
DEFAULT_INCLUDE_COLLECTIONS = [CollectionKeys.LOSSES]
@@ -193,11 +191,7 @@ def _get_num_workers(self):
193191
# Try smdataparallel
194192
# smdistributed.dataparallel should be invoked via `mpirun`.
195193
# It supports EC2 machines with 8 GPUs per machine.
196-
_is_invoked_via_mpi = (
197-
os.getenv("OMPI_COMM_WORLD_SIZE") is not None
198-
and int(os.getenv("OMPI_COMM_WORLD_SIZE")) >= 8
199-
)
200-
if _is_invoked_via_mpi:
194+
if check_smdataparallel_env():
201195
try:
202196
import smdistributed.dataparallel.torch.distributed as smdataparallel
203197

@@ -227,11 +221,7 @@ def _get_worker_name(self):
227221
# Try smdataparallel
228222
# smdistributed.dataparallel should be invoked via `mpirun`.
229223
# It supports EC2 machines with 8 GPUs per machine.
230-
_is_invoked_via_mpi = (
231-
os.getenv("OMPI_COMM_WORLD_SIZE") is not None
232-
and int(os.getenv("OMPI_COMM_WORLD_SIZE")) >= 8
233-
)
234-
if _is_invoked_via_mpi:
224+
if check_smdataparallel_env():
235225
try:
236226
import smdistributed.dataparallel.torch.distributed as smdataparallel
237227

smdebug/tensorflow/base_hook.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from smdebug.core.hook import BaseHook
1515
from smdebug.core.modes import ModeKeys
1616
from smdebug.core.reductions import get_numpy_reduction, get_reduction_tensor_name
17-
from smdebug.core.utils import make_numpy_array, serialize_tf_device
17+
from smdebug.core.utils import check_smdataparallel_env, make_numpy_array, serialize_tf_device
1818
from smdebug.core.writer import FileWriter
1919

2020
# Local
@@ -135,11 +135,7 @@ def _get_distribution_strategy(self) -> TFDistributionStrategy:
135135

136136
# smdistributed.dataparallel should be invoked via `mpirun`.
137137
# It supports EC2 machines with 8 GPUs per machine.
138-
_is_invoked_via_mpi = (
139-
os.getenv("OMPI_COMM_WORLD_SIZE") is not None
140-
and int(os.getenv("OMPI_COMM_WORLD_SIZE")) >= 8
141-
)
142-
if _is_invoked_via_mpi:
138+
if check_smdataparallel_env():
143139
try:
144140
import smdistributed.dataparallel.tensorflow as smdataparallel
145141

0 commit comments

Comments
 (0)