Skip to content

Commit

Permalink
docstr reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
suiyoubi committed Nov 7, 2024
1 parent 4afdd92 commit 6a09e8c
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class ParallelismConfig:
tensor_model_parallel_size: int
pipeline_model_parallel_size: int
virtual_pipeline_model_parallel_size: int
microbatch_group_size_per_vp_stage: int
context_parallel_size: int
sequence_parallel: bool
expert_model_parallel_size: int
Expand All @@ -114,6 +115,9 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
across GPU ranks. Defaults to 1.
virtual_pipeline_model_parallel_size (Optional[int]): Interleaved pipeline parallelism used to
improve performance by reducing the pipeline bubble. Defaults to None.
microbatch_group_size_per_vp_stage(Optional[int]): the number of micro-batches that are executed
at a time for a given virtual stage (both forward and backward). Defaults to None and convert
to pipeline_parallel_size. which specifies a depth-first schedule.
context_parallel_size (int): Splits network input along sequence dimension across GPU ranks.
Defaults to 1.
sequence_parallel (bool): Makes tensor parallelism more memory efficient for LLMs (20B+) by
Expand Down Expand Up @@ -174,6 +178,7 @@ def __init__(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
microbatch_group_size_per_vp_stage: Optional[int] = None,
context_parallel_size: int = 1,
sequence_parallel: bool = False,
expert_model_parallel_size: int = 1,
Expand Down Expand Up @@ -218,6 +223,7 @@ def __init__(
self.data_sampler: Optional["DataSampler"] = data_sampler
self.tensor_model_parallel_size = tensor_model_parallel_size
self.pipeline_model_parallel_size = pipeline_model_parallel_size
self.microbatch_group_size_per_vp_stage = microbatch_group_size_per_vp_stage if microbatch_group_size_per_vp_stage is not None else pipeline_model_parallel_size
self.context_parallel_size = context_parallel_size
self.expert_model_parallel_size = expert_model_parallel_size
self.moe_extended_tp = moe_extended_tp
Expand Down Expand Up @@ -816,6 +822,7 @@ def parallelism(self) -> ParallelismConfig:
tensor_model_parallel_size=self.tensor_model_parallel_size,
pipeline_model_parallel_size=self.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=self.virtual_pipeline_model_parallel_size,
microbatch_group_size_per_vp_stage=self.microbatch_group_size_per_vp_stage,
context_parallel_size=self.context_parallel_size,
sequence_parallel=self.sequence_parallel,
expert_model_parallel_size=self.expert_model_parallel_size,
Expand Down

0 comments on commit 6a09e8c

Please sign in to comment.