Skip to content

Commit

Permalink
move all_reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Oct 8, 2022
1 parent a60bbc5 commit 6fdd8ff
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 131 deletions.
8 changes: 8 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupGloo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts) {
return AllReduce(inputs, outputs, opts, true);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts,
bool sync_op) {
auto tag = next_tag();
std::shared_ptr<GlooTask> task;
auto context = get_context();
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupGloo.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ class ProcessGroupGloo : public ProcessGroup {
std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts = AllreduceOptions()) override;

std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts,
bool sync_op) override;

std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;

Expand Down
119 changes: 6 additions & 113 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@
from .fleet.layers.mpu.mp_ops import _linear
from .fleet.layers.mpu.mp_ops import _parallel_linear
from .fleet.layers.mpu.mp_ops import _parallel_embedding
from .communication.comm_utils import ReduceOp
from .communication.group import Group
from .communication.group import Group, _add_new_group
from .communication.all_reduce import all_reduce
from .communication.reduce import _get_reduce_op, ReduceOp

__all__ = []

Expand Down Expand Up @@ -153,19 +154,6 @@ def _new_ring_id():
return len(_get_group_map()) + max(_get_global_env().nrings, 9)


def _get_reduce_op(reduce_op, func_name):
if reduce_op == ReduceOp.SUM:
return core.ReduceOp.SUM
elif reduce_op == ReduceOp.MAX:
return core.ReduceOp.MAX
elif reduce_op == ReduceOp.MIN:
return core.ReduceOp.MIN
elif reduce_op == ReduceOp.PROD:
return core.ReduceOp.PRODUCT
else:
raise ValueError("Unknown reduce_op type for {}.".format(func_name))


def get_group(id=0):
"""
Expand Down Expand Up @@ -411,6 +399,9 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
_group_map_by_name[group_name] = group
_group_map[gid] = group
_group_map_backend[group] = backend
#TODO: The method below is a new method for group management, will replace the previous
# three in the future.
_add_new_group(group)

# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by tcp
Expand Down Expand Up @@ -704,104 +695,6 @@ def broadcast(tensor, src, group=None, sync_op=True):
})


def all_reduce(tensor, op=ReduceOp.SUM, group=None, sync_op=True):
"""
Reduce a tensor over all ranks so that all get the result.
As shown below, one process is started with a GPU and the data of this process is represented
by its group rank. The reduce operator is sum. Through all_reduce operator,
each GPU will have the sum of the data from all GPUs.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
:width: 800
:alt: all_reduce
:align: center
Args:
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Wether this op is a sync op. Default value is True.
Returns:
None.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.all_reduce(data)
print(data)
# [[5, 7, 9], [5, 7, 9]] (2 GPUs)
"""
if group is not None and not group.is_member():
return

if in_dygraph_mode():
op_type = _get_reduce_op(op, "all_reduce")
group = _get_default_group() if group is None else group
task = group.process_group.allreduce(tensor, op_type)
if sync_op:
task.wait()
return None
else:
return task

use_calc_stream = sync_op
ring_id = 0 if group is None else group.id
if _non_static_mode():
if op == ReduceOp.SUM:
return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.MAX:
return _legacy_C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.MIN:
return _legacy_C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.PROD:
return _legacy_C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
else:
raise ValueError("Unknown parameter: {}.".format(op))

check_variable_and_dtype(tensor, 'tensor', [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
], 'all_reduce')
if op == ReduceOp.SUM:
op_type = 'c_allreduce_sum'
elif op == ReduceOp.MAX:
op_type = 'c_allreduce_max'
elif op == ReduceOp.MIN:
op_type = 'c_allreduce_min'
elif op == ReduceOp.PROD:
op_type = 'c_allreduce_prod'
if not isinstance(ring_id, int):
raise ValueError("The type of 'ring_id' for all_reduce should be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream
})


def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
"""
Expand Down
87 changes: 87 additions & 0 deletions python/paddle/distributed/communication/all_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.fluid.framework as framework
from paddle.distributed.communication import stream as stream
from paddle.distributed.communication.reduce import ReduceOp


def all_reduce(tensor, op=ReduceOp.SUM, group=None, sync_op=True):
"""
Reduce a tensor over all ranks so that all get the result.
As shown below, one process is started with a GPU and the data of this process is represented
by its group rank. The reduce operator is sum. Through all_reduce operator,
each GPU will have the sum of the data from all GPUs.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
:width: 800
:alt: all_reduce
:align: center
Args:
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Wether this op is a sync op. Default value is True.
Returns:
Return a task object.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.all_reduce(data)
print(data)
# [[5, 7, 9], [5, 7, 9]] (2 GPUs)
"""
if not framework._in_legacy_dygraph():
return stream.all_reduce(tensor,
op=op,
group=group,
sync_op=sync_op,
use_calc_stream=False)

# code below will be removed after we remove the old dygraph
use_calc_stream = sync_op
ring_id = 0 if group is None else group.id
if op == ReduceOp.SUM:
return paddle._legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.MAX:
return paddle._legacy_C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.MIN:
return paddle._legacy_C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.PROD:
return paddle._legacy_C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
use_calc_stream,
'ring_id', ring_id)
else:
raise ValueError("Unknown parameter: {}.".format(op))
26 changes: 22 additions & 4 deletions python/paddle/distributed/communication/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ class Group():
The abstract representation of group.
"""

def __init__(self, group_rank, id, ranks, pg=None, name=None):
self._group_rank = group_rank
self._world_size = len(ranks) if group_rank >= 0 else -1
def __init__(self, rank_in_group, id, ranks, pg=None, name=None):
self._rank_in_group = rank_in_group
self._world_size = len(ranks) if rank_in_group >= 0 else -1
self._id = id
self._ranks = ranks
self._pg = pg
self._name = name

@property
def rank(self):
return self._group_rank
return self._rank_in_group

@property
def ranks(self):
Expand Down Expand Up @@ -74,3 +74,21 @@ def __repr__(self):
debug_str += "; name: "
debug_str += self.name if self.name else "None"
return debug_str


class _GroupManager():
global_group_id = 0
group_map_by_id = {}


def _get_global_group():
if _GroupManager.global_group_id not in _GroupManager.group_map_by_id:
raise RuntimeError("The global group is not initialized.")
return _GroupManager.group_map_by_id[_GroupManager.global_group_id]


def _add_new_group(group):
if group.id in _GroupManager.group_map_by_id:
raise RuntimeError("The group with id {} already exist.".format(
group.id))
_GroupManager.group_map_by_id[group.id] = group
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle.fluid.framework as framework
import paddle.fluid.core as core


class ReduceOp:
"""
Expand Down Expand Up @@ -48,3 +51,26 @@ class ReduceOp:
MIN = 2
PROD = 3
AVG = 4


def _get_reduce_op(reduce_op, func_name):
if framework.in_dygraph_mode():
if reduce_op == ReduceOp.SUM:
return core.ReduceOp.SUM
elif reduce_op == ReduceOp.MAX:
return core.ReduceOp.MAX
elif reduce_op == ReduceOp.MIN:
return core.ReduceOp.MIN
elif reduce_op == ReduceOp.PROD:
return core.ReduceOp.PRODUCT
else:
if reduce_op == ReduceOp.SUM:
return 'c_allreduce_sum'
elif reduce_op == ReduceOp.MAX:
return 'c_allreduce_max'
elif reduce_op == ReduceOp.MIN:
return 'c_allreduce_min'
elif reduce_op == ReduceOp.PROD:
return 'c_allreduce_prod'

raise ValueError("Unknown reduce_op type for {}.".format(func_name))
Loading

0 comments on commit 6fdd8ff

Please sign in to comment.