From d4ca0d22c7db079364d317e87f557487c790e869 Mon Sep 17 00:00:00 2001 From: Wen Sun Date: Thu, 24 Nov 2022 18:10:42 +0800 Subject: [PATCH 1/2] refactor: move all_gather --- python/paddle/distributed/__init__.py | 5 +- python/paddle/distributed/collective.py | 226 ------------------ .../distributed/communication/__init__.py | 3 + .../distributed/communication/all_gather.py | 170 +++++++++++++ .../communication/stream/all_gather.py | 75 +++++- .../test_collective_allgather_api.py | 10 +- 6 files changed, 245 insertions(+), 244 deletions(-) create mode 100644 python/paddle/distributed/communication/all_gather.py diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 4db153c53b414..da60ef2eb6441 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -27,8 +27,6 @@ from paddle.distributed.fleet.dataset import QueueDataset # noqa: F401 from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401 -from .collective import all_gather # noqa: F401 -from .collective import all_gather_object # noqa: F401 from .collective import barrier # noqa: F401 from .collective import split # noqa: F401 from .collective import new_group # noqa: F401 @@ -37,6 +35,8 @@ from .communication import ( stream, ReduceOp, + all_gather, + all_gather_object, all_reduce, alltoall, alltoall_single, @@ -112,4 +112,5 @@ "irecv", "reduce_scatter", "rpc", + "stream", ] diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 5f33748559848..ef4def05c239c 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -12,14 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -import pickle -import io import datetime from ..fluid.layer_helper import LayerHelper from ..fluid.framework import in_dygraph_mode from ..fluid.framework import _non_static_mode -from ..fluid.data_feeder import check_variable_and_dtype from ..fluid.layers.tensor import fill_constant import paddle import paddle.fluid.core as core @@ -435,225 +431,3 @@ def _sync_comm_stream(tensor, ring_id=0): outputs={'Out': [tensor]}, attrs={'ring_id': ring_id}, ) - - -def all_gather(tensor_list, tensor, group=None, sync_op=True): - """ - - Gather tensors from all participators and all get the result. As shown - below, one process is started with a GPU and the data of this process is represented - by its group rank. Through the all_gather operator, each GPU will have data - from all GPUs. - - .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png - :width: 800 - :alt: all_gather - :align: center - - Args: - tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type - should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128. - tensor (Tensor): The Tensor to send. Its data type - should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128. - group (Group, optional): The group instance return by new_group or None for global default group. - sync_op (bool, optional): Whether this op is a sync op. The default value is True. - - Returns: - None. - - Examples: - .. code-block:: python - - # required: distributed - import paddle - import paddle.distributed as dist - - dist.init_parallel_env() - tensor_list = [] - if dist.get_rank() == 0: - data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) - else: - data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) - dist.all_gather(tensor_list, data) - print(tensor_list) - # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs) - """ - if group is not None and not group.is_member(): - return - - def convert_to_complex(list_of_tensor): - list_of_complex = [] - for tensor in list_of_tensor: - list_of_complex.append(paddle.as_complex(tensor)) - return list_of_complex - - is_input_complex = ( - tensor.dtype == paddle.complex64 or tensor.dtype == paddle.complex128 - ) - if is_input_complex: - tensor = paddle.as_real(tensor) - - if in_dygraph_mode(): - group = _get_default_group() if group is None else group - if len(tensor_list) == 0: - tensor_shape = list(tensor.shape) - tensor_shape[0] *= group.nranks - out = paddle.empty(tensor_shape, tensor.dtype) - else: - out = paddle.concat(tensor_list, axis=0) - task = group.process_group.all_gather_into_tensor(out, tensor, sync_op) - task.wait() - tensor_list.clear() - list_of_tensor = paddle.split(out, group.nranks, 0) - if is_input_complex: - tensor_list.extend(convert_to_complex(list_of_tensor)) - else: - tensor_list.extend(list_of_tensor) - return - - use_calc_stream = sync_op - ring_id = 0 if group is None else group.id - nranks = _get_global_group().nranks if group is None else group.nranks - - if _non_static_mode(): - out = _legacy_C_ops.c_allgather( - tensor, - 'use_calc_stream', - use_calc_stream, - 'ring_id', - ring_id, - 'nranks', - nranks, - ) - else: - op_type = 'c_allgather' - helper = LayerHelper(op_type, **locals()) - out = helper.create_variable_for_type_inference(dtype=tensor.dtype) - if not isinstance(tensor_list, list): - raise ValueError( - "The type of 'tensor_list' for all_gather " "should be list." - ) - for elem in tensor_list: - check_variable_and_dtype( - elem, - 'tensor_list', - [ - 'float16', - 'float32', - 'float64', - 'int32', - 'int64', - 'bool', - 'int8', - 'uint8', - 'complex64', - 'complex128', - ], - 'all_gather', - ) - check_variable_and_dtype( - tensor, - 'tensor', - [ - 'float16', - 'float32', - 'float64', - 'int32', - 'int64', - 'bool', - 'int8', - 'uint8', - 'complex64', - 'complex128', - ], - 'all_gather', - ) - helper.append_op( - type=op_type, - inputs={'X': [tensor]}, - outputs={'Out': [out]}, - attrs={ - 'ring_id': ring_id, - 'use_calc_stream': use_calc_stream, - 'nranks': nranks, - }, - ) - - list_of_tensor = paddle.split(out, nranks, 0) - if is_input_complex: - tensor_list.extend(convert_to_complex(list_of_tensor)) - else: - tensor_list.extend(list_of_tensor) - - -def _convert_object_to_tensor(obj): - _pickler = pickle.Pickler - f = io.BytesIO() - _pickler(f).dump(obj) - data = np.frombuffer(f.getvalue(), dtype=np.uint8) - tensor = paddle.to_tensor(data) - return tensor, tensor.numel() - - -def _convert_tensor_to_object(tensor, len_of_tensor): - _unpickler = pickle.Unpickler - return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load() - - -def all_gather_object(object_list, obj, group=None): - """ - - Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in. - - Args: - object_list (list): A list of output object. The datatype of every element in the list is same as the input obj. - obj (Any): The picklable object to send. - group (Group): The group instance return by new_group or None for global default group. - - Returns: - None. - - Warning: - This API only supports the dygraph mode. - - Examples: - .. code-block:: python - - # required: distributed - import paddle - import paddle.distributed as dist - - dist.init_parallel_env() - object_list = [] - if dist.get_rank() == 0: - obj = {"foo": [1, 2, 3]} - else: - obj = {"bar": [4, 5, 6]} - dist.all_gather_object(object_list, obj) - print(object_list) - # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs) - """ - assert ( - in_dygraph_mode() - ), "all_gather_object doesn't support static graph mode." - - tensor, len_of_tensor = _convert_object_to_tensor(obj) - - # gather len_of_tensor from all ranks - list_len_of_tensor = [] - all_gather(list_len_of_tensor, len_of_tensor, group) - # get the max length from list - max_len_of_tensor = int(max(list_len_of_tensor).item()) - # resize the input tensor to max length avoid hang in all gather - # Note(liyurui): Maybe we should support various length all_gather? - # Now this operation is efficient for we don't support resize in python. - numpy_data = tensor.numpy() - numpy_data = np.resize(numpy_data, [max_len_of_tensor]) - input_tensor = paddle.to_tensor(numpy_data) - - tensor_list = [] - all_gather(tensor_list, input_tensor, group) - for i, tensor in enumerate(tensor_list): - object_list.append( - _convert_tensor_to_object(tensor, list_len_of_tensor[i]) - ) diff --git a/python/paddle/distributed/communication/__init__.py b/python/paddle/distributed/communication/__init__.py index bdd0f99371b85..3b5872ba2c8ff 100644 --- a/python/paddle/distributed/communication/__init__.py +++ b/python/paddle/distributed/communication/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .all_gather import all_gather, all_gather_object from .all_reduce import all_reduce from .broadcast import broadcast from .reduce import reduce, ReduceOp @@ -24,6 +25,8 @@ __all__ = [ "ReduceOp", + "all_gather", + "all_gather_object", "all_reduce", "alltoall", "alltoall_single", diff --git a/python/paddle/distributed/communication/all_gather.py b/python/paddle/distributed/communication/all_gather.py new file mode 100644 index 0000000000000..2a14a05a0b128 --- /dev/null +++ b/python/paddle/distributed/communication/all_gather.py @@ -0,0 +1,170 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import pickle + +import numpy as np +import paddle +import paddle.distributed as dist +import paddle.fluid.framework as framework +import paddle.distributed.communication.stream as stream + + +def all_gather(tensor_list, tensor, group=None, sync_op=True): + """ + + Gather tensors from all participators and all get the result. As shown + below, one process is started with a GPU and the data of this process is represented + by its group rank. Through the all_gather operator, each GPU will have data + from all GPUs. + + .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png + :width: 800 + :alt: all_gather + :align: center + + Args: + tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type + should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. + tensor (Tensor): The Tensor to send. Its data type + should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. + group (Group, optional): The group instance return by new_group or None for global default group. + sync_op (bool, optional): Whether this op is a sync op. The default value is True. + + Returns: + None. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + tensor_list = [] + if dist.get_rank() == 0: + data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) + else: + data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) + dist.all_gather(tensor_list, data) + print(tensor_list) + # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs) + """ + if not framework._in_legacy_dygraph(): + return stream.all_gather(tensor_list, tensor, group, sync_op) + + # NOTE: uncomment code below when having fully complex support + # def convert_to_complex(list_of_tensor): + # list_of_complex = [] + # for tensor in list_of_tensor: + # list_of_complex.append(paddle.as_complex(tensor)) + # return list_of_complex + + # is_input_complex = (tensor.dtype == paddle.complex64 + # or tensor.dtype == paddle.complex128) + # if is_input_complex: + # tensor = paddle.as_real(tensor) + + # code below will be removed after we remove the old dygraph + if group is not None and not group.is_member(): + return + + ring_id = 0 if group is None else group.id + nranks = dist.get_world_size() + out = paddle._legacy_C_ops.c_allgather( + tensor, + 'use_calc_stream', + sync_op, + 'ring_id', + ring_id, + 'nranks', + nranks, + ) + tensor_list.clear() + tensor_list.extend(paddle.split(out, nranks, 0)) + + +def _convert_object_to_tensor(obj): + _pickler = pickle.Pickler + f = io.BytesIO() + _pickler(f).dump(obj) + data = np.frombuffer(f.getvalue(), dtype=np.uint8) + tensor = paddle.to_tensor(data) + return tensor, tensor.numel() + + +def _convert_tensor_to_object(tensor, len_of_tensor): + _unpickler = pickle.Unpickler + return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load() + + +def all_gather_object(object_list, obj, group=None): + """ + + Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in. + + Args: + object_list (list): A list of output object. The datatype of every element in the list is same as the input obj. + obj (Any): The picklable object to send. + group (Group): The group instance return by new_group or None for global default group. + + Returns: + None. + + Warning: + This API only supports the dygraph mode. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + object_list = [] + if dist.get_rank() == 0: + obj = {"foo": [1, 2, 3]} + else: + obj = {"bar": [4, 5, 6]} + dist.all_gather_object(object_list, obj) + print(object_list) + # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs) + """ + assert ( + framework.in_dygraph_mode() + ), "all_gather_object doesn't support static graph mode." + + tensor, len_of_tensor = _convert_object_to_tensor(obj) + + # gather len_of_tensor from all ranks + list_len_of_tensor = [] + all_gather(list_len_of_tensor, len_of_tensor, group) + # get the max length from list + max_len_of_tensor = int(max(list_len_of_tensor).item()) + # resize the input tensor to max length avoid hang in all gather + # Note(liyurui): Maybe we should support various length all_gather? + # Now this operation is efficient for we don't support resize in python. + numpy_data = tensor.numpy() + numpy_data = np.resize(numpy_data, [max_len_of_tensor]) + input_tensor = paddle.to_tensor(numpy_data) + + tensor_list = [] + all_gather(tensor_list, input_tensor, group) + for i, tensor in enumerate(tensor_list): + object_list.append( + _convert_tensor_to_object(tensor, list_len_of_tensor[i]) + ) diff --git a/python/paddle/distributed/communication/stream/all_gather.py b/python/paddle/distributed/communication/stream/all_gather.py index 1e3344d0dbba0..641d2ea3be9a5 100644 --- a/python/paddle/distributed/communication/stream/all_gather.py +++ b/python/paddle/distributed/communication/stream/all_gather.py @@ -13,14 +13,17 @@ # limitations under the License. import paddle +import paddle.distributed as dist import paddle.fluid.framework as framework -from paddle.distributed import collective +import paddle.fluid.data_feeder as data_feeder +import paddle.fluid.layer_helper as layer_helper +from paddle.distributed.communication.group import _get_global_group def _all_gather_into_tensor_in_dygraph( out_tensor, in_tensor, group, sync_op, use_calc_stream ): - group = collective._get_default_group() if group is None else group + group = _get_global_group() if group is None else group if use_calc_stream: return group.process_group.all_gather_into_tensor_on_calc_stream( @@ -40,7 +43,7 @@ def _all_gather_into_tensor_in_dygraph( def _all_gather_in_dygraph( tensor_list, tensor, group, sync_op, use_calc_stream ): - group = collective._get_default_group() if group is None else group + group = _get_global_group() if group is None else group if len(tensor_list) == 0: tensor_list += [paddle.empty_like(tensor) for _ in range(group.nranks)] @@ -57,6 +60,58 @@ def _all_gather_in_dygraph( return task +def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op): + op_type = 'c_allgather' + helper = layer_helper.LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=tensor.dtype) + for elem in tensor_list: + data_feeder.check_variable_and_dtype( + elem, + 'tensor_list', + [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'bool', + 'int8', + 'uint8', + ], + 'all_gather', + ) + data_feeder.check_variable_and_dtype( + tensor, + 'tensor', + [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'bool', + 'int8', + 'uint8', + ], + 'all_gather', + ) + + ring_id = 0 if group is None else group.id + nranks = dist.get_world_size() + helper.append_op( + type=op_type, + inputs={'X': [tensor]}, + outputs={'Out': [out]}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': sync_op, + 'nranks': nranks, + }, + ) + tensor_list.clear() + tensor_list.extend(paddle.split(out, nranks, 0)) + + def all_gather( tensor_or_tensor_list, tensor, @@ -122,7 +177,13 @@ def all_gather( return _all_gather_in_dygraph( tensor_or_tensor_list, tensor, group, sync_op, use_calc_stream ) - - raise RuntimeError( - "paddle.distributed.stream.all_gather is only supported in dygraph mode now." - ) + else: + assert group is None, "Group can not be used in static mode for now." + if paddle.is_tensor(tensor_or_tensor_list): + raise RuntimeError( + "Only support passing a tensor list to `all_gather` in static mode now." + ) + else: + return _all_gather_in_static_mode( + tensor_or_tensor_list, tensor, group, sync_op + ) diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_allgather_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_allgather_api.py index f0e9b3bba2f60..004be6180a6ea 100644 --- a/python/paddle/fluid/tests/unittests/collective/test_collective_allgather_api.py +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_allgather_api.py @@ -34,8 +34,6 @@ def test_allgather_nccl(self): "int8", "uint8", "bool", - "complex64", - "complex128", ] for dtype in dtypes_to_test: self.check_with_place( @@ -52,8 +50,6 @@ def test_allgather_gloo(self): "int8", "uint8", "bool", - "complex64", - "complex128", ] for dtype in dtypes_to_test: self.check_with_place( @@ -64,7 +60,7 @@ def test_allgather_gloo(self): dtype=dtype, ) - def test_allgatther_nccl_dygraph(self): + def test_allgather_nccl_dygraph(self): dtypes_to_test = [ "float16", "float32", @@ -74,8 +70,6 @@ def test_allgatther_nccl_dygraph(self): "int8", "uint8", "bool", - "complex64", - "complex128", ] if self._nccl_version >= 2100: dtypes_to_test.append("bfloat16") @@ -99,8 +93,6 @@ def test_allgather_gloo_dygraph(self): "uint8", "bool", "bfloat16", - "complex64", - "complex128", ] for dtype in dtypes_to_test: self.check_with_place( From a246a0d6ee8312a09332063273a2bec2613dc94e Mon Sep 17 00:00:00 2001 From: Wen Sun Date: Fri, 25 Nov 2022 10:10:25 +0800 Subject: [PATCH 2/2] chore: rm package in __all__ --- python/paddle/distributed/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index da60ef2eb6441..35d95c305778e 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -111,6 +111,4 @@ "isend", "irecv", "reduce_scatter", - "rpc", - "stream", ]