From a60bbc5a3e497ba10167e9c39e74e44025d82f5b Mon Sep 17 00:00:00 2001 From: LiYuRio Date: Wed, 21 Sep 2022 15:43:52 +0800 Subject: [PATCH] refactor group --- python/paddle/distributed/collective.py | 56 ++------------ .../paddle/distributed/communication/group.py | 76 +++++++++++++++++++ .../paddle/distributed/fleet/base/topology.py | 4 +- python/paddle/distributed/parallel.py | 7 +- .../distributed/models/moe/moe_layer.py | 1 - 5 files changed, 85 insertions(+), 59 deletions(-) create mode 100644 python/paddle/distributed/communication/group.py diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 45d6b00652811..e79d736421f68 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -53,53 +53,10 @@ from .fleet.layers.mpu.mp_ops import _parallel_linear from .fleet.layers.mpu.mp_ops import _parallel_embedding from .communication.comm_utils import ReduceOp +from .communication.group import Group __all__ = [] - -class Group(): - """ - The abstract representation of group. - """ - - def __init__(self, rank, rank_num, id=0, ranks=[], pg=None, name=None): - self.rank = rank - self.nranks = rank_num - self.id = id - self.ranks = ranks - self.pg = pg - self.name = name - - def is_member(self): - if self.rank < 0: - return False - if self.nranks < 2: - return False - return True - - def get_group_rank(self, rank): - if self.is_member() and rank in self.ranks: - return self.ranks.index(rank) - else: - return -1 - - @property - def process_group(self): - return self.pg - - @property - def world_size(self): - return self.nranks if self.rank >= 0 else -1 - - def __repr__(self): - debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format( - self.rank, self.nranks, self.id) - debug_str += ", ".join(map(str, self.ranks)) - debug_str += "; name: " - debug_str += self.name if self.name else "None" - return debug_str - - _global_env = None @@ -147,9 +104,8 @@ def _get_group_map(): global _group_map if _global_env_gid not in _group_map: genv = _get_global_env() - _group_map[_global_env_gid] = Group(genv.rank, - genv.world_size, - ranks=list(range(genv.world_size))) + _group_map[_global_env_gid] = Group(genv.rank, 0, + list(range(genv.world_size))) return _group_map @@ -451,7 +407,7 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout): else: rank = -1 pg = None - group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name) + group = Group(rank, gid, ranks, pg=pg, name=group_name) _group_map_by_name[group_name] = group _group_map[gid] = group _group_map_backend[group] = backend @@ -476,13 +432,13 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout): ring_id = _new_ring_id() if global_rank not in ranks: - gp = Group(-1, -1, ring_id, ranks) + gp = Group(-1, ring_id, ranks) _group_map[ring_id] = gp else: ranks = sorted(ranks) group_rank = ranks.index(global_rank) group_size = len(ranks) - gp = Group(group_rank, group_size, ring_id, ranks) + gp = Group(group_rank, ring_id, ranks) _group_map[ring_id] = gp if group_size >= 2: diff --git a/python/paddle/distributed/communication/group.py b/python/paddle/distributed/communication/group.py new file mode 100644 index 0000000000000..e9094ba514554 --- /dev/null +++ b/python/paddle/distributed/communication/group.py @@ -0,0 +1,76 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Group(): + """ + The abstract representation of group. + """ + + def __init__(self, group_rank, id, ranks, pg=None, name=None): + self._group_rank = group_rank + self._world_size = len(ranks) if group_rank >= 0 else -1 + self._id = id + self._ranks = ranks + self._pg = pg + self._name = name + + @property + def rank(self): + return self._group_rank + + @property + def ranks(self): + return self._ranks + + @property + def nranks(self): + return len(self._ranks) + + @property + def name(self): + return self._name + + @property + def process_group(self): + return self._pg + + @property + def world_size(self): + return self._world_size + + @property + def id(self): + return self._id + + def is_member(self): + if self.rank < 0: + return False + if self.nranks < 2: + return False + return True + + def get_group_rank(self, rank): + if self.is_member(): + return self.ranks.index(rank) + else: + return -1 + + def __repr__(self): + debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format( + self.rank, self.nranks, self.id) + debug_str += ", ".join(map(str, self.ranks)) + debug_str += "; name: " + debug_str += self.name if self.name else "None" + return debug_str diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 0f72bfe9be28d..b841542312ef8 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -377,8 +377,8 @@ def __init__(self): def set_comm_group(self, group_name, group_rank, group_size, ring_id, group_ranks): - group = paddle.distributed.collective.Group(group_rank, group_size, - ring_id, group_ranks) + group = paddle.distributed.collective.Group(group_rank, ring_id, + group_ranks) self.groups[group_name] = group def get_group(self, group_name): diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 8a22526d3c2e5..29825e81b774a 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -258,12 +258,7 @@ def train(): _default_group_name, pg_options=None) ranks = list(range(world_size)) - group = Group(rank, - world_size, - id=0, - ranks=ranks, - pg=pg, - name=_default_group_name) + group = Group(rank, 0, ranks, pg=pg, name=_default_group_name) _set_group_map_by_name(_default_group_name, group) _set_group_map(0, group) _set_group_map_backend(group, backend) diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index 58b026a3b2a30..7c11a3e6393cc 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -265,7 +265,6 @@ class MoELayer(nn.Layer): from paddle.distributed import fleet moe_group = Group(fleet.worker_index(), - fleet.worker_num(), 0, list(range(fleet.worker_num()))) mp_group = None