Skip to content

Commit

Permalink
refactor group
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Sep 28, 2022
1 parent 2aec65b commit a60bbc5
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 59 deletions.
56 changes: 6 additions & 50 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,53 +53,10 @@
from .fleet.layers.mpu.mp_ops import _parallel_linear
from .fleet.layers.mpu.mp_ops import _parallel_embedding
from .communication.comm_utils import ReduceOp
from .communication.group import Group

__all__ = []


class Group():
"""
The abstract representation of group.
"""

def __init__(self, rank, rank_num, id=0, ranks=[], pg=None, name=None):
self.rank = rank
self.nranks = rank_num
self.id = id
self.ranks = ranks
self.pg = pg
self.name = name

def is_member(self):
if self.rank < 0:
return False
if self.nranks < 2:
return False
return True

def get_group_rank(self, rank):
if self.is_member() and rank in self.ranks:
return self.ranks.index(rank)
else:
return -1

@property
def process_group(self):
return self.pg

@property
def world_size(self):
return self.nranks if self.rank >= 0 else -1

def __repr__(self):
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
self.rank, self.nranks, self.id)
debug_str += ", ".join(map(str, self.ranks))
debug_str += "; name: "
debug_str += self.name if self.name else "None"
return debug_str


_global_env = None


Expand Down Expand Up @@ -147,9 +104,8 @@ def _get_group_map():
global _group_map
if _global_env_gid not in _group_map:
genv = _get_global_env()
_group_map[_global_env_gid] = Group(genv.rank,
genv.world_size,
ranks=list(range(genv.world_size)))
_group_map[_global_env_gid] = Group(genv.rank, 0,
list(range(genv.world_size)))
return _group_map


Expand Down Expand Up @@ -451,7 +407,7 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
else:
rank = -1
pg = None
group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name)
group = Group(rank, gid, ranks, pg=pg, name=group_name)
_group_map_by_name[group_name] = group
_group_map[gid] = group
_group_map_backend[group] = backend
Expand All @@ -476,13 +432,13 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
ring_id = _new_ring_id()

if global_rank not in ranks:
gp = Group(-1, -1, ring_id, ranks)
gp = Group(-1, ring_id, ranks)
_group_map[ring_id] = gp
else:
ranks = sorted(ranks)
group_rank = ranks.index(global_rank)
group_size = len(ranks)
gp = Group(group_rank, group_size, ring_id, ranks)
gp = Group(group_rank, ring_id, ranks)
_group_map[ring_id] = gp

if group_size >= 2:
Expand Down
76 changes: 76 additions & 0 deletions python/paddle/distributed/communication/group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class Group():
"""
The abstract representation of group.
"""

def __init__(self, group_rank, id, ranks, pg=None, name=None):
self._group_rank = group_rank
self._world_size = len(ranks) if group_rank >= 0 else -1
self._id = id
self._ranks = ranks
self._pg = pg
self._name = name

@property
def rank(self):
return self._group_rank

@property
def ranks(self):
return self._ranks

@property
def nranks(self):
return len(self._ranks)

@property
def name(self):
return self._name

@property
def process_group(self):
return self._pg

@property
def world_size(self):
return self._world_size

@property
def id(self):
return self._id

def is_member(self):
if self.rank < 0:
return False
if self.nranks < 2:
return False
return True

def get_group_rank(self, rank):
if self.is_member():
return self.ranks.index(rank)
else:
return -1

def __repr__(self):
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
self.rank, self.nranks, self.id)
debug_str += ", ".join(map(str, self.ranks))
debug_str += "; name: "
debug_str += self.name if self.name else "None"
return debug_str
4 changes: 2 additions & 2 deletions python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ def __init__(self):

def set_comm_group(self, group_name, group_rank, group_size, ring_id,
group_ranks):
group = paddle.distributed.collective.Group(group_rank, group_size,
ring_id, group_ranks)
group = paddle.distributed.collective.Group(group_rank, ring_id,
group_ranks)
self.groups[group_name] = group

def get_group(self, group_name):
Expand Down
7 changes: 1 addition & 6 deletions python/paddle/distributed/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,7 @@ def train():
_default_group_name,
pg_options=None)
ranks = list(range(world_size))
group = Group(rank,
world_size,
id=0,
ranks=ranks,
pg=pg,
name=_default_group_name)
group = Group(rank, 0, ranks, pg=pg, name=_default_group_name)
_set_group_map_by_name(_default_group_name, group)
_set_group_map(0, group)
_set_group_map_backend(group, backend)
Expand Down
1 change: 0 additions & 1 deletion python/paddle/incubate/distributed/models/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ class MoELayer(nn.Layer):
from paddle.distributed import fleet
moe_group = Group(fleet.worker_index(),
fleet.worker_num(),
0,
list(range(fleet.worker_num())))
mp_group = None
Expand Down

0 comments on commit a60bbc5

Please sign in to comment.