Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HybridParallel]Support 1f1b for PipelineParallel #34483

Merged
merged 14 commits into from
Aug 2, 2021
63 changes: 42 additions & 21 deletions python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def __init__(self, topology):
self.is_first_stage = (self.stage_id == 0)
self.is_last_stage = (self.stage_id == (self._pp_degree - 1))

# create p2p_groups
if self._pp_degree > 1:
self._set_p2p_group()

debug_str = "HybridParallelInfo: rank_id: %d, mp_degree: %d, " \
"sharding_degree: %d, pp_degree: %d, dp_degree: %d" % (self.global_rank, self._mp_degree,
self._sharding_degree, self._pp_degree, self._dp_degree)
Expand All @@ -164,27 +168,9 @@ def __init__(self, topology):
self._dp_group, self._check_group)
logger.info(debug_str)

# create p2p_groups and no new group
self._p2p_groups = self._build_p2p_lists()

global _HYBRID_PARALLEL_GROUP
_HYBRID_PARALLEL_GROUP = self

def _build_p2p_lists(self):
comm_lists = self._topo.get_comm_list('pipe')
p2p_lists = []
for rank in range(self.nranks):
for comm_ranks in comm_lists:
assert len(comm_ranks) == self._pp_degree
if rank in comm_ranks:
idx = comm_ranks.index(rank)
next_rank = comm_ranks[(idx + 1) % self._pp_degree]
p2p_lists.append([rank, next_rank])
break
assert len(
p2p_lists) == self.nranks, "len(p2p_lists) should be equal nranks"
return p2p_lists

def get_parallel_mode(self):
# there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel
# NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and
Expand Down Expand Up @@ -236,6 +222,41 @@ def _set_check_group(self, parallel_method="data"):

return parallel_group, parallel_comm_group

def _set_p2p_group(self):
comm_lists = self._topo.get_comm_list('pipe')

self.send_next_group = None
self.send_prev_group = None
self.recv_next_group = None
self.recv_prev_group = None

for comm_ranks in comm_lists:
assert len(comm_ranks) == self._pp_degree
for idx, rank in enumerate(comm_ranks):
curr_rank = rank
next_rank = comm_ranks[(idx + 1) % self._pp_degree]
prev_rank = comm_ranks[(idx - 1) % self._pp_degree]

next_group = paddle.distributed.new_group(
ranks=[curr_rank, next_rank])
if self.global_rank == curr_rank:
self.send_next_group = next_group
elif self.global_rank == next_rank:
self.recv_prev_group = next_group

prev_group = paddle.distributed.new_group(
ranks=[prev_rank, curr_rank])

if self.global_rank == curr_rank:
self.send_prev_group = prev_group
elif self.global_rank == prev_rank:
self.recv_next_group = prev_group

assert self.send_next_group is not None
assert self.send_prev_group is not None
assert self.recv_next_group is not None
assert self.recv_prev_group is not None

def topology(self):
return self._topo

Expand Down Expand Up @@ -287,6 +308,9 @@ def get_pipe_parallel_world_size(self):
def get_pipe_parallel_group(self):
return self._pp_comm_group

def get_p2p_groups(self):
return self.send_next_group, self.send_prev_group, self.recv_next_group, self.recv_prev_group

# sharding parallel message:
def _get_sharding_parallel_id(self):
return self._topo.get_coord(self.global_rank).sharding
Expand All @@ -304,9 +328,6 @@ def get_sharding_parallel_group_src_rank(self):
# TODO should the src rank related to the shard rank for each parameter ?
return self._sharding_comm_group.ranks[0]

def get_p2p_groups(self):
return self._p2p_groups

# check parallel group
def get_check_parallel_group(self):
return self._check_comm_group
Expand Down
Loading