diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 8a0147a4613a..f113d2378916 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -91,6 +91,7 @@ class ParallelismConfig: tensor_model_parallel_size: int pipeline_model_parallel_size: int virtual_pipeline_model_parallel_size: int + microbatch_group_size_per_vp_stage: int context_parallel_size: int sequence_parallel: bool expert_model_parallel_size: int @@ -114,6 +115,9 @@ class MegatronStrategy(DDPStrategy, io.IOMixin): across GPU ranks. Defaults to 1. virtual_pipeline_model_parallel_size (Optional[int]): Interleaved pipeline parallelism used to improve performance by reducing the pipeline bubble. Defaults to None. + microbatch_group_size_per_vp_stage(Optional[int]): the number of micro-batches that are executed + at a time for a given virtual stage (both forward and backward). Defaults to None and convert + to pipeline_parallel_size. which specifies a depth-first schedule. context_parallel_size (int): Splits network input along sequence dimension across GPU ranks. Defaults to 1. sequence_parallel (bool): Makes tensor parallelism more memory efficient for LLMs (20B+) by @@ -174,6 +178,7 @@ def __init__( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, virtual_pipeline_model_parallel_size: Optional[int] = None, + microbatch_group_size_per_vp_stage: Optional[int] = None, context_parallel_size: int = 1, sequence_parallel: bool = False, expert_model_parallel_size: int = 1, @@ -218,6 +223,7 @@ def __init__( self.data_sampler: Optional["DataSampler"] = data_sampler self.tensor_model_parallel_size = tensor_model_parallel_size self.pipeline_model_parallel_size = pipeline_model_parallel_size + self.microbatch_group_size_per_vp_stage = microbatch_group_size_per_vp_stage if microbatch_group_size_per_vp_stage is not None else pipeline_model_parallel_size self.context_parallel_size = context_parallel_size self.expert_model_parallel_size = expert_model_parallel_size self.moe_extended_tp = moe_extended_tp @@ -816,6 +822,7 @@ def parallelism(self) -> ParallelismConfig: tensor_model_parallel_size=self.tensor_model_parallel_size, pipeline_model_parallel_size=self.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=self.virtual_pipeline_model_parallel_size, + microbatch_group_size_per_vp_stage=self.microbatch_group_size_per_vp_stage, context_parallel_size=self.context_parallel_size, sequence_parallel=self.sequence_parallel, expert_model_parallel_size=self.expert_model_parallel_size,