From 0b6f9e49218304c033c1e723a13400a38366e1a5 Mon Sep 17 00:00:00 2001 From: Zachary Mueller Date: Wed, 24 May 2023 15:51:09 -0400 Subject: [PATCH] Fix sagemaker DP/MP (#23681) * Check for use_sagemaker_dp * Add a check for is_sagemaker_mp when setting _n_gpu again. Should be last broken thing * Try explicit check? * Quality --- src/transformers/trainer.py | 4 +++- src/transformers/training_args.py | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 72fcd34d7ff..79f19f0a342 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3398,7 +3398,9 @@ def _nested_gather(self, tensors, name=None): tensors = nested_xla_mesh_reduce(tensors, name) elif is_sagemaker_mp_enabled(): tensors = smp_gather(tensors) - elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: + elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or ( + self.args.distributed_state is None and self.local_rank != -1 + ): tensors = distributed_concat(tensors) return tensors diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 57aca25712d..36258c1508f 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1629,6 +1629,9 @@ def _setup_devices(self) -> "torch.device": device = torch.device("cuda", local_rank) self._n_gpu = 1 torch.cuda.set_device(device) + elif is_sagemaker_dp_enabled(): + self.distributed_state = PartialState(_use_sagemaker_dp=True) + self._n_gpu = 1 elif self.deepspeed: # Need to do similar for Accelerator init os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" @@ -1653,8 +1656,9 @@ def _setup_devices(self) -> "torch.device": if is_torch_tpu_available(): device = self.distributed_state.device self._n_gpu = 0 - elif is_sagemaker_dp_enabled(): - self._n_gpu = 1 + elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled(): + # Already set _n_gpu + pass elif self.distributed_state.distributed_type == DistributedType.NO: if self.use_mps_device: if not torch.backends.mps.is_available():