Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Gao Deng <gdeng@nvidia.com>
  • Loading branch information
gdengk committed Apr 26, 2024
1 parent 15d01f3 commit 106d503
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@

try:
from megatron.core import ModelParallelConfig, parallel_state
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed import DistributedDataParallel as McoreDDP
from megatron.core.transformer.module import Float16Module as MCoreFloat16Module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import init_method_normal, scaled_init_method_normal
Expand Down Expand Up @@ -148,13 +148,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
# set the megatron core model parallel config
self.model_parallel_config: ModelParallelConfig = self.build_model_parallel_config()

self.with_distributed_adam = cfg.optim.get('name') == 'distributed_fused_adam'
self.use_mcore_dist_optim = cfg.optim.get('use_mcore_dist_optim', False)
if self.use_mcore_dist_optim:
assert (
self.with_distributed_adam
), "with_distributed_adam must be True when using mcore distributed optimizer"

self.use_mcore_dist_optim = cfg.optim.get('name') == 'mcore_distributed_optim'
self.with_distributed_adam = cfg.optim.get('name') == 'distributed_fused_adam' or self.use_mcore_dist_optim
self.with_megatron_fused_adam = cfg.optim.get('name') == 'megatron_fused_adam'

# used in NVIDIA NGC PyTorch containers
Expand Down Expand Up @@ -323,7 +318,7 @@ def _wrap_model_for_O2(self):
def get_model_module_list(self):
if isinstance(self.model, list):
return [
model.module if isinstance(model, (Float16Module, MCoreFloat16Module, DDP)) else model
model.module if isinstance(model, (Float16Module, MCoreFloat16Module, McoreDDP)) else model
for model in self.model
]
elif isinstance(self.model, (Float16Module, MCoreFloat16Module)):
Expand Down Expand Up @@ -935,7 +930,10 @@ def _validate_and_override_config(self):
# async grad allreduce. This should be fixed!
# For now we must disable it whenever using the baseline implementaion.
# The distributed adam from apex does work with gradient accumulation fusion.
distributed_fused_adam = self.cfg.optim.get('name', 'fused_adam') == 'distributed_fused_adam'
distributed_fused_adam = (
self.cfg.optim.get('name', 'fused_adam') == 'distributed_fused_adam'
or self.cfg.optim.get('name', 'fused_adam') == 'mcore_distributed_optim'
)
pipeline_model_parallel_size = self.cfg.get('pipeline_model_parallel_size', 1)
data_parallel_size = app_state.data_parallel_size

Expand Down
22 changes: 10 additions & 12 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace
from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject, ShardedObject
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed import DistributedDataParallel as McoreDDP
from megatron.core.distributed import DistributedDataParallelConfig, finalize_model_grads

# NeMo's implementation of the get_gpt_layer_ammo_spec function is temporarily used
Expand Down Expand Up @@ -508,16 +508,16 @@ def setup_mcore_distributed_parallel(self):
if self.with_distributed_adam and self.use_mcore_dist_optim:
config = get_model_config(self.model[0])
ddp_config = DistributedDataParallelConfig(
grad_reduce_in_fp32=self.megatron_amp_O2,
overlap_grad_reduce=self.cfg.optim.get('mcore_overlap_grad_sync', False),
grad_reduce_in_fp32=(self.cfg.optim.get('grad_sync_dtype', 'fp32') == 'fp32'),
overlap_grad_reduce=self.cfg.optim.get('overlap_grad_sync', False),
use_distributed_optimizer=True,
check_for_nan_in_grad=False,
check_for_nan_in_grad=self.cfg.optim.get('check_for_nan_in_grad', False),
# mcore bucket_size is based on num of parameters, therefore not
# using grad_allreduce_chunk_size_mb to configure bucket_size here
bucket_size=self.cfg.optim.get('mcore_ddp_bucket_size', None),
# using bucket_cap_mb to configure bucket_size here
bucket_size=self.cfg.optim.get('ddp_bucket_size', None),
)
self.model = [
DDP(
McoreDDP(
config,
ddp_config,
model_chunk,
Expand Down Expand Up @@ -643,16 +643,14 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None):
grad_sync_func = self.reduce_overlap_gradients
param_sync_func = self.sync_overlap_parameters
else:
if self.cfg.optim.get("mcore_overlap_grad_sync", False):
if self.cfg.optim.get("overlap_grad_sync", False):
no_sync_func = [model_chunk.no_sync for model_chunk in self.model]
no_sync_func = no_sync_func[0] if len(self.model) == 1 else no_sync_func

if self.cfg.optim.get("mcore_delay_grad_reduce", True):
if self.cfg.optim.get("delay_grad_reduce", True):
grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.model]
grad_sync_func = grad_sync_func[0] if len(self.model) == 1 else grad_sync_func
if self.cfg.optim.get("mcore_overlap_param_sync", False) and self.cfg.optim.get(
"mcore_delay_param_gather", False
):
if self.cfg.optim.get("overlap_param_sync", False) and self.cfg.optim.get("delay_param_gather", False):
param_sync_func = [
lambda x: self._optimizer.finish_param_sync(model_index, x)

Check failure

Code scanning / CodeQL

Loop variable capture Error

Capture of loop variable
model_index
.
for model_index in range(len(self.model))
Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]:
if self.cfg.model.get('fsdp', False):
assert (
not self.cfg.model.optim.get('name') == 'distributed_fused_adam'
and not self.cfg.model.optim.get('name') == 'mcore_distributed_optim'
), 'Distributed optimizer cannot be used with FSDP.'
sharded_checkpoint = self.cfg.model.get('fsdp_sharded_checkpoint', False)
if self.cfg.model.get('tensor_model_parallel_size', 1) > 1:
Expand Down Expand Up @@ -100,7 +101,12 @@ def _plugins(self) -> list:
"""
megatron_amp_O2 = self.cfg.model.get('megatron_amp_O2', False)
with_distributed_adam = (
self.cfg.model.optim.get('name') == 'distributed_fused_adam' if self.cfg.model.get('optim') else False
(
self.cfg.model.optim.get('name') == 'distributed_fused_adam'
or self.cfg.model.optim.get('name') == 'mcore_distributed_optim'
)
if self.cfg.model.get('optim')
else False
)

plugins = []
Expand Down
6 changes: 2 additions & 4 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,8 @@ def setup_megatron_optimization(self, optim_config: Union[Dict[str, Any], DictCo
adam_beta2=optim_config['betas'][1],
clip_grad=self.trainer.gradient_clip_val,
use_distributed_optimizer=self.use_mcore_dist_optim,
overlap_grad_reduce=self.cfg.optim.get('mcore_overlap_grad_sync', False),
overlap_param_gather=self.cfg.optim.get('mcore_overlap_param_sync', False),
overlap_grad_reduce=self.cfg.optim.get('overlap_grad_sync', False),
overlap_param_gather=self.cfg.optim.get('overlap_param_sync', False),
)
return megatron_optim_config

Expand Down Expand Up @@ -678,8 +678,6 @@ def setup_optimization(
if optimizer_cls is None:
# Try to get optimizer name for dynamic resolution, defaulting to Adam
optimizer_name = optim_config.get('name', 'adam')
if optimizer_name == "distributed_fused_adam" and self.use_mcore_dist_optim:
optimizer_name = "mcore_distributed_optim"
else:
if inspect.isclass(optimizer_cls):
optimizer_name = optimizer_cls.__name__.lower()
Expand Down

0 comments on commit 106d503

Please sign in to comment.