diff --git a/Jenkinsfile b/Jenkinsfile index 663e83737026..0fc492961c61 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -4291,6 +4291,42 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' sh "rm -rf /home/TestData/nlp/lora_tuning_tp2" } } + stage('L2: Megatron GPT PEFT Lora TP=2 SP') { + when { + anyOf { + branch 'main' + changeRequest target: 'main' + } + } + failFast true + steps { + sh "rm -rf /home/TestData/nlp/lora_tuning_tp2_sp" + sh "python examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \ + trainer.devices=2 \ + trainer.log_every_n_steps=1 \ + trainer.max_epochs=9999 \ + trainer.max_steps=3 \ + trainer.val_check_interval=3 \ + ++trainer.limit_val_batches=2 \ + trainer.precision=16 \ + exp_manager.exp_dir=/home/TestData/nlp/lora_tuning_tp2 \ + model.pipeline_model_parallel_size=1 \ + model.tensor_model_parallel_size=2 \ + model.sequence_parallel=true \ + model.restore_from_path=/home/TestData/nlp/megatron_gpt/TP2/megatron_gpt_tp2.nemo \ + model.peft.peft_scheme='lora' \ + model.answer_only_loss=True \ + model.micro_batch_size=1 \ + model.global_batch_size=1 \ + model.data.train_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ + model.data.train_ds.concat_sampling_probabilities=[1.0] \ + model.data.train_ds.num_workers=0 \ + model.data.validation_ds.num_workers=0 \ + model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ + model.data.validation_ds.names=[quarel]" + sh "rm -rf /home/TestData/nlp/lora_tuning_tp2_sp" + } + } stage('L2: Megatron GPT Eval') { when { anyOf { diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 87dfc37b8b86..1da9f5df1e9a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -755,7 +755,7 @@ def _append_sequence_parallel_module_grads(self, module, grads): # (@adithyare) adapter training now extends MegatronGPTModel # so we have to add this check here to ensure we do not # perform all_reduce when grad is None. - # grad can be None when performing PeFT training. + # grad can be None when performing PEFT training. if sequence_parallel_param and param.requires_grad: if self.megatron_amp_O2: grad = param.main_grad @@ -775,7 +775,9 @@ def allreduce_sequence_parallel_gradients(self): self._append_sequence_parallel_module_grads(module, grads) else: self._append_sequence_parallel_module_grads(self.model, grads) - + if not grads: + # may be empty for PEFT training + return coalesced = torch._utils._flatten_dense_tensors(grads) torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group()) for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index d57d40b5c581..ac85ea7a1d2e 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -43,6 +43,10 @@ try: from megatron.core import ModelParallelConfig from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear + from megatron.core.tensor_parallel.mappings import ( + gather_from_sequence_parallel_region, + scatter_to_sequence_parallel_region, + ) HAVE_MEGATRON_CORE = True @@ -147,11 +151,13 @@ def __init__( self.activation = activation_registry[activation]() self.norm_position = norm_position self.dim = dim - + self.input_is_parallel = input_is_parallel # megatron_gpt_peft_models will provide this arg, but deprecated ones do not. # in case this arg is not provided, use the dummy default config. if model_parallel_config is None: model_parallel_config = ModelParallelConfig() + self._sequence_parallel = model_parallel_config.sequence_parallel + model_parallel_config.sequence_parallel = False # SP is irrelevant for the lora linear layer if input_is_parallel: self.linear_in = RowParallelLinear( @@ -219,6 +225,9 @@ def __init__( # Setup adapter strategy self.setup_adapter_strategy(adapter_mixin_strategies.ReturnResultAdapterStrategy()) + # revert config change in case it is read elsewhere + model_parallel_config.sequence_parallel = self._sequence_parallel + def _get_init_fn(self, init_method: str): if init_method == 'xavier': init_fn = init.xavier_normal_ @@ -240,10 +249,24 @@ def forward(self, x): if self.norm_position == 'pre': x = self.layer_norm(x) + if self._sequence_parallel and not self.input_is_parallel: + # for attention_qkv and linear_fc1 + # layernorm before lora is impacted by sequence parallel, + # hence seq dim need to be gathered right before lora linear layers + # this function also handles the backward pass correctly + x = gather_from_sequence_parallel_region(x) x, _ = self.linear_in(x) # (@adithyare) ColumnLinear returns output and bias, we are ignoring the bias term. x = self.activation(x) x, _ = self.linear_out(x) + + if self._sequence_parallel and self.input_is_parallel: + # for attention_dense and linear_fc2 + # layernorm after lora is impacted by sequence parallel, + # hence seq dim need to be scattered right after lora linear layers + # this function also handles the backward pass correctly + x = scatter_to_sequence_parallel_region(x) + if self.norm_position == 'post': x = self.layer_norm(x) diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index 1d365723ebda..815ad4d9e952 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -16,6 +16,8 @@ from omegaconf import DictConfig +from nemo.utils import logging + try: from nemo.collections.nlp.modules.common.megatron.adapters.mcore_mixins import ( MCoreGPTEmbeddingMixin, @@ -148,6 +150,12 @@ def __init__(self, cfg): ) name_key_to_cfg[AdapterName.LORA_4HtoH_ADAPTER] = adapter_cfg name_key_to_mcore_mixins[AdapterName.LORA_4HtoH_ADAPTER] = [("mlp", MCoreMLPMixin)] + else: + logging.error( + f"Unrecognized target_module string: {module}.\n" + f"The possible options are: {list(PEFT_MODULE_MAP.values())}" + ) + exit(1) self.name_key_to_mcore_mixins = name_key_to_mcore_mixins super().__init__(lora_cfg, name_key_to_cfg)