diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index b09cf2b11ca5..2f4ed9e20aff 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -573,7 +573,7 @@ def _configure_with_arguments(self, args, mpu): if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: ompi_local_rank = os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") local_rank = os.environ.get('LOCAL_RANK', ompi_local_rank) - assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({mpi_local_rank}), " \ + assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), " \ "not sure how to proceed as we're seeing conficting local rank info." os.environ['LOCAL_RANK'] = local_rank @@ -907,6 +907,13 @@ def _configure_zero_optimizer(self, optimizer): gradient_predivide=self.gradient_predivide) elif zero_stage <= ZERO_OPTIMIZATION_GRADIENTS: overlap_comm = self.zero_overlap_comm() + contiguous_gradients = self.zero_contiguous_gradients() + + # Overlap and contiguous grads are meaningless in stage 1 and are ignored + if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: + overlap_comm = False + contiguous_gradients = False + if isinstance(self.module, PipelineModule): if overlap_comm: logger.warning( @@ -921,7 +928,7 @@ def _configure_zero_optimizer(self, optimizer): dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_args=self.dynamic_loss_scale_args(), clip_grad=self.gradient_clipping(), - contiguous_gradients=self.zero_contiguous_gradients(), + contiguous_gradients=contiguous_gradients, reduce_bucket_size=self.zero_reduce_bucket_size(), allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.data_parallel_group,