Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,9 @@ def steps_per_print(self):
def zero_allgather_partitions(self):
return self._config.zero_config.allgather_partitions

def zero_round_robin_gradients(self):
return self._config.zero_config.round_robin_gradients

def dump_state(self):
return self._config.dump_state

Expand Down Expand Up @@ -908,11 +911,13 @@ def _configure_zero_optimizer(self, optimizer):
elif zero_stage <= ZERO_OPTIMIZATION_GRADIENTS:
overlap_comm = self.zero_overlap_comm()
contiguous_gradients = self.zero_contiguous_gradients()
round_robin_gradients = self.zero_round_robin_gradients()

# Overlap and contiguous grads are meaningless in stage 1 and are ignored
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
overlap_comm = False
contiguous_gradients = False
round_robin_gradients = False

if isinstance(self.module, PipelineModule):
if overlap_comm:
Expand Down Expand Up @@ -940,7 +945,8 @@ def _configure_zero_optimizer(self, optimizer):
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
ignore_unused_parameters=self.zero_ignore_unused_parameters(),
partition_grads=zero_stage == ZERO_OPTIMIZATION_GRADIENTS)
partition_grads=zero_stage == ZERO_OPTIMIZATION_GRADIENTS,
round_robin_gradients=round_robin_gradients)
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
logger.info("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
Expand Down
6 changes: 6 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, param_dict):
self.gather_fp16_weights_on_model_save = None

self.ignore_unused_parameters = None
self.round_robin_gradients = None

if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
Expand Down Expand Up @@ -184,3 +185,8 @@ def _initialize(self, zero_config_dict):
self.legacy_stage1 = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_LEGACY_STAGE1,
ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT)

self.round_robin_gradients = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS,
ZERO3_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT)
11 changes: 9 additions & 2 deletions deepspeed/runtime/zero/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
"sub_group_size" : 1000000000000,
"offload_param": {...},
"offload_optimizer": {...},
"ignore_unused_parameters": [true|false]
"ignore_unused_parameters": [true|false],
"round_robin_gradients": [true|false]
}
}
'''
Expand Down Expand Up @@ -124,6 +125,10 @@
ZERO_OPTIMIZATION_LEGACY_STAGE1 = "legacy_stage1"
ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT = False

# Stage 2 - partition gradients in a round robin fashsion to load-balance reduction and offload copying
ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS = 'round_robin_gradients'
ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS_DEFAULT = False

#yapf: disable
ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE:
Expand Down Expand Up @@ -161,5 +166,7 @@
ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS:
ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS_DEFAULT,
ZERO_OPTIMIZATION_LEGACY_STAGE1:
ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT
ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT,
ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS:
ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS_DEFAULT
}
17 changes: 12 additions & 5 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def __init__(self,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1,
ignore_unused_parameters=True,
partition_grads=True):
partition_grads=True,
round_robin_gradients=False):

if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(self,
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
self.ignore_unused_parameters = ignore_unused_parameters
self.round_robin_gradients = round_robin_gradients

self.extra_large_param_to_reduce = None

Expand Down Expand Up @@ -232,10 +234,15 @@ def __init__(self,
# This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks.
# For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging
# to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m).
round_robin_tensors, round_robin_indices = self._round_robin_reorder(
self.fp16_groups[i],
dist.get_world_size(group=self.dp_process_group)
)
if self.round_robin_gradients:
round_robin_tensors, round_robin_indices = self._round_robin_reorder(
self.fp16_groups[i],
dist.get_world_size(group=self.dp_process_group)
)
else:
round_robin_tensors = self.fp16_groups[i]
round_robin_indices = list(range(len(self.fp16_groups[i])))

self.round_robin_fp16_groups.append(round_robin_tensors)
self.round_robin_fp6_indices.append(round_robin_indices)

Expand Down