|
22 | 22 | from smdebug.core.logger import get_logger |
23 | 23 | from smdebug.exceptions import IndexReaderException |
24 | 24 |
|
| 25 | +_is_invoked_via_smddp = None |
25 | 26 | logger = get_logger() |
26 | 27 |
|
27 | 28 |
|
@@ -345,24 +346,20 @@ def get_distributed_worker(): |
345 | 346 |
|
346 | 347 | # smdistributed.dataparallel should be invoked via `mpirun`. |
347 | 348 | # It supports EC2 machines with 8 GPUs per machine. |
348 | | - _is_invoked_via_mpi = ( |
349 | | - os.getenv("OMPI_COMM_WORLD_SIZE") is not None |
350 | | - and int(os.getenv("OMPI_COMM_WORLD_SIZE")) >= 8 |
351 | | - ) |
352 | | - if _is_invoked_via_mpi: |
| 349 | + if check_smdataparallel_env(): |
353 | 350 | try: |
354 | 351 | import smdistributed.dataparallel.torch.distributed as smdataparallel |
355 | 352 |
|
356 | 353 | if smdataparallel.get_world_size(): |
357 | | - rank = smdataparallel.get_rank() |
| 354 | + return smdataparallel.get_rank() |
358 | 355 | except (ModuleNotFoundError, ValueError, ImportError): |
359 | 356 | pass |
360 | 357 |
|
361 | 358 | try: |
362 | 359 | import smdistributed.dataparallel.tensorflow as smdataparallel |
363 | 360 |
|
364 | 361 | if smdataparallel.size(): |
365 | | - rank = smdataparallel.rank() |
| 362 | + return smdataparallel.rank() |
366 | 363 | except (ModuleNotFoundError, ValueError, ImportError): |
367 | 364 | pass |
368 | 365 | return rank |
@@ -474,3 +471,29 @@ def __exit__(self, *args): |
474 | 471 | shutil.rmtree(self.out_dir, ignore_errors=True) |
475 | 472 | if self.tensorboard_dir: |
476 | 473 | shutil.rmtree(self.tensorboard_dir, ignore_errors=True) |
| 474 | + |
| 475 | + |
| 476 | +def check_smdataparallel_env(): |
| 477 | + # Check to ensure it is invoked by mpi and the SM distribution is `dataparallel` |
| 478 | + global _is_invoked_via_smddp |
| 479 | + if _is_invoked_via_smddp is None: |
| 480 | + _is_invoked_via_mpi = ( |
| 481 | + os.getenv("OMPI_COMM_WORLD_SIZE") is not None |
| 482 | + and int(os.getenv("OMPI_COMM_WORLD_SIZE")) >= 8 |
| 483 | + ) |
| 484 | + if os.getenv("SM_FRAMEWORK_PARAMS") is None: |
| 485 | + _is_invoked_via_smddp = False |
| 486 | + else: |
| 487 | + try: |
| 488 | + smddp_flag = json.loads(os.getenv("SM_FRAMEWORK_PARAMS")) |
| 489 | + except: |
| 490 | + _is_invoked_via_smddp = False |
| 491 | + return _is_invoked_via_smddp |
| 492 | + if ( |
| 493 | + smddp_flag.get("sagemaker_distributed_dataparallel_enabled", False) |
| 494 | + and _is_invoked_via_mpi |
| 495 | + ): |
| 496 | + _is_invoked_via_smddp = True |
| 497 | + else: |
| 498 | + _is_invoked_via_smddp = False |
| 499 | + return _is_invoked_via_smddp |
0 commit comments