diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 121c03153b7f..09cf25db61fc 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -59,7 +59,7 @@ try: from megatron.core import ModelParallelConfig, parallel_state - from megatron.core.distributed import DistributedDataParallel as DDP + from megatron.core.distributed import DistributedDataParallel as McoreDDP from megatron.core.transformer.module import Float16Module as MCoreFloat16Module from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import init_method_normal, scaled_init_method_normal @@ -148,13 +148,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): # set the megatron core model parallel config self.model_parallel_config: ModelParallelConfig = self.build_model_parallel_config() - self.with_distributed_adam = cfg.optim.get('name') == 'distributed_fused_adam' - self.use_mcore_dist_optim = cfg.optim.get('use_mcore_dist_optim', False) - if self.use_mcore_dist_optim: - assert ( - self.with_distributed_adam - ), "with_distributed_adam must be True when using mcore distributed optimizer" - + self.use_mcore_dist_optim = cfg.optim.get('name') == 'mcore_distributed_optim' + self.with_distributed_adam = cfg.optim.get('name') == 'distributed_fused_adam' or self.use_mcore_dist_optim self.with_megatron_fused_adam = cfg.optim.get('name') == 'megatron_fused_adam' # used in NVIDIA NGC PyTorch containers @@ -323,7 +318,7 @@ def _wrap_model_for_O2(self): def get_model_module_list(self): if isinstance(self.model, list): return [ - model.module if isinstance(model, (Float16Module, MCoreFloat16Module, DDP)) else model + model.module if isinstance(model, (Float16Module, MCoreFloat16Module, McoreDDP)) else model for model in self.model ] elif isinstance(self.model, (Float16Module, MCoreFloat16Module)): @@ -935,7 +930,10 @@ def _validate_and_override_config(self): # async grad allreduce. This should be fixed! # For now we must disable it whenever using the baseline implementaion. # The distributed adam from apex does work with gradient accumulation fusion. - distributed_fused_adam = self.cfg.optim.get('name', 'fused_adam') == 'distributed_fused_adam' + distributed_fused_adam = ( + self.cfg.optim.get('name', 'fused_adam') == 'distributed_fused_adam' + or self.cfg.optim.get('name', 'fused_adam') == 'mcore_distributed_optim' + ) pipeline_model_parallel_size = self.cfg.get('pipeline_model_parallel_size', 1) data_parallel_size = app_state.data_parallel_size diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 8496c0d85ded..c8d1067046be 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -91,7 +91,7 @@ from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject, ShardedObject - from megatron.core.distributed import DistributedDataParallel as DDP + from megatron.core.distributed import DistributedDataParallel as McoreDDP from megatron.core.distributed import DistributedDataParallelConfig, finalize_model_grads # NeMo's implementation of the get_gpt_layer_ammo_spec function is temporarily used @@ -508,16 +508,16 @@ def setup_mcore_distributed_parallel(self): if self.with_distributed_adam and self.use_mcore_dist_optim: config = get_model_config(self.model[0]) ddp_config = DistributedDataParallelConfig( - grad_reduce_in_fp32=self.megatron_amp_O2, - overlap_grad_reduce=self.cfg.optim.get('mcore_overlap_grad_sync', False), + grad_reduce_in_fp32=(self.cfg.optim.get('grad_sync_dtype', 'fp32') == 'fp32'), + overlap_grad_reduce=self.cfg.optim.get('overlap_grad_sync', False), use_distributed_optimizer=True, - check_for_nan_in_grad=False, + check_for_nan_in_grad=self.cfg.optim.get('check_for_nan_in_grad', False), # mcore bucket_size is based on num of parameters, therefore not - # using grad_allreduce_chunk_size_mb to configure bucket_size here - bucket_size=self.cfg.optim.get('mcore_ddp_bucket_size', None), + # using bucket_cap_mb to configure bucket_size here + bucket_size=self.cfg.optim.get('ddp_bucket_size', None), ) self.model = [ - DDP( + McoreDDP( config, ddp_config, model_chunk, @@ -643,16 +643,14 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): grad_sync_func = self.reduce_overlap_gradients param_sync_func = self.sync_overlap_parameters else: - if self.cfg.optim.get("mcore_overlap_grad_sync", False): + if self.cfg.optim.get("overlap_grad_sync", False): no_sync_func = [model_chunk.no_sync for model_chunk in self.model] no_sync_func = no_sync_func[0] if len(self.model) == 1 else no_sync_func - if self.cfg.optim.get("mcore_delay_grad_reduce", True): + if self.cfg.optim.get("delay_grad_reduce", True): grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.model] grad_sync_func = grad_sync_func[0] if len(self.model) == 1 else grad_sync_func - if self.cfg.optim.get("mcore_overlap_param_sync", False) and self.cfg.optim.get( - "mcore_delay_param_gather", False - ): + if self.cfg.optim.get("overlap_param_sync", False) and self.cfg.optim.get("delay_param_gather", False): param_sync_func = [ lambda x: self._optimizer.finish_param_sync(model_index, x) for model_index in range(len(self.model)) diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index 6b9763a53414..8c9b6a52080b 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -55,6 +55,7 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]: if self.cfg.model.get('fsdp', False): assert ( not self.cfg.model.optim.get('name') == 'distributed_fused_adam' + and not self.cfg.model.optim.get('name') == 'mcore_distributed_optim' ), 'Distributed optimizer cannot be used with FSDP.' sharded_checkpoint = self.cfg.model.get('fsdp_sharded_checkpoint', False) if self.cfg.model.get('tensor_model_parallel_size', 1) > 1: @@ -100,7 +101,12 @@ def _plugins(self) -> list: """ megatron_amp_O2 = self.cfg.model.get('megatron_amp_O2', False) with_distributed_adam = ( - self.cfg.model.optim.get('name') == 'distributed_fused_adam' if self.cfg.model.get('optim') else False + ( + self.cfg.model.optim.get('name') == 'distributed_fused_adam' + or self.cfg.model.optim.get('name') == 'mcore_distributed_optim' + ) + if self.cfg.model.get('optim') + else False ) plugins = [] diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 383025721328..f1691fff7c4b 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -592,8 +592,8 @@ def setup_megatron_optimization(self, optim_config: Union[Dict[str, Any], DictCo adam_beta2=optim_config['betas'][1], clip_grad=self.trainer.gradient_clip_val, use_distributed_optimizer=self.use_mcore_dist_optim, - overlap_grad_reduce=self.cfg.optim.get('mcore_overlap_grad_sync', False), - overlap_param_gather=self.cfg.optim.get('mcore_overlap_param_sync', False), + overlap_grad_reduce=self.cfg.optim.get('overlap_grad_sync', False), + overlap_param_gather=self.cfg.optim.get('overlap_param_sync', False), ) return megatron_optim_config @@ -678,8 +678,6 @@ def setup_optimization( if optimizer_cls is None: # Try to get optimizer name for dynamic resolution, defaulting to Adam optimizer_name = optim_config.get('name', 'adam') - if optimizer_name == "distributed_fused_adam" and self.use_mcore_dist_optim: - optimizer_name = "mcore_distributed_optim" else: if inspect.isclass(optimizer_cls): optimizer_name = optimizer_cls.__name__.lower()