From e5ecdf5463ab5c5205d903e6a7bddb7169a4a165 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Mon, 26 Jul 2021 17:13:26 -0700 Subject: [PATCH 1/5] ignore overlap/contiguous_gradients if using zero 1 (#1246) Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/engine.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 013924def623..579e4aaf424e 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -592,7 +592,7 @@ def _configure_with_arguments(self, args, mpu): if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: ompi_local_rank = os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") local_rank = os.environ.get('LOCAL_RANK', ompi_local_rank) - assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({mpi_local_rank}), " \ + assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), " \ "not sure how to proceed as we're seeing conficting local rank info." os.environ['LOCAL_RANK'] = local_rank @@ -926,6 +926,13 @@ def _configure_zero_optimizer(self, optimizer): gradient_predivide=self.gradient_predivide) elif zero_stage <= ZERO_OPTIMIZATION_GRADIENTS: overlap_comm = self.zero_overlap_comm() + contiguous_gradients = self.zero_contiguous_gradients() + + # Overlap and contiguous grads are meaningless in stage 1 and are ignored + if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: + overlap_comm = False + contiguous_gradients = False + if isinstance(self.module, PipelineModule): if overlap_comm: logger.warning( @@ -940,7 +947,7 @@ def _configure_zero_optimizer(self, optimizer): dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_args=self.dynamic_loss_scale_args(), clip_grad=self.gradient_clipping(), - contiguous_gradients=self.zero_contiguous_gradients(), + contiguous_gradients=contiguous_gradients, reduce_bucket_size=self.zero_reduce_bucket_size(), allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.data_parallel_group, From 5bb09f87939adc85bdc1a54f9b3aff09757fb9d4 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 28 Jul 2021 11:03:06 -0700 Subject: [PATCH 2/5] Make round robin gradient partitioning configurable (default False) (#1256) --- deepspeed/runtime/engine.py | 8 +++++++- deepspeed/runtime/zero/config.py | 6 ++++++ deepspeed/runtime/zero/constants.py | 11 +++++++++-- deepspeed/runtime/zero/stage2.py | 17 ++++++++++++----- 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 579e4aaf424e..0dc447c54765 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -501,6 +501,9 @@ def steps_per_print(self): def zero_allgather_partitions(self): return self._config.zero_config.allgather_partitions + def zero_round_robin_gradients(self): + return self._config.zero_config.round_robin_gradients + def dump_state(self): return self._config.dump_state @@ -927,11 +930,13 @@ def _configure_zero_optimizer(self, optimizer): elif zero_stage <= ZERO_OPTIMIZATION_GRADIENTS: overlap_comm = self.zero_overlap_comm() contiguous_gradients = self.zero_contiguous_gradients() + round_robin_gradients = self.zero_round_robin_gradients() # Overlap and contiguous grads are meaningless in stage 1 and are ignored if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: overlap_comm = False contiguous_gradients = False + round_robin_gradients = False if isinstance(self.module, PipelineModule): if overlap_comm: @@ -959,7 +964,8 @@ def _configure_zero_optimizer(self, optimizer): gradient_predivide_factor=self.gradient_predivide_factor(), gradient_accumulation_steps=self.gradient_accumulation_steps(), ignore_unused_parameters=self.zero_ignore_unused_parameters(), - partition_grads=zero_stage == ZERO_OPTIMIZATION_GRADIENTS) + partition_grads=zero_stage == ZERO_OPTIMIZATION_GRADIENTS, + round_robin_gradients=round_robin_gradients) elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS: logger.info("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3 diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 377ad94549a7..fe81fceebd33 100755 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -39,6 +39,7 @@ def __init__(self, param_dict): self.gather_fp16_weights_on_model_save = None self.ignore_unused_parameters = None + self.round_robin_gradients = None if ZERO_OPTIMIZATION in param_dict.keys(): zero_config_dict = param_dict[ZERO_OPTIMIZATION] @@ -184,3 +185,8 @@ def _initialize(self, zero_config_dict): self.legacy_stage1 = get_scalar_param(zero_config_dict, ZERO_OPTIMIZATION_LEGACY_STAGE1, ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT) + + self.round_robin_gradients = get_scalar_param( + zero_config_dict, + ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS, + ZERO3_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT) diff --git a/deepspeed/runtime/zero/constants.py b/deepspeed/runtime/zero/constants.py index 7beebe00e717..e0a26d53609f 100755 --- a/deepspeed/runtime/zero/constants.py +++ b/deepspeed/runtime/zero/constants.py @@ -30,7 +30,8 @@ "sub_group_size" : 1000000000000, "offload_param": {...}, "offload_optimizer": {...}, - "ignore_unused_parameters": [true|false] + "ignore_unused_parameters": [true|false], + "round_robin_gradients": [true|false] } } ''' @@ -124,6 +125,10 @@ ZERO_OPTIMIZATION_LEGACY_STAGE1 = "legacy_stage1" ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT = False +# Stage 2 - partition gradients in a round robin fashsion to load-balance reduction and offload copying +ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS = 'round_robin_gradients' +ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS_DEFAULT = False + #yapf: disable ZERO_OPTIMIZATION_DEFAULT = { ZERO_OPTIMIZATION_STAGE: @@ -161,5 +166,7 @@ ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS: ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS_DEFAULT, ZERO_OPTIMIZATION_LEGACY_STAGE1: - ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT + ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT, + ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS: + ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS_DEFAULT } diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 64955ff400dc..8e4394d14f79 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -99,7 +99,8 @@ def __init__(self, gradient_predivide_factor=1.0, gradient_accumulation_steps=1, ignore_unused_parameters=True, - partition_grads=True): + partition_grads=True, + round_robin_gradients=False): if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") @@ -159,6 +160,7 @@ def __init__(self, self.gradient_accumulation_steps = gradient_accumulation_steps self.micro_step_id = 0 self.ignore_unused_parameters = ignore_unused_parameters + self.round_robin_gradients = round_robin_gradients self.extra_large_param_to_reduce = None @@ -232,10 +234,15 @@ def __init__(self, # This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks. # For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging # to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m). - round_robin_tensors, round_robin_indices = self._round_robin_reorder( - self.fp16_groups[i], - dist.get_world_size(group=self.dp_process_group) - ) + if self.round_robin_gradients: + round_robin_tensors, round_robin_indices = self._round_robin_reorder( + self.fp16_groups[i], + dist.get_world_size(group=self.dp_process_group) + ) + else: + round_robin_tensors = self.fp16_groups[i] + round_robin_indices = list(range(len(self.fp16_groups[i]))) + self.round_robin_fp16_groups.append(round_robin_tensors) self.round_robin_fp6_indices.append(round_robin_indices) From d370f535ce9b289cc9cb4442cd938ca83cb04221 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 29 Jul 2021 22:51:53 +0000 Subject: [PATCH 3/5] pass GAS boundary state from PP -> ZeRO --- deepspeed/runtime/engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 0dc447c54765..57b9c12b6f9c 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1176,6 +1176,9 @@ def forward(self, *inputs, **kwargs): return loss def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): + # Pass (PP) gas boundary flag to optimizer (required for zero) + self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() + # ZeRO stage 2 communicates during non gradient accumulation boundaries as well if self.zero_optimization_partition_gradients(): self.optimizer.overlapping_partition_gradients_reduce_epilogue() From 0067c88ee58c7600c0be6ea435c95ad86beb12d8 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 28 Jul 2021 15:11:42 -0700 Subject: [PATCH 4/5] Use correct default for round robin gradients (#1258) * Make round robin gradient partitioning configurable (default False) * Use the correct default * Log config setting --- deepspeed/runtime/zero/config.py | 2 +- deepspeed/runtime/zero/stage2.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index fe81fceebd33..a48dd4e620b4 100755 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -189,4 +189,4 @@ def _initialize(self, zero_config_dict): self.round_robin_gradients = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS, - ZERO3_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT) + ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS_DEFAULT) diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 8e4394d14f79..a7597773641d 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -106,6 +106,7 @@ def __init__(self, logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Allgather bucket size {allgather_bucket_size}") logger.info(f"CPU Offload: {cpu_offload}") + logger.info(f'Round robin gradient partitioning: {round_robin_gradients}') # The fused optimizer does all the work. We need this layer for two reason: # 1. maintain same user API from apex.fp16_utils # 2. keep common stuff here in case we need to add ne552w fused optimizer later From 624303f2820719abaa6710f6abc2442db74a49b5 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 29 Jul 2021 23:01:36 +0000 Subject: [PATCH 5/5] formatting --- deepspeed/runtime/engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 57b9c12b6f9c..5791b3fab85d 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1177,7 +1177,8 @@ def forward(self, *inputs, **kwargs): def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): # Pass (PP) gas boundary flag to optimizer (required for zero) - self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() + self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary( + ) # ZeRO stage 2 communicates during non gradient accumulation boundaries as well if self.zero_optimization_partition_gradients():