From 5c524061f0ffdc51b578c0b24733adce03741f57 Mon Sep 17 00:00:00 2001 From: Adi Renduchintala <108822655+arendu@users.noreply.github.com> Date: Fri, 14 Apr 2023 19:35:57 -0700 Subject: [PATCH] check if grad is none before calling all_reduce (#6428) --- .../language_modeling/megatron_gpt_model.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 85ef1c7c2584..bf5799ea53c2 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -456,7 +456,11 @@ def _append_sequence_parallel_module_grads(self, module, grads): for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) - if sequence_parallel_param: + # (@adithyare) adapter training now extends MegatronGPTModel + # so we have to add this check here to ensure we do not + # perform all_reduce when grad is None. + # grad can be None when performing PeFT training. + if sequence_parallel_param and param.requires_grad: if self.megatron_amp_o2: grad = param.main_grad else: @@ -504,12 +508,15 @@ def allreduce_first_last_embeddings(self): module = self.model if module.share_token_embeddings: word_embeddings_weight = module.word_embeddings_weight() - if self.megatron_amp_o2: - # O2 recipe stores a "main" copy of weights and grads - grad = word_embeddings_weight.main_grad - else: - grad = word_embeddings_weight.grad - torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) + # (@adithyare) adapter training now extends MegatronGPTModel so we have to add this check here to ensure we do not perform all_reduce when grad is None. + # grad can be None when performing PeFT training. + if word_embeddings_weight.requires_grad: + if self.megatron_amp_o2: + # O2 recipe stores a "main" copy of weights and grads + grad = word_embeddings_weight.main_grad + else: + grad = word_embeddings_weight.grad + torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) def get_forward_output_and_loss_func(self, validation_step=False): def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None):