diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 2f4ed9e20aff..2c48d0ed9bfb 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -482,6 +482,9 @@ def steps_per_print(self): def zero_allgather_partitions(self): return self._config.zero_config.allgather_partitions + def zero_round_robin_gradients(self): + return self._config.zero_config.round_robin_gradients + def dump_state(self): return self._config.dump_state @@ -908,11 +911,13 @@ def _configure_zero_optimizer(self, optimizer): elif zero_stage <= ZERO_OPTIMIZATION_GRADIENTS: overlap_comm = self.zero_overlap_comm() contiguous_gradients = self.zero_contiguous_gradients() + round_robin_gradients = self.zero_round_robin_gradients() # Overlap and contiguous grads are meaningless in stage 1 and are ignored if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: overlap_comm = False contiguous_gradients = False + round_robin_gradients = False if isinstance(self.module, PipelineModule): if overlap_comm: @@ -940,7 +945,8 @@ def _configure_zero_optimizer(self, optimizer): gradient_predivide_factor=self.gradient_predivide_factor(), gradient_accumulation_steps=self.gradient_accumulation_steps(), ignore_unused_parameters=self.zero_ignore_unused_parameters(), - partition_grads=zero_stage == ZERO_OPTIMIZATION_GRADIENTS) + partition_grads=zero_stage == ZERO_OPTIMIZATION_GRADIENTS, + round_robin_gradients=round_robin_gradients) elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS: logger.info("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3 diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 377ad94549a7..fe81fceebd33 100755 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -39,6 +39,7 @@ def __init__(self, param_dict): self.gather_fp16_weights_on_model_save = None self.ignore_unused_parameters = None + self.round_robin_gradients = None if ZERO_OPTIMIZATION in param_dict.keys(): zero_config_dict = param_dict[ZERO_OPTIMIZATION] @@ -184,3 +185,8 @@ def _initialize(self, zero_config_dict): self.legacy_stage1 = get_scalar_param(zero_config_dict, ZERO_OPTIMIZATION_LEGACY_STAGE1, ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT) + + self.round_robin_gradients = get_scalar_param( + zero_config_dict, + ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS, + ZERO3_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT) diff --git a/deepspeed/runtime/zero/constants.py b/deepspeed/runtime/zero/constants.py index a003dc5611f2..6c1f249540cb 100755 --- a/deepspeed/runtime/zero/constants.py +++ b/deepspeed/runtime/zero/constants.py @@ -30,7 +30,8 @@ "sub_group_size" : 1000000000000, "offload_param": {...}, "offload_optimizer": {...}, - "ignore_unused_parameters": [true|false] + "ignore_unused_parameters": [true|false], + "round_robin_gradients": [true|false] } } ''' @@ -124,6 +125,10 @@ ZERO_OPTIMIZATION_LEGACY_STAGE1 = "legacy_stage1" ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT = False +# Stage 2 - partition gradients in a round robin fashsion to load-balance reduction and offload copying +ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS = 'round_robin_gradients' +ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS_DEFAULT = False + #yapf: disable ZERO_OPTIMIZATION_DEFAULT = { ZERO_OPTIMIZATION_STAGE: @@ -161,5 +166,7 @@ ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS: ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS_DEFAULT, ZERO_OPTIMIZATION_LEGACY_STAGE1: - ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT + ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT, + ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS: + ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS_DEFAULT } diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 5ea19d7f738a..064d59629d87 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -99,7 +99,8 @@ def __init__(self, gradient_predivide_factor=1.0, gradient_accumulation_steps=1, ignore_unused_parameters=True, - partition_grads=True): + partition_grads=True, + round_robin_gradients=False): if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") @@ -159,6 +160,7 @@ def __init__(self, self.gradient_accumulation_steps = gradient_accumulation_steps self.micro_step_id = 0 self.ignore_unused_parameters = ignore_unused_parameters + self.round_robin_gradients = round_robin_gradients self.extra_large_param_to_reduce = None @@ -232,10 +234,15 @@ def __init__(self, # This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks. # For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging # to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m). - round_robin_tensors, round_robin_indices = self._round_robin_reorder( - self.fp16_groups[i], - dist.get_world_size(group=self.dp_process_group) - ) + if self.round_robin_gradients: + round_robin_tensors, round_robin_indices = self._round_robin_reorder( + self.fp16_groups[i], + dist.get_world_size(group=self.dp_process_group) + ) + else: + round_robin_tensors = self.fp16_groups[i] + round_robin_indices = list(range(len(self.fp16_groups[i]))) + self.round_robin_fp16_groups.append(round_robin_tensors) self.round_robin_fp6_indices.append(round_robin_indices)