Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def _configure_with_arguments(self, args, mpu):
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
ompi_local_rank = os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
local_rank = os.environ.get('LOCAL_RANK', ompi_local_rank)
assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({mpi_local_rank}), " \
assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), " \
"not sure how to proceed as we're seeing conficting local rank info."
os.environ['LOCAL_RANK'] = local_rank

Expand Down Expand Up @@ -907,6 +907,13 @@ def _configure_zero_optimizer(self, optimizer):
gradient_predivide=self.gradient_predivide)
elif zero_stage <= ZERO_OPTIMIZATION_GRADIENTS:
overlap_comm = self.zero_overlap_comm()
contiguous_gradients = self.zero_contiguous_gradients()

# Overlap and contiguous grads are meaningless in stage 1 and are ignored
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
overlap_comm = False
contiguous_gradients = False

if isinstance(self.module, PipelineModule):
if overlap_comm:
logger.warning(
Expand All @@ -921,7 +928,7 @@ def _configure_zero_optimizer(self, optimizer):
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
contiguous_gradients=contiguous_gradients,
reduce_bucket_size=self.zero_reduce_bucket_size(),
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.data_parallel_group,
Expand Down