Skip to content

Commit 3f5165d

Browse files
committed
Add world-size getter in Engine
Signed-off-by: WoosungMyung <dntjd517@naver.com>
1 parent 1d7b90a commit 3f5165d

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

deepspeed/runtime/engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,10 @@ def get_tensor_parallel_rank(self):
748748
def get_model_parallel_rank(self):
749749
return groups.get_model_parallel_rank()
750750

751+
def get_parallel_world_sizes(self):
752+
"""Return a dict of parallel world sizes for data/tensor parallelism."""
753+
return {"dp": groups.get_data_parallel_world_size(), "tp": groups.get_tensor_model_parallel_world_size()}
754+
751755
def get_sequence_parallel_group(self):
752756
return self.seq_parallel_group
753757

deepspeed/runtime/pipe/engine.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,12 @@ def is_last_stage(self):
537537
def get_pipeline_parallel_rank(self):
538538
return self.stage_id
539539

540+
def get_parallel_world_sizes(self):
541+
"""Return a dict of parallel world sizes for data/tensor/pipeline parallelism."""
542+
sizes = super().get_parallel_world_sizes()
543+
sizes["pp"] = self.num_stages
544+
return sizes
545+
540546
def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True, micro_batches=None):
541547
if reduce is None:
542548
return outputs

0 commit comments

Comments
 (0)