|
23 | 23 | from smdebug.exceptions import IndexReaderException |
24 | 24 |
|
25 | 25 | _is_invoked_via_smddp = None |
| 26 | + |
| 27 | +try: |
| 28 | + import smdistributed.modelparallel.tensorflow as smp |
| 29 | + |
| 30 | + _smp_imported = smp |
| 31 | +except (ImportError, ModuleNotFoundError): |
| 32 | + try: |
| 33 | + import smdistributed.modelparallel.torch as smp |
| 34 | + |
| 35 | + _smp_imported = smp |
| 36 | + except (ImportError, ModuleNotFoundError): |
| 37 | + _smp_imported = None |
| 38 | + |
| 39 | + |
| 40 | +try: |
| 41 | + import torch.distributed as dist |
| 42 | + |
| 43 | + _torch_dist_imported = dist |
| 44 | +except (ImportError, ModuleNotFoundError): |
| 45 | + _torch_dist_imported = None |
| 46 | + |
| 47 | + |
| 48 | +try: |
| 49 | + import horovod.torch as hvd |
| 50 | + |
| 51 | + _hvd_imported = hvd |
| 52 | +except (ModuleNotFoundError, ImportError): |
| 53 | + try: |
| 54 | + import horovod.tensorflow as hvd |
| 55 | + |
| 56 | + _hvd_imported = hvd |
| 57 | + except (ModuleNotFoundError, ImportError): |
| 58 | + _hvd_imported = None |
| 59 | + |
| 60 | + |
| 61 | +try: |
| 62 | + import smdistributed.dataparallel.torch.distributed as smdataparallel |
| 63 | + |
| 64 | + _smdataparallel_imported = smdataparallel |
| 65 | +except (ModuleNotFoundError, ImportError): |
| 66 | + try: |
| 67 | + import smdistributed.dataparallel.tensorflow as smdataparallel |
| 68 | + |
| 69 | + _smdataparallel_imported = smdataparallel |
| 70 | + except (ModuleNotFoundError, ImportError): |
| 71 | + _smdataparallel_imported = None |
| 72 | + |
| 73 | + |
26 | 74 | logger = get_logger() |
27 | 75 |
|
28 | 76 |
|
@@ -317,51 +365,34 @@ def get_tb_worker(): |
317 | 365 |
|
318 | 366 |
|
319 | 367 | def get_distributed_worker(): |
320 | | - """Get the rank for horovod or torch distributed. If none of them are being used, |
| 368 | + """ |
| 369 | + Get the rank for horovod or torch distributed. If none of them are being used, |
321 | 370 | return None""" |
322 | 371 | rank = None |
323 | | - try: |
324 | | - import torch.distributed as dist |
325 | | - except (ImportError, ModuleNotFoundError): |
326 | | - dist = None |
327 | | - rank = None |
328 | | - if dist and hasattr(dist, "is_initialized") and dist.is_initialized(): |
329 | | - rank = dist.get_rank() |
330 | | - else: |
| 372 | + if ( |
| 373 | + _torch_dist_imported |
| 374 | + and hasattr(_torch_dist_imported, "is_initialized") |
| 375 | + and _torch_dist_imported.is_initialized() |
| 376 | + ): |
| 377 | + rank = _torch_dist_imported.get_rank() |
| 378 | + elif _smp_imported and smp.core.initialized: |
| 379 | + rank = smp.rank() |
| 380 | + elif check_smdataparallel_env(): |
| 381 | + # smdistributed.dataparallel should be invoked via `mpirun`. |
| 382 | + # It supports EC2 machines with 8 GPUs per machine. |
| 383 | + assert smdataparallel is not None |
331 | 384 | try: |
332 | | - import horovod.torch as hvd |
333 | | - |
334 | | - if hvd.size(): |
335 | | - rank = hvd.rank() |
336 | | - except (ModuleNotFoundError, ValueError, ImportError): |
| 385 | + if smdataparallel.get_world_size(): |
| 386 | + return smdataparallel.get_rank() |
| 387 | + except ValueError: |
337 | 388 | pass |
338 | | - |
| 389 | + elif _hvd_imported: |
339 | 390 | try: |
340 | | - import horovod.tensorflow as hvd |
341 | | - |
342 | 391 | if hvd.size(): |
343 | 392 | rank = hvd.rank() |
344 | | - except (ModuleNotFoundError, ValueError, ImportError): |
| 393 | + except ValueError: |
345 | 394 | pass |
346 | 395 |
|
347 | | - # smdistributed.dataparallel should be invoked via `mpirun`. |
348 | | - # It supports EC2 machines with 8 GPUs per machine. |
349 | | - if check_smdataparallel_env(): |
350 | | - try: |
351 | | - import smdistributed.dataparallel.torch.distributed as smdataparallel |
352 | | - |
353 | | - if smdataparallel.get_world_size(): |
354 | | - return smdataparallel.get_rank() |
355 | | - except (ModuleNotFoundError, ValueError, ImportError): |
356 | | - pass |
357 | | - |
358 | | - try: |
359 | | - import smdistributed.dataparallel.tensorflow as smdataparallel |
360 | | - |
361 | | - if smdataparallel.size(): |
362 | | - return smdataparallel.rank() |
363 | | - except (ModuleNotFoundError, ValueError, ImportError): |
364 | | - pass |
365 | 396 | return rank |
366 | 397 |
|
367 | 398 |
|
|
0 commit comments