diff --git a/paddle/fluid/operators/collective/alltoall_op.cc b/paddle/fluid/operators/collective/alltoall_op.cc new file mode 100644 index 0000000000000..1c57b9f996763 --- /dev/null +++ b/paddle/fluid/operators/collective/alltoall_op.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/alltoall_op.h" + +namespace paddle { +namespace operators { + +class AllToAllOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AllToAll"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "AllToAll"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for alltoall op must be non-negative.", ring_id)); + framework::DDim dim = ctx->GetInputDim("X"); + if (dim[0] < 0) dim[0] = -1; + ctx->SetOutputDim("Out", dim); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class AllToAllOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor send."); + AddOutput("Out", "(Tensor) the result of alltoall."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(R"DOC( +AllToAll Operator +Scatter tensors from all participators to all participators. +)DOC"); + } +}; + +template +class AllToAllOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("alltoall"); + retv->SetInput("X", this->OutputGrad("Out")); + retv->SetOutput("Out", this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_INPLACE_OP_INFERER(AllToAllInplaceInferer, {"X", "Out"}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(alltoall, ops::AllToAllOp, ops::AllToAllOpMaker, + ops::AllToAllOpGradMaker, + ops::AllToAllOpGradMaker, + ops::AllToAllInplaceInferer) + +REGISTER_OP_CPU_KERNEL(alltoall, ops::AllToAllOpCPUKernel, + ops::AllToAllOpCPUKernel, + ops::AllToAllOpCPUKernel, + ops::AllToAllOpCPUKernel, + ops::AllToAllOpCPUKernel); diff --git a/paddle/fluid/operators/collective/alltoall_op.cu.cc b/paddle/fluid/operators/collective/alltoall_op.cu.cc new file mode 100644 index 0000000000000..1bcb47fc686cf --- /dev/null +++ b/paddle/fluid/operators/collective/alltoall_op.cu.cc @@ -0,0 +1,95 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/alltoall_op.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class AllToAllOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_NCCL) +#if NCCL_VERSION_CODE >= 2703 + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + int send_numel = x->numel(); + ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); + + int ring_id = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for alltoall op must be non-negative.", ring_id)); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + int nranks = comm->nranks(); + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + framework::DDim x_dims = x->dims(); + framework::DDim out_dims(x_dims); + PADDLE_ENFORCE_EQ( + x_dims[0] % nranks, 0, + platform::errors::InvalidArgument( + "The first dimension size (%d) of the input tensor must be " + "divisible by the number of ranks (%d).", + x_dims[0], nranks)); + auto send_buf = x->data(); + auto recv_buf = out->mutable_data(out_dims, place); + size_t offset = 0; + send_numel /= nranks; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < nranks; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( + send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( + recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + offset += send_numel; + } + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(alltoall, ops::AllToAllOpCUDAKernel, + ops::AllToAllOpCUDAKernel, + ops::AllToAllOpCUDAKernel, + ops::AllToAllOpCUDAKernel, + ops::AllToAllOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/alltoall_op.h b/paddle/fluid/operators/collective/alltoall_op.h new file mode 100644 index 0000000000000..61eec44093794 --- /dev/null +++ b/paddle/fluid/operators/collective/alltoall_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif + +namespace paddle { +namespace operators { + +template +class AllToAllOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support alltoall for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 7fb9e1d0455bb..fb74539410880 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -36,6 +36,7 @@ 'scatter', 'barrier', 'split', + 'alltoall', 'ReduceOp', 'send', 'recv', @@ -1174,6 +1175,77 @@ def split(x, return linear_out +def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): + """ + Scatter tensors in in_tensor_list to all participators and gather the result tensors in out_tensor_list. + Args: + in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type + should be float16, float32, float64, int32 or int64. + out_tensor_list (Tensor): A list of output Tensors. The data type of its elements should be the same as the + data type of the input Tensors. + group (Group, optional): The group instance return by new_group or None for global default group. Default: None. + use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True. + Returns: + None. + Examples: + .. code-block:: python + # required: distributed + import numpy as np + import paddle + from paddle.distributed import init_parallel_env + init_parallel_env() + out_tensor_list = [] + if paddle.distributed.ParallelEnv().rank == 0: + np_data1 = np.array([[1, 2, 3], [4, 5, 6]]) + np_data2 = np.array([[7, 8, 9], [10, 11, 12]]) + else: + np_data1 = np.array([[13, 14, 15], [16, 17, 18]]) + np_data2 = np.array([[19, 20, 21], [22, 23, 24]]) + data1 = paddle.to_tensor(np_data1) + data2 = paddle.to_tensor(np_data2) + paddle.distributed.all_to_all([data1, data2], out_tensor_list) + # out for rank 0: [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] + # out for rank 1: [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] + """ + if group is not None and not group.is_member(): + return + + ring_id = 0 if group is None else group.id + op_type = 'alltoall' + temp = paddle.concat(in_tensor_list, axis=0) + helper = LayerHelper(op_type, **locals()) + nranks = len(in_tensor_list) + out = helper.create_variable_for_type_inference( + dtype=in_tensor_list[0].dtype) + if in_dygraph_mode(): + core.ops.alltoall_(temp, 'use_calc_stream', use_calc_stream, 'ring_id', + ring_id) + else: + if not isinstance(in_tensor_list, list): + raise ValueError("The type of 'in_tensor_list' for all_to_all " + "should be list.") + for elem in in_tensor_list: + check_variable_and_dtype( + elem, 'in_tensor_list', + ['float16', 'float32', 'float64', 'int32', 'int64'], + 'all_to_all') + if not isinstance(out_tensor_list, list): + raise ValueError("The type of 'out_tensor_list' for all_to_all " + "should be list.") + if len(out_tensor_list) != 0: + raise ValueError("The 'out_tensor_list' for all_to_all " + "must be an empty list.") + helper.append_op( + type=op_type, + inputs={'X': [temp]}, + outputs={'Out': [out]}, + attrs={ + 'ring_id': group, + 'use_calc_stream': use_calc_stream, + }) + out_tensor_list.extend(paddle.split(out, nranks, 0)) + + def send(tensor, dst=0, group=None, use_calc_stream=True): """ Send a tensor to the receiver. diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index c1a29c050b138..8e998459cd499 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -96,6 +96,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_new_group_api) LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api) LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api) + LIST(REMOVE_ITEM TEST_OPS test_collective_alltoall_api) LIST(REMOVE_ITEM TEST_OPS test_collective_sendrecv_api) LIST(REMOVE_ITEM TEST_OPS test_collective_wait) LIST(REMOVE_ITEM TEST_OPS test_memcpy_op) @@ -872,6 +873,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) endif() if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120) + set_tests_properties(test_collective_alltoall_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_sendrecv_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_allreduce_api PROPERTIES TIMEOUT 120) @@ -907,6 +909,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) test_new_group_api test_collective_broadcast_api test_collective_allgather_api + test_collective_alltoall_api PROPERTIES LABELS "RUN_TYPE=DIST") endif() if(WITH_GPU OR WITH_ROCM) diff --git a/python/paddle/fluid/tests/unittests/collective_alltoall_api.py b/python/paddle/fluid/tests/unittests/collective_alltoall_api.py new file mode 100644 index 0000000000000..be18b68a1da33 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_alltoall_api.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main + +paddle.enable_static() + + +class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank): + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + tindata = paddle.split(tindata, 2, axis=0) + tout_data = [] + paddle.distributed.alltoall(tindata, tout_data) + return tout_data + + +if __name__ == "__main__": + runtime_main(TestCollectiveAllToAllAPI, "alltoall") diff --git a/python/paddle/fluid/tests/unittests/test_collective_alltoall_api.py b/python/paddle/fluid/tests/unittests/test_collective_alltoall_api.py new file mode 100644 index 0000000000000..fab975a9d6249 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_alltoall_api.py @@ -0,0 +1,34 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle + +from test_collective_api_base import TestDistBase + +paddle.enable_static() + + +class TestCollectiveAllToAllAPI(TestDistBase): + def _setup_config(self): + pass + + def test_alltoall_nccl(self): + self.check_with_place("collective_alltoall_api.py", "alltoall", "nccl") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index 832ffafa85e8c..e6693b676cf64 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -277,6 +277,19 @@ def check_with_place(self, self.assertTrue( np.allclose( result_data, need_result, rtol=1e-05, atol=1e-05)) + elif col_type == "alltoall": + need_result1 = np.vstack((input1[0:input1.shape[0] // 2, :], + input2[0:input2.shape[0] // 2, :])) + need_result2 = np.vstack((input1[input1.shape[0] // 2:, :], + input2[input2.shape[0] // 2:, :])) + tr0_out = np.vstack(tr0_out) + tr1_out = np.vstack(tr1_out) + self.assertTrue( + np.allclose( + tr0_out, need_result1, rtol=1e-05, atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out, need_result2, rtol=1e-05, atol=1e-05)) elif col_type == "sendrecv": result_data = tr1_out[0] self.assertTrue(