diff --git a/mmengine/dist/utils.py b/mmengine/dist/utils.py index e0cf8c1171..285c0f371b 100644 --- a/mmengine/dist/utils.py +++ b/mmengine/dist/utils.py @@ -94,6 +94,15 @@ def _init_dist_mpi(backend, **kwargs) -> None: 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. **kwargs: keyword arguments are passed to ``init_process_group``. """ + if backend == 'smddp': + try: + import smdistributed.dataparallel.torch.torch_smddp # noqa: F401 + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + 'Please use an Amazon SageMaker DLC to access smdistributed: ' + 'https://github.com/aws/deep-learning-containers/blob/master' + '/available_images.md#sagemaker-framework-containers' + '-sm-support-only') from e local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) torch.cuda.set_device(local_rank) if 'MASTER_PORT' not in os.environ: @@ -433,6 +442,8 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device: elif backend == 'cncl': import torch_mlu # noqa: F401 return torch.device('mlu', torch.mlu.current_device()) + elif backend == 'smddp': + return torch.device('cuda', torch.cuda.current_device()) else: # GLOO and MPI backends use cpu device by default return torch.device('cpu')