diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 5955276eda56..d23e57941aaf 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -24,6 +24,7 @@ import torch import torch.distributed +from megatron.core.distributed import DistributedDataParallelConfig from torch import Tensor, nn DataT = TypeVar("DataT", Tensor, Dict[str, Tensor], Sequence[Tensor]) @@ -105,6 +106,7 @@ def __init__( forward_step: Optional[Callable[[nn.Module, DataT], Tensor]] = None, loss_reduction: Optional[Callable[[nn.Module], "MegatronLossReduction"]] = None, vp_size: Optional[int] = None, + ddp_config: Optional[DistributedDataParallelConfig] = None, cpu: bool = False, ) -> None: from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes @@ -130,6 +132,23 @@ def __init__( _model.configure_model() _pipeline.append(_model) + if isinstance(ddp_config, DistributedDataParallelConfig): + from megatron.core.distributed import DistributedDataParallel as McoreDDP + + _pipeline = [ + McoreDDP( + model_chunk.config, + ddp_config, + model_chunk, + data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(), + # Turn off bucketing for model_chunk 2 onwards, since communication for these + # model chunks is overlapped with compute anyway. + disable_bucketing=(model_chunk_idx > 0), + ) + for (model_chunk_idx, model_chunk) in enumerate(_pipeline) + ] + for i, model_module in enumerate(_pipeline): if not cpu: model_module.cuda(torch.cuda.current_device()) @@ -162,6 +181,7 @@ def __init__( self.data_step = data_step or default_data_step self.forward_step = forward_step or default_forward_step self.loss_reduction: MegatronLossReduction = loss_reduction + self.ddp_config = ddp_config def forward( self, diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index c002ecf7fd68..8fa178d7df01 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -4,13 +4,14 @@ from collections import OrderedDict from contextlib import ExitStack from pathlib import Path -from typing import TYPE_CHECKING, Any, ContextManager, Dict, List, Mapping, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, ContextManager, Dict, List, Literal, Mapping, Optional, TypeVar, Union, cast import pytorch_lightning as pl import torch import torch.distributed from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment from lightning_fabric.utilities.optimizer import _optimizers_to_device +from megatron.core.distributed import DistributedDataParallelConfig from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.callbacks.progress import TQDMProgressBar from pytorch_lightning.loops import _AutomaticOptimization, evaluation_loop, fit_loop, prediction_loop @@ -38,6 +39,9 @@ ConfigT = TypeVar("ConfigT") +DDPLiteral = Literal["megatron", "pytorch"] + + class MegatronStrategy(DDPStrategy, io.IOMixin): """Megatron plugin for Pytorch Lightning. @@ -58,11 +62,11 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment=None, # TODO: Add type-hint checkpoint_io=None, # TODO: Add type-hint - no_ddp_communication_hook: bool = True, find_unused_parameters: bool = False, enable_nemo_ckpt_io: bool = True, ckpt_type: TrainerCkptProtocol = TrainerCheckpoint, ckpt_include_optimizer: bool = False, + ddp: Union[DDPLiteral, DistributedDataParallelConfig] = "megatron", lazy_init: bool = False, **kwargs, ) -> None: @@ -73,7 +77,7 @@ def __init__( find_unused_parameters=find_unused_parameters, **kwargs, ) - self.no_ddp_communication_hook = no_ddp_communication_hook + self.megatron_callbacks = CallbackConnector() self.data_sampler: Optional['DataSampler'] = data_sampler self.tensor_model_parallel_size = tensor_model_parallel_size @@ -85,6 +89,16 @@ def __init__( self.lazy_init = lazy_init self.ckpt_include_optimizer = ckpt_include_optimizer + if ddp == "megatron": + self.ddp_config = DistributedDataParallelConfig() + elif isinstance(ddp, DistributedDataParallelConfig): + self.ddp_config = ddp + elif ddp == "pytorch": + self.ddp_config = None + self.no_ddp_communication_hook = False + else: + raise ValueError(f"Invalid DDP type: {ddp}") + # used in NVIDIA NGC PyTorch containers _strategy_lib.enable_nvidia_optimizations() @@ -153,6 +167,9 @@ def setup(self, trainer: pl.Trainer) -> None: # set up optimizers after the wrapped module has been moved to the device self.setup_optimizers(trainer) + + # TODO: Throw an execption if we have a mcore optimizer and no ddp_config + if hasattr(self.precision_plugin, "convert_optimizer"): _optimizers = [*self.optimizers] _optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0]) @@ -204,6 +221,7 @@ def setup_megatron_parallel(self, trainer: pl.Trainer) -> None: precision_plugin=self.precision_plugin, vp_size=self.virtual_pipeline_model_parallel_size, cpu=isinstance(trainer.accelerator, CPUAccelerator), + ddp_config=self.ddp_config, ) self.model = self.megatron_parallel self.model.trainer = trainer @@ -212,6 +230,10 @@ def setup_megatron_parallel(self, trainer: pl.Trainer) -> None: self.model = self.precision_plugin.convert_module(self.model) self.model.callbacks.add(getattr(trainer, "callbacks")) + if hasattr(self, "optimizers") and self.optimizers: + for optimizer in self.optimizers: + self.model.callbacks.add(optimizer) + if self.data_sampler: self.model.callbacks.add(self.data_sampler) @@ -223,10 +245,11 @@ def setup_megatron_parallel(self, trainer: pl.Trainer) -> None: def configure_ddp(self) -> None: logging.debug(f"{self.__class__.__name__}: configuring MegatronParallel") self.model = self._setup_model(self.model) - self._register_ddp_hooks() + if self.ddp_config is None: + self._register_ddp_hooks() @override - def _setup_model(self, model: nn.Module) -> DistributedDataParallel: + def _setup_model(self, model: nn.Module) -> nn.Module: """Only called when we need to wrap the model for pytorch's ddp.""" from megatron.core import parallel_state @@ -236,16 +259,19 @@ def _setup_model(self, model: nn.Module) -> DistributedDataParallel: if app_state.model_parallel_size is not None: self._ddp_kwargs["process_group"] = parallel_state.get_data_parallel_group() - dist_data_parallel: DistributedDataParallel = super()._setup_model(model) - if self.no_ddp_communication_hook: - # When using custom gradient accumulation and allreduce, disable - # DDP communication hook that works on the gradient bucket. - # Instead, use the custom gradient function and communication hook, - # which is defined in the master optimizer wrapper. - dist_data_parallel.require_backward_grad_sync = False - dist_data_parallel.register_comm_hook(None, noop_hook) + # Only wrap the model if we are not using Megatron's DDP + if not self.ddp_config: + dist_data_parallel: DistributedDataParallel = super()._setup_model(model) + if self.no_ddp_communication_hook: + # When using custom gradient accumulation and allreduce, disable + # DDP communication hook that works on the gradient bucket. + # Instead, use the custom gradient function and communication hook, + # which is defined in the master optimizer wrapper. + dist_data_parallel.require_backward_grad_sync = False + dist_data_parallel.register_comm_hook(None, noop_hook) + model = dist_data_parallel - return dist_data_parallel + return model def _setup_parallel_ranks(self) -> None: self.set_world_ranks() @@ -260,7 +286,7 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "training") with self.precision_plugin.train_step_context(): # TODO: Do we need this? - return self.model(dataloader_iter, *args, **kwargs) + return self.model(dataloader_iter, forward_only=False, *args, **kwargs) @override def validation_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: @@ -269,7 +295,7 @@ def validation_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OU kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "validation") with self.precision_plugin.val_step_context(): # TODO: Do we need this? - return self.model(dataloader_iter, *args, **kwargs) + return self.model(dataloader_iter, forward_only=True, *args, **kwargs) @override def test_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: @@ -278,7 +304,7 @@ def test_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "test") with self.precision_plugin.test_step_context(): # TODO: Do we need this? - return self.model(dataloader_iter, *args, **kwargs) + return self.model(dataloader_iter, forward_only=True, *args, **kwargs) @override def predict_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: @@ -287,7 +313,7 @@ def predict_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPU kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "predict") with self.precision_plugin.predict_step_context(): # TODO: Do we need this? - return self.model(dataloader_iter, *args, **kwargs) + return self.model(dataloader_iter, forward_only=True, *args, **kwargs) @override def teardown(self) -> None: