diff --git a/Jenkinsfile b/Jenkinsfile index f2e0704a0ea5..cb614799eb9d 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -3604,12 +3604,14 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' trainer.limit_val_batches=2 \ trainer.accumulate_grad_batches=1 \ trainer.max_steps=3 \ - trainer.precision=16 \ + trainer.precision=bf16 \ trainer.gradient_clip_val=1.0 \ exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ model.pipeline_model_parallel_size=2 \ model.tensor_model_parallel_size=1 \ - model.optim.name=fused_adam \ + model.mcore_gpt=True \ + model.megatron_amp_O2=True \ + model.optim.name=distributed_fused_adam \ model.optim.lr=2e-4 \ model.optim.sched.warmup_steps=1 \ model.optim.sched.constant_steps=1 \ @@ -3639,13 +3641,15 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' trainer.limit_val_batches=2 \ trainer.accumulate_grad_batches=1 \ trainer.max_steps=6 \ - trainer.precision=16 \ + trainer.precision=bf16 \ trainer.gradient_clip_val=1.0 \ + model.mcore_gpt=True \ + model.megatron_amp_O2=True \ exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ exp_manager.resume_if_exists=True \ model.pipeline_model_parallel_size=2 \ model.tensor_model_parallel_size=1 \ - model.optim.name=fused_adam \ + model.optim.name=distributed_fused_adam \ model.optim.lr=2e-4 \ model.optim.sched.warmup_steps=2 \ model.optim.sched.constant_steps=2 \ diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 2ce65c17ee3e..994bc9ca9479 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1560,3 +1560,36 @@ def build_transformer_config(self) -> TransformerConfig: setattr(transformer_config, key, value) return transformer_config + + def _wrap_model_for_O2(self): + """ Wraps self.model in a float16 wrapper if the model is using megatron amp O2. + Args: + model: The model to wrap. Can be a list of modules or a single module. + Returns: + The wrapped model. Returns a list of wrapped modules or a single wrapped module. + """ + Float16Wrapper = MCoreFloat16Module if self.mcore_gpt else Float16Module + + nemo_args = { + 'config': self.model_parallel_config, + 'precision': self.cfg.precision, + 'share_token_embeddings': self.cfg.get('share_embeddings_and_output_weights', True), + } + mcore_args = { + 'config': self.transformer_config, + } + + args = mcore_args if self.mcore_gpt else nemo_args + + # Model wrapper to convert both model and inputs to half precision + if isinstance(self.model, list): + converted_model = [] + for module in self.model: + args['module'] = module + converted_model.append(Float16Wrapper(**args)) + self.model = converted_model + else: + args['module'] = self.model + self.model = Float16Wrapper(**args) + + args.pop('module') diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_peft_models.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_peft_models.py index f985f99218e8..91928a86253d 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_peft_models.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_peft_models.py @@ -189,7 +189,7 @@ def on_load_checkpoint(self, checkpoint) -> None: # mcore uses distributed checkpointing print('enter peft loading') if self.mcore_gpt: - for index, module in enumerate(self.get_gpt_module_list()): + for index, module in enumerate(self.get_model_module_list()): if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: checkpoint_state_dict = checkpoint['state_dict'][f'model_{index}'] else: diff --git a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py index 853ffc6ea012..8d626b2c2b7c 100644 --- a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py +++ b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py @@ -376,7 +376,7 @@ def on_load_checkpoint(self, checkpoint) -> None: # same as super().on_load_checkpoint() but strict=False and only check unexpected keys # mcore uses distributed checkpointing if hasattr(self, 'mcore_gpt') and self.mcore_gpt: - for index, module in enumerate(self.get_gpt_module_list()): + for index, module in enumerate(self.get_model_module_list()): if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: checkpoint_state_dict = checkpoint['state_dict'][f'model_{index}'] else: