Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move collective communication all_gather from collective.py #48339

Merged
merged 2 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
from paddle.distributed.fleet.dataset import QueueDataset # noqa: F401
from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401

from .collective import all_gather # noqa: F401
from .collective import all_gather_object # noqa: F401
from .collective import barrier # noqa: F401
from .collective import split # noqa: F401
from .collective import new_group # noqa: F401
Expand All @@ -37,6 +35,8 @@
from .communication import (
stream,
ReduceOp,
all_gather,
all_gather_object,
all_reduce,
alltoall,
alltoall_single,
Expand Down Expand Up @@ -112,4 +112,5 @@
"irecv",
"reduce_scatter",
"rpc",
"stream",
HermitSun marked this conversation as resolved.
Show resolved Hide resolved
]
226 changes: 0 additions & 226 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pickle
import io
import datetime
from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import in_dygraph_mode
from ..fluid.framework import _non_static_mode
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.layers.tensor import fill_constant
import paddle
import paddle.fluid.core as core
Expand Down Expand Up @@ -435,225 +431,3 @@ def _sync_comm_stream(tensor, ring_id=0):
outputs={'Out': [tensor]},
attrs={'ring_id': ring_id},
)


def all_gather(tensor_list, tensor, group=None, sync_op=True):
"""

Gather tensors from all participators and all get the result. As shown
below, one process is started with a GPU and the data of this process is represented
by its group rank. Through the all_gather operator, each GPU will have data
from all GPUs.

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
:width: 800
:alt: all_gather
:align: center

Args:
tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.

Returns:
None.

Examples:
.. code-block:: python

# required: distributed
import paddle
import paddle.distributed as dist

dist.init_parallel_env()
tensor_list = []
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.all_gather(tensor_list, data)
print(tensor_list)
# [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
"""
if group is not None and not group.is_member():
return

def convert_to_complex(list_of_tensor):
list_of_complex = []
for tensor in list_of_tensor:
list_of_complex.append(paddle.as_complex(tensor))
return list_of_complex

is_input_complex = (
tensor.dtype == paddle.complex64 or tensor.dtype == paddle.complex128
)
if is_input_complex:
tensor = paddle.as_real(tensor)

if in_dygraph_mode():
group = _get_default_group() if group is None else group
if len(tensor_list) == 0:
tensor_shape = list(tensor.shape)
tensor_shape[0] *= group.nranks
out = paddle.empty(tensor_shape, tensor.dtype)
else:
out = paddle.concat(tensor_list, axis=0)
task = group.process_group.all_gather_into_tensor(out, tensor, sync_op)
task.wait()
tensor_list.clear()
list_of_tensor = paddle.split(out, group.nranks, 0)
if is_input_complex:
tensor_list.extend(convert_to_complex(list_of_tensor))
else:
tensor_list.extend(list_of_tensor)
return

use_calc_stream = sync_op
ring_id = 0 if group is None else group.id
nranks = _get_global_group().nranks if group is None else group.nranks

if _non_static_mode():
out = _legacy_C_ops.c_allgather(
tensor,
'use_calc_stream',
use_calc_stream,
'ring_id',
ring_id,
'nranks',
nranks,
)
else:
op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
if not isinstance(tensor_list, list):
raise ValueError(
"The type of 'tensor_list' for all_gather " "should be list."
)
for elem in tensor_list:
check_variable_and_dtype(
elem,
'tensor_list',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'bool',
'int8',
'uint8',
'complex64',
'complex128',
],
'all_gather',
)
check_variable_and_dtype(
tensor,
'tensor',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'bool',
'int8',
'uint8',
'complex64',
'complex128',
],
'all_gather',
)
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [out]},
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
'nranks': nranks,
},
)

list_of_tensor = paddle.split(out, nranks, 0)
if is_input_complex:
tensor_list.extend(convert_to_complex(list_of_tensor))
else:
tensor_list.extend(list_of_tensor)


def _convert_object_to_tensor(obj):
_pickler = pickle.Pickler
f = io.BytesIO()
_pickler(f).dump(obj)
data = np.frombuffer(f.getvalue(), dtype=np.uint8)
tensor = paddle.to_tensor(data)
return tensor, tensor.numel()


def _convert_tensor_to_object(tensor, len_of_tensor):
_unpickler = pickle.Unpickler
return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()


def all_gather_object(object_list, obj, group=None):
"""

Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.

Args:
object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
obj (Any): The picklable object to send.
group (Group): The group instance return by new_group or None for global default group.

Returns:
None.

Warning:
This API only supports the dygraph mode.

Examples:
.. code-block:: python

# required: distributed
import paddle
import paddle.distributed as dist

dist.init_parallel_env()
object_list = []
if dist.get_rank() == 0:
obj = {"foo": [1, 2, 3]}
else:
obj = {"bar": [4, 5, 6]}
dist.all_gather_object(object_list, obj)
print(object_list)
# [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
"""
assert (
in_dygraph_mode()
), "all_gather_object doesn't support static graph mode."

tensor, len_of_tensor = _convert_object_to_tensor(obj)

# gather len_of_tensor from all ranks
list_len_of_tensor = []
all_gather(list_len_of_tensor, len_of_tensor, group)
# get the max length from list
max_len_of_tensor = int(max(list_len_of_tensor).item())
# resize the input tensor to max length avoid hang in all gather
# Note(liyurui): Maybe we should support various length all_gather?
# Now this operation is efficient for we don't support resize in python.
numpy_data = tensor.numpy()
numpy_data = np.resize(numpy_data, [max_len_of_tensor])
input_tensor = paddle.to_tensor(numpy_data)

tensor_list = []
all_gather(tensor_list, input_tensor, group)
for i, tensor in enumerate(tensor_list):
object_list.append(
_convert_tensor_to_object(tensor, list_len_of_tensor[i])
)
3 changes: 3 additions & 0 deletions python/paddle/distributed/communication/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .all_gather import all_gather, all_gather_object
from .all_reduce import all_reduce
from .broadcast import broadcast
from .reduce import reduce, ReduceOp
Expand All @@ -24,6 +25,8 @@

__all__ = [
"ReduceOp",
"all_gather",
"all_gather_object",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是官方公开的对外调用的接口形式不?
paddle.distributed.communication.all_gather()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里务必再提PR改一下,__all__只用来存放对外公开的API

Copy link
Contributor Author

@HermitSun HermitSun Nov 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为外层的__init__里使用了import *,暂时使用这种写法以避免引入额外的依赖。
后续pr中将会进行清理。

"all_reduce",
"alltoall",
"alltoall_single",
Expand Down
Loading