From 4e2ed33ad99f9c4527176f146e420ab8eaf8ef6b Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Mon, 4 Dec 2023 11:23:07 -0800 Subject: [PATCH] Add interface to set NCCL options of each process group (#7923) Signed-off-by: Sangkug Lym Co-authored-by: Eric Harper --- examples/nlp/language_modeling/conf/megatron_gpt_config.yaml | 1 + nemo/collections/nlp/parts/megatron_trainer_builder.py | 1 + nemo/collections/nlp/parts/nlp_overrides.py | 5 ++++- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index bd34d54f5fd68..8a7dd689e9702 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -131,6 +131,7 @@ model: apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) sync_batch_comm: False # Enable stream synchronization after each p2p communication between pipeline stages + nccl_communicator_config_path: null # Path to the yaml file with NCCL communicator options (min_ctas, max_ctas, and cga_cluster_size) ## Activation Checkpointing # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index b2554a35cdbdc..69956129bdde6 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -53,6 +53,7 @@ def _training_strategy(self) -> NLPDDPStrategy: no_ddp_communication_hook=True, gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view, find_unused_parameters=False, + nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None), ) def _grad_scaler(self) -> GradScaler: diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 0d13c31d99656..82cdae381701e 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -81,6 +81,7 @@ class NLPDDPStrategy(DDPStrategy): Args: no_ddp_communication_hook: Disable DDP communication hook when using AMP-O2 with FP32 gradient accumulation. + nccl_communicator_config_path: Path to the yaml file with NCCL communicator options """ def __init__( @@ -89,6 +90,7 @@ def __init__( cluster_environment: ClusterEnvironment = None, checkpoint_io: Optional[CheckpointIO] = None, no_ddp_communication_hook: bool = False, + nccl_communicator_config_path: Optional[str] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: if not HAVE_APEX: @@ -103,6 +105,7 @@ def __init__( super().__init__(parallel_devices, cluster_environment, checkpoint_io, **kwargs) self.no_ddp_communication_hook = no_ddp_communication_hook + self.nccl_communicator_config_path = nccl_communicator_config_path def setup(self, trainer: "pl.Trainer") -> None: """ @@ -180,7 +183,6 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None: Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_devices - is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() @@ -196,6 +198,7 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None: pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, + nccl_communicator_config_path=self.nccl_communicator_config_path, ) # assert that fake tp and pp rank match after model parallel init