Skip to content

Commit

Permalink
[Auto Parallel] Add the graph class for the process and cluster (#37482)
Browse files Browse the repository at this point in the history
* [Auto Parallel]  Add the unified cluster representation

* [Auto Parallel] Add the graph class for physical mapping

* [Auto Parallel] Add the simple physical mapper

* Set the timeout of the mapper

* Merge the upstream develop unittests cmake files

* Fix a bug of the process group

* Remove mapper unittest from platforms which is not GPU

* Move the instantiation of process group after resharding

* Add the local id for devices

* Update the rank mapping format

* Add some comments

* Remove the related files about mapping

* Remove unused rank_mapping unittest

* Improve the unittest coverage
  • Loading branch information
aoyulong authored Nov 27, 2021
1 parent e7bda1d commit 48faf63
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 44 deletions.
172 changes: 172 additions & 0 deletions python/paddle/distributed/auto_parallel/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License


class Node:
def __init__(self, id, **attrs):
# Each node must has a unique id
self._id = id
# Attributes for Node
self._attrs = {}
self._attrs.update(attrs)

@property
def id(self):
return self._id

@property
def attrs(self):
return self._attrs

def __getitem__(self, attr_name):
return self._attrs[attr_name]

def __setitem__(self, attr_name, attr_value):
self._attrs[attr_name] = attr_value

def __contains__(self, attr_name):
try:
return attr_name in self._attrs
except TypeError:
return False

def __str__(self):
str = "(id: {}, attrs: {})".format(self.id, self.attrs)
return str


class Edge:
def __init__(self, src_id, tgt_id, **attrs):
# The id of source node in an Edge
self._src_id = src_id
# The id of target node in an Edge
self._tgt_id = tgt_id
# Attributes for Edge
self._attrs = {}
self._attrs.update(attrs)

@property
def src_id(self):
return self._src_id

@property
def tgt_id(self):
return self._tgt_id

@property
def attrs(self):
return self._attrs

def __getitem__(self, attr_name):
return self._attrs[attr_name]

def __setitem__(self, attr_name, attr_value):
self._attrs[attr_name] = attr_value

def __contains__(self, attr_name):
try:
return attr_name in self._attrs
except TypeError:
return False

def __str__(self):
str = ""
str += "(src_id: {}, tgt_id: {}, attrs: {})".format(
self.src_id, self.tgt_id, self._attrs)
return str


class Graph:
def __init__(self, **attrs):
# _nodes is dict for storing the nodes of the graph.
# The key of this dict is the node id.
self._nodes = {}
# _adjs is a dict of dict for storing the adjacency of the graph.
# The key of the outer dict is the node id of the source node and
# the key of the inner dict is the node id of the target node.
self._adjs = {}
# Attributes for Graph
self._attrs = {}
self._attrs.update(attrs)

@property
def nodes(self):
return self._nodes

@property
def attrs(self):
return self._attrs

@property
def adjs(self):
return self._adjs

def add_node(self, node_id, **attrs):
if node_id is None:
raise ValueError("None cannot be a node")
if node_id not in self._nodes:
node = Node(node_id, **attrs)
self._nodes[node_id] = node
self._adjs[node_id] = {}
else:
self._nodes[node_id].attrs.update(attrs)

def add_edge(self, src_id, tgt_id, **attrs):
# add nodes
if src_id is None:
raise ValueError("None cannot be a node")
if tgt_id is None:
raise ValueError("None cannot be a node")
if src_id not in self._nodes:
src_node = Node(src_id)
self._nodes[src_id] = src_node
self._adjs[src_id] = {}
if tgt_id not in self._nodes:
tgt_node = Node(tgt_id)
self._nodes[tgt_id] = tgt_node
self._adjs[tgt_id] = {}
# add the edge
edge = Edge(src_id, tgt_id, **attrs)
self._adjs[src_id][tgt_id] = edge

def __len__(self):
return len(self._nodes)

def __iter__(self):
return iter(self._nodes.values())

def __getitem__(self, node_id):
# Return the adjacency of a node
return self._adjs[node_id]

def __contains__(self, node_id):
# Check whether a node in the graph
try:
return node_id in self._nodes
except TypeError:
return False

def __str__(self):
str = ""
str += "**************Nodes**************\n"
for node_id in self.nodes:
str += "{}\n".format(self.nodes[node_id])

str += "**************Edges**************\n"
for src_id in self.adjs:
str += "--------------{}--------------\n".format(src_id)
for idx, tgt_id in enumerate(self.adjs[src_id]):
str += "{}\n".format(self.adjs[src_id][tgt_id])

return str
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def parallelize(self,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if rank not in process_group._ranks:
if rank not in process_group.ranks:
continue
process_group.instantiate()

Expand Down
89 changes: 62 additions & 27 deletions python/paddle/distributed/auto_parallel/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from ...fluid.framework import in_dygraph_mode
from ...fluid.layers.tensor import fill_constant

# Note that Process group 0 is reserved for representing all ranks.
# At the begining, group 0 is empty and new ranks will be added automatically.
_g_process_group_map = {}


Expand All @@ -27,25 +29,27 @@ def get_all_process_groups():
return _g_process_group_map.values()


def get_process_group(group_id):
global _g_process_group_map
return _g_process_group_map.get(group_id, None)


def new_process_group(ranks):
global _g_process_group_map
if not _g_process_group_map:
genv = _get_global_env()
_g_process_group_map["global_group"] = ProcessGroup(
0, list(range(genv.world_size)))
# A key constructed from ranks is used in the global process group map
key = ''.join(map(str, sorted(ranks)))
if key not in _g_process_group_map:
num_groups = len(_g_process_group_map)
# Note: our process group may interfere with the original implementation
# so the created group id should start from the original _new_ring_id()
group_id = _new_ring_id() + num_groups + 1
pg = ProcessGroup(group_id, ranks)
_g_process_group_map[key] = pg
return pg
else:
pg = _g_process_group_map[key]
return pg
# A key constructed from ranks is used for avoiding duplication
new_key = ''.join(map(str, sorted(ranks)))
for pg_id, pg in _g_process_group_map.items():
cur_key = ''.join(map(str, sorted(pg.ranks)))
if pg_id != 0 and new_key == cur_key:
return pg
# If not matching the existing one, construt a new process group
num_groups = len(_g_process_group_map)
# Note: our process group may interfere with the original implementation
# so the created group id should start from the original _new_ring_id()
group_id = _new_ring_id() + num_groups + 1
new_pg = ProcessGroup(group_id, ranks)
_g_process_group_map[group_id] = new_pg
return new_pg


# This implementation refers to lots of Paddle/python/paddle/distributed/collective.py,
Expand All @@ -56,22 +60,40 @@ def new_process_group(ranks):
# handle the communication implementation choice.
class ProcessGroup:
def __init__(self, group_id, ranks):
if group_id == 0 and get_process_group(0) is not None:
assert group_id != 0, "Process group id 0 is reserved for all ranks."
self._group_id = group_id
self._ranks = sorted(ranks)
self._nranks = len(self._ranks)
# Add the current ranks into group 0
if group_id != 0:
global _g_process_group_map
_g_process_group_map[0].add_ranks(ranks)
self._is_instantiate = False

@property
def id(self):
return self._group_id

# @property
# def key(self):
# return ''.join(map(str, sorted(self._ranks)))
@property
def ranks(self):
return self._ranks

@property
def nranks(self):
return len(self._ranks)

def add_ranks(self, new_ranks):
if set(new_ranks) <= set(self.ranks):
return
else:
assert self.is_instantiate() == False, \
"Cannot add new ranks after instantiating the process group"
self._ranks.extend(new_ranks)
self._ranks = sorted(list(set(self.ranks)))

def local_rank(self, global_rank):
if global_rank in self._ranks:
return self._ranks.index(global_rank)
if global_rank in self.ranks:
return self.ranks.index(global_rank)
else:
assert False, \
"Rank {} doesn't belong to this group".format(global_rank)
Expand All @@ -86,12 +108,12 @@ def instantiate(self):
genv = _get_global_env()
global_rank = genv.rank

if self._nranks >= 2:
if self.nranks >= 2:
strategy = core.ParallelStrategy()
strategy.nranks = self._nranks
strategy.nranks = self.nranks
strategy.local_rank = self.local_rank(global_rank)
strategy.trainer_endpoints = [
genv.trainer_endpoints[i] for i in self._ranks
genv.trainer_endpoints[i] for i in self.ranks
]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1
Expand All @@ -113,7 +135,20 @@ def instantiate(self):

self._is_instantiate = True

# def __eq__(self, other):
# if not isinstance(other, ProcessGroup):
# return False
# if self.id != other.id:
# return False
# return True

# def __ne__(self, other):
# return not self.__eq__(other)

def __str__(self):
string = "id: {}, nranks: {}, ranks: {}.".format(
self.id, self._nranks, ", ".join(map(str, self._ranks)))
self.id, self.nranks, ", ".join(map(str, self.ranks)))
return string


_g_process_group_map[0] = ProcessGroup(0, [])
5 changes: 5 additions & 0 deletions python/paddle/distributed/auto_parallel/process_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,14 @@ def __init__(self, mesh):
self._topology = _get_nested_list_shape(mesh)
self._processes = processes

# Store all process meshes
from .dist_context import get_default_distributed_context
default_dist_cxt = get_default_distributed_context()
default_dist_cxt.add_process_mesh(self)
# Add new processes to process group 0
from .process_group import get_process_group
pg0 = get_process_group(0)
pg0.add_ranks(self.processes)

@property
def topology(self):
Expand Down
14 changes: 2 additions & 12 deletions python/paddle/distributed/auto_parallel/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,13 +627,13 @@ def _insert_fill_constant_op(block, idx):
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'nranks': group._nranks
'nranks': group.nranks
})
idx_offset += 1

# insert split op
split_out = _insert_split_op(block, idx + idx_offset, allgather_out,
group._nranks)
group.nranks)
idx_offset += 1
tensor_list.extend(split_out)
return tensor_list, idx_offset
Expand Down Expand Up @@ -665,14 +665,6 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
partition_tensor_list.append((tensor, partition_index))


def _init_comm_for_send_recv():
if not _g_process_group_map:
genv = _get_global_env()
_g_process_group_map["global_group"] = ProcessGroup(
0, list(range(genv.world_size)))
_g_process_group_map["global_group"].instantiate()


HAS_SENT = {}
HAS_RECV = {}
HAS_ALLGATHER = {}
Expand Down Expand Up @@ -726,7 +718,6 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
assert tensor_list, "The result of parsing allgather op should not be None."

elif isinstance(op_desc, SendOpDesc):
_init_comm_for_send_recv()
if var_name not in HAS_SENT.keys():
HAS_SENT[var_name] = []
if op_desc.dst not in HAS_SENT[var_name]:
Expand All @@ -735,7 +726,6 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
HAS_SENT[var_name].append(op_desc.dst)

elif isinstance(op_desc, RecvOpDesc):
_init_comm_for_send_recv()
if var_name not in HAS_RECV.keys():
HAS_RECV[var_name] = {}
if op_desc.src not in HAS_RECV[var_name].keys():
Expand Down
2 changes: 0 additions & 2 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_run_random_port)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_async)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_cloud)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_ascend)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_rank_mapping)
list(APPEND MIXED_DIST_TEST_OPS test_ascend_group)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_nproc)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_api_input)
Expand Down Expand Up @@ -669,7 +668,6 @@ if(WITH_DISTRIBUTE)
bash_test_modules(test_fleet_launch_async START_BASH test_fleet_launch_async.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
bash_test_modules(test_fleet_launch_cloud START_BASH test_fleet_launch_cloud.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
bash_test_modules(test_fleet_launch_nproc START_BASH test_fleet_launch_nproc.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
bash_test_modules(test_fleet_launch_rank_mapping START_BASH test_fleet_launch_rank_mapping.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
if(WITH_ASCEND OR WITH_ASCEND_CL)
bash_test_modules(test_fleet_launch_ascend START_BASH test_fleet_launch_ascend.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
bash_test_modules(test_ascend_group START_BASH test_ascend_group.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
Expand Down
Loading

0 comments on commit 48faf63

Please sign in to comment.