Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add alltoall api #32507

Merged
merged 10 commits into from
Apr 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions paddle/fluid/operators/collective/alltoall_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/alltoall_op.h"

namespace paddle {
namespace operators {

class AllToAllOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AllToAll");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "AllToAll");
int ring_id = ctx->Attrs().Get<int>("ring_id");
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for alltoall op must be non-negative.", ring_id));
framework::DDim dim = ctx->GetInputDim("X");
if (dim[0] < 0) dim[0] = -1;
ctx->SetOutputDim("Out", dim);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};

class AllToAllOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) tensor send.");
AddOutput("Out", "(Tensor) the result of alltoall.");
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(R"DOC(
AllToAll Operator
Scatter tensors from all participators to all participators.
)DOC");
}
};

template <typename T>
class AllToAllOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("alltoall");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};

DECLARE_INPLACE_OP_INFERER(AllToAllInplaceInferer, {"X", "Out"});

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OPERATOR(alltoall, ops::AllToAllOp, ops::AllToAllOpMaker,
ops::AllToAllOpGradMaker<paddle::framework::OpDesc>,
ops::AllToAllOpGradMaker<paddle::imperative::OpBase>,
ops::AllToAllInplaceInferer)

REGISTER_OP_CPU_KERNEL(alltoall, ops::AllToAllOpCPUKernel<float>,
ops::AllToAllOpCPUKernel<double>,
ops::AllToAllOpCPUKernel<int>,
ops::AllToAllOpCPUKernel<int64_t>,
ops::AllToAllOpCPUKernel<plat::float16>);
95 changes: 95 additions & 0 deletions paddle/fluid/operators/collective/alltoall_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/alltoall_op.h"

#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif

namespace paddle {
namespace operators {

template <typename T>
class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL)
#if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
int send_numel = x->numel();
ncclDataType_t dtype = platform::ToNCCLDataType(x->type());

int ring_id = ctx.Attr<int>("ring_id");
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for alltoall op must be non-negative.", ring_id));
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();

cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}

framework::DDim x_dims = x->dims();
framework::DDim out_dims(x_dims);
PADDLE_ENFORCE_EQ(
x_dims[0] % nranks, 0,
platform::errors::InvalidArgument(
"The first dimension size (%d) of the input tensor must be "
"divisible by the number of ranks (%d).",
x_dims[0], nranks));
auto send_buf = x->data<T>();
auto recv_buf = out->mutable_data<T>(out_dims, place);
size_t offset = 0;
send_numel /= nranks;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < nranks; ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
send_buf + offset, send_numel, dtype, i, comm->comm(), stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
recv_buf + offset, send_numel, dtype, i, comm->comm(), stream));
offset += send_numel;
}
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
#else
PADDLE_THROW(
platform::errors::Unavailable("NCCL version >= 2.7.3 is needed."));
#endif
#else
PADDLE_THROW(
platform::errors::Unavailable("PaddlePaddle should compile with GPU."));
#endif
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(alltoall, ops::AllToAllOpCUDAKernel<float>,
ops::AllToAllOpCUDAKernel<double>,
ops::AllToAllOpCUDAKernel<int>,
ops::AllToAllOpCUDAKernel<int64_t>,
ops::AllToAllOpCUDAKernel<plat::float16>);
42 changes: 42 additions & 0 deletions paddle/fluid/operators/collective/alltoall_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <algorithm>
#include <utility>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"

#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif

namespace paddle {
namespace operators {

template <typename T>
class AllToAllOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support alltoall for cpu kernel now."));
}
};

} // namespace operators
} // namespace paddle
72 changes: 72 additions & 0 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
'scatter',
'barrier',
'split',
'alltoall',
'ReduceOp',
'send',
'recv',
Expand Down Expand Up @@ -1174,6 +1175,77 @@ def split(x,
return linear_out


def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
"""
Scatter tensors in in_tensor_list to all participators and gather the result tensors in out_tensor_list.
Args:
in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
out_tensor_list (Tensor): A list of output Tensors. The data type of its elements should be the same as the
data type of the input Tensors.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
Returns:
None.
Examples:
.. code-block:: python
# required: distributed
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
init_parallel_env()
out_tensor_list = []
if paddle.distributed.ParallelEnv().rank == 0:
np_data1 = np.array([[1, 2, 3], [4, 5, 6]])
np_data2 = np.array([[7, 8, 9], [10, 11, 12]])
else:
np_data1 = np.array([[13, 14, 15], [16, 17, 18]])
np_data2 = np.array([[19, 20, 21], [22, 23, 24]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
paddle.distributed.all_to_all([data1, data2], out_tensor_list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以把跑完后的结果也放到文档里

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# out for rank 0: [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]]
# out for rank 1: [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]]
"""
if group is not None and not group.is_member():
return

ring_id = 0 if group is None else group.id
op_type = 'alltoall'
temp = paddle.concat(in_tensor_list, axis=0)
helper = LayerHelper(op_type, **locals())
nranks = len(in_tensor_list)
out = helper.create_variable_for_type_inference(
dtype=in_tensor_list[0].dtype)
if in_dygraph_mode():
core.ops.alltoall_(temp, 'use_calc_stream', use_calc_stream, 'ring_id',
ring_id)
else:
if not isinstance(in_tensor_list, list):
raise ValueError("The type of 'in_tensor_list' for all_to_all "
"should be list.")
for elem in in_tensor_list:
check_variable_and_dtype(
elem, 'in_tensor_list',
['float16', 'float32', 'float64', 'int32', 'int64'],
'all_to_all')
if not isinstance(out_tensor_list, list):
raise ValueError("The type of 'out_tensor_list' for all_to_all "
"should be list.")
if len(out_tensor_list) != 0:
raise ValueError("The 'out_tensor_list' for all_to_all "
"must be an empty list.")
helper.append_op(
type=op_type,
inputs={'X': [temp]},
outputs={'Out': [out]},
attrs={
'ring_id': group,
'use_calc_stream': use_calc_stream,
})
out_tensor_list.extend(paddle.split(out, nranks, 0))


def send(tensor, dst=0, group=None, use_calc_stream=True):
"""
Send a tensor to the receiver.
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_new_group_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_alltoall_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_sendrecv_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_wait)
LIST(REMOVE_ITEM TEST_OPS test_memcpy_op)
Expand Down Expand Up @@ -872,6 +873,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
endif()
if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_alltoall_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_sendrecv_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_allreduce_api PROPERTIES TIMEOUT 120)
Expand Down Expand Up @@ -907,6 +909,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
test_new_group_api
test_collective_broadcast_api
test_collective_allgather_api
test_collective_alltoall_api
PROPERTIES LABELS "RUN_TYPE=DIST")
endif()
if(WITH_GPU OR WITH_ROCM)
Expand Down
Loading