Skip to content

Commit

Permalink
check if grad is none before calling all_reduce (#6428)
Browse files Browse the repository at this point in the history
  • Loading branch information
arendu authored Apr 15, 2023
1 parent ae55b52 commit 5c52406
Showing 1 changed file with 14 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,11 @@ def _append_sequence_parallel_module_grads(self, module, grads):

for param in module.parameters():
sequence_parallel_param = getattr(param, 'sequence_parallel', False)
if sequence_parallel_param:
# (@adithyare) adapter training now extends MegatronGPTModel
# so we have to add this check here to ensure we do not
# perform all_reduce when grad is None.
# grad can be None when performing PeFT training.
if sequence_parallel_param and param.requires_grad:
if self.megatron_amp_o2:
grad = param.main_grad
else:
Expand Down Expand Up @@ -504,12 +508,15 @@ def allreduce_first_last_embeddings(self):
module = self.model
if module.share_token_embeddings:
word_embeddings_weight = module.word_embeddings_weight()
if self.megatron_amp_o2:
# O2 recipe stores a "main" copy of weights and grads
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
# (@adithyare) adapter training now extends MegatronGPTModel so we have to add this check here to ensure we do not perform all_reduce when grad is None.
# grad can be None when performing PeFT training.
if word_embeddings_weight.requires_grad:
if self.megatron_amp_o2:
# O2 recipe stores a "main" copy of weights and grads
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())

def get_forward_output_and_loss_func(self, validation_step=False):
def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None):
Expand Down

0 comments on commit 5c52406

Please sign in to comment.