From 7798d5aa97bc00af2b930430dcf0e2cc92a21ec1 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu Date: Thu, 28 Nov 2024 08:24:12 +0200 Subject: [PATCH 1/4] Zero2: avoid graph breaks in torch.compile by using param_idx inside reduce_independent_p_g_buckets_and_remove_grads and in reduce_ipg_grads which are being executed during the BWD hook in zero2, the model param is being stored inside params_in_ipg_bucket. torch.compile has hard time tracing parameters. By using the param's static index inside the group the same logic can be maintain with less complexity. --- deepspeed/runtime/zero/stage_1_and_2.py | 8 +++++--- tests/unit/moe/test_moe.py | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 7ac89a233808..3b1596badf31 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -310,6 +310,7 @@ def __init__(self, for param in param_group['params']: if param.requires_grad: param.grad_accum = None + param.param_idx_in_group = len(trainable_parameters) trainable_parameters.append(param) self.bit16_groups.append(trainable_parameters) @@ -961,7 +962,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient" self.grads_in_ipg_bucket.append(grad_reduc) - self.params_in_ipg_bucket.append((i, param, param_id)) + self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id)) #make sure the average tensor function knows how to average the gradients if is_moe_param(param): @@ -1067,7 +1068,7 @@ def average_tensor(self, tensor): process_group = self.dp_process_group # count = 0 - for i, param, param_id in self.params_in_ipg_bucket: + for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: process_group = self.dp_process_group @@ -1383,7 +1384,8 @@ def reduce_ipg_grads(self): stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - for _, param, param_id in self.params_in_ipg_bucket: + for group_idx, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[group_idx][param_idx_in_group] assert self.params_already_reduced[param_id] == False, \ f"The parameter {param_id} has already been reduced. \ diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index 9ee546437f6c..c67a907c6785 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -93,7 +93,8 @@ def strict_average_tensor(tensor): process_group = optimizer.dp_process_group curr_size = 0 pg_offsets = [] - for i, param, param_id in optimizer.params_in_ipg_bucket: + for i, param_idx, param_id in optimizer.params_in_ipg_bucket: + param = optimizer.bit16_groups[i][param_idx] process_group = optimizer.dp_process_group if optimizer.ipg_bucket_has_moe_params: process_group = optimizer.expert_dp_process_group[param.group_name] if is_moe_param( From 7d1c883962685b8de68a2340fd4d1ccf1c39850e Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Wed, 11 Dec 2024 11:56:45 +0200 Subject: [PATCH 2/4] Update deepspeed/runtime/zero/stage_1_and_2.py assign param according to the param idx Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/zero/stage_1_and_2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 3b1596badf31..457fdd291bc6 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1069,6 +1069,7 @@ def average_tensor(self, tensor): process_group = self.dp_process_group # count = 0 for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[group_idx][param_idx_in_group] process_group = self.dp_process_group From f7e8d53df1a97a9cfeae3ed417b2209bec7be656 Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Mon, 16 Dec 2024 14:52:09 -0800 Subject: [PATCH 3/4] Formatting, fix indent --- deepspeed/runtime/zero/stage_1_and_2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 457fdd291bc6..c4163d6a850f 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1069,7 +1069,7 @@ def average_tensor(self, tensor): process_group = self.dp_process_group # count = 0 for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: - param = self.bit16_groups[group_idx][param_idx_in_group] + param = self.bit16_groups[group_idx][param_idx_in_group] process_group = self.dp_process_group From 7606795a2c49c28c275441edc5a81ead2325e4a5 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:51:46 +0200 Subject: [PATCH 4/4] Update stage_1_and_2.py fix index value to 'i' instead of 'group_idx' --- deepspeed/runtime/zero/stage_1_and_2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index c4163d6a850f..ecb2a527f870 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1069,7 +1069,7 @@ def average_tensor(self, tensor): process_group = self.dp_process_group # count = 0 for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: - param = self.bit16_groups[group_idx][param_idx_in_group] + param = self.bit16_groups[i][param_idx_in_group] process_group = self.dp_process_group