-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add alltoall api #32507
Merged
Merged
add alltoall api #32507
Changes from 8 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
ebc13a5
add ut, test=develop
19c099c
add alltoall op, test=develop
c03878c
add alltoall op, test=develop
3da1a89
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0c72e10
fix typos, test=develop
fac8a43
fix ut, test=develop
f4b021a
fix ut, test=deveop
62f6853
fix ut, test=develop
9fa800d
Merge branch 'develop' into alltoall
a8efaff
fix doc, test=develop
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/collective/alltoall_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class AllToAllOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AllToAll"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "AllToAll"); | ||
int ring_id = ctx->Attrs().Get<int>("ring_id"); | ||
PADDLE_ENFORCE_GE( | ||
ring_id, 0, | ||
platform::errors::InvalidArgument( | ||
"The ring_id (%d) for alltoall op must be non-negative.", ring_id)); | ||
framework::DDim dim = ctx->GetInputDim("X"); | ||
if (dim[0] < 0) dim[0] = -1; | ||
ctx->SetOutputDim("Out", dim); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class AllToAllOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() { | ||
AddInput("X", "(Tensor) tensor send."); | ||
AddOutput("Out", "(Tensor) the result of alltoall."); | ||
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.") | ||
.SetDefault(0); | ||
AddAttr<bool>( | ||
"use_calc_stream", | ||
"(bool default false) eject CUDA operations to calculation stream.") | ||
.SetDefault(false); | ||
AddComment(R"DOC( | ||
AllToAll Operator | ||
Scatter tensors from all participators to all participators. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class AllToAllOpGradMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
protected: | ||
void Apply(GradOpPtr<T> retv) const override { | ||
retv->SetType("alltoall"); | ||
retv->SetInput("X", this->OutputGrad("Out")); | ||
retv->SetOutput("Out", this->InputGrad("X")); | ||
retv->SetAttrMap(this->Attrs()); | ||
} | ||
}; | ||
|
||
DECLARE_INPLACE_OP_INFERER(AllToAllInplaceInferer, {"X", "Out"}); | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
|
||
REGISTER_OPERATOR(alltoall, ops::AllToAllOp, ops::AllToAllOpMaker, | ||
ops::AllToAllOpGradMaker<paddle::framework::OpDesc>, | ||
ops::AllToAllOpGradMaker<paddle::imperative::OpBase>, | ||
ops::AllToAllInplaceInferer) | ||
|
||
REGISTER_OP_CPU_KERNEL(alltoall, ops::AllToAllOpCPUKernel<float>, | ||
ops::AllToAllOpCPUKernel<double>, | ||
ops::AllToAllOpCPUKernel<int>, | ||
ops::AllToAllOpCPUKernel<int64_t>, | ||
ops::AllToAllOpCPUKernel<plat::float16>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/collective/alltoall_op.h" | ||
|
||
#if defined(PADDLE_WITH_NCCL) | ||
#include "paddle/fluid/platform/collective_helper.h" | ||
#include "paddle/fluid/platform/nccl_helper.h" | ||
#endif | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
class AllToAllOpCUDAKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
#if defined(PADDLE_WITH_NCCL) | ||
#if NCCL_VERSION_CODE >= 2703 | ||
auto x = ctx.Input<framework::LoDTensor>("X"); | ||
auto out = ctx.Output<framework::LoDTensor>("Out"); | ||
int send_numel = x->numel(); | ||
ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); | ||
|
||
int ring_id = ctx.Attr<int>("ring_id"); | ||
PADDLE_ENFORCE_GE( | ||
ring_id, 0, | ||
platform::errors::InvalidArgument( | ||
"The ring_id (%d) for alltoall op must be non-negative.", ring_id)); | ||
auto place = ctx.GetPlace(); | ||
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); | ||
int nranks = comm->nranks(); | ||
|
||
cudaStream_t stream = nullptr; | ||
if (ctx.Attr<bool>("use_calc_stream")) { | ||
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); | ||
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream(); | ||
} else { | ||
stream = comm->stream(); | ||
} | ||
|
||
framework::DDim x_dims = x->dims(); | ||
framework::DDim out_dims(x_dims); | ||
PADDLE_ENFORCE_EQ( | ||
x_dims[0] % nranks, 0, | ||
platform::errors::InvalidArgument( | ||
"The first dimension size (%d) of the input tensor must be " | ||
"divisible by the number of ranks (%d).", | ||
x_dims[0], nranks)); | ||
auto send_buf = x->data<T>(); | ||
auto recv_buf = out->mutable_data<T>(out_dims, place); | ||
size_t offset = 0; | ||
send_numel /= nranks; | ||
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); | ||
for (auto i = 0; i < nranks; ++i) { | ||
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( | ||
send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); | ||
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( | ||
recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); | ||
offset += send_numel; | ||
} | ||
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); | ||
#else | ||
PADDLE_THROW( | ||
platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); | ||
#endif | ||
#else | ||
PADDLE_THROW( | ||
platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); | ||
#endif | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
|
||
REGISTER_OP_CUDA_KERNEL(alltoall, ops::AllToAllOpCUDAKernel<float>, | ||
ops::AllToAllOpCUDAKernel<double>, | ||
ops::AllToAllOpCUDAKernel<int>, | ||
ops::AllToAllOpCUDAKernel<int64_t>, | ||
ops::AllToAllOpCUDAKernel<plat::float16>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <algorithm> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "paddle/fluid/framework/data_type.h" | ||
#include "paddle/fluid/framework/lod_tensor.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
|
||
#if defined(PADDLE_WITH_GLOO) | ||
#include "paddle/fluid/framework/fleet/gloo_wrapper.h" | ||
#endif | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
class AllToAllOpCPUKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
PADDLE_THROW(platform::errors::Unavailable( | ||
"Do not support alltoall for cpu kernel now.")); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
'scatter', | ||
'barrier', | ||
'split', | ||
'alltoall', | ||
'ReduceOp', | ||
] | ||
|
||
|
@@ -954,3 +955,73 @@ def split(x, | |
inner_rank, | ||
name=name) | ||
return linear_out | ||
|
||
|
||
def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): | ||
""" | ||
Scatter tensors in in_tensor_list to all participators and gather the result tensors in out_tensor_list. | ||
Args: | ||
in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type | ||
should be float16, float32, float64, int32 or int64. | ||
out_tensor_list (Tensor): A list of output Tensors. The data type of its elements should be the same as the | ||
data type of the input Tensors. | ||
group (Group): The group instance return by new_group or None for global default group. | ||
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False). | ||
Default to True. | ||
Returns: | ||
None. | ||
Examples: | ||
.. code-block:: python | ||
# required: distributed | ||
import numpy as np | ||
import paddle | ||
from paddle.distributed import init_parallel_env | ||
init_parallel_env() | ||
out_tensor_list = [] | ||
if paddle.distributed.ParallelEnv().rank == 0: | ||
np_data1 = np.array([[1, 2, 3], [4, 5, 6]]) | ||
np_data2 = np.array([[7, 8, 9], [10, 11, 12]]) | ||
else: | ||
np_data1 = np.array([[13, 14, 15], [16, 17, 18]]) | ||
np_data2 = np.array([[19, 20, 21], [22, 23, 24]]) | ||
data1 = paddle.to_tensor(np_data1) | ||
data2 = paddle.to_tensor(np_data2) | ||
paddle.distributed.all_to_all([data1, data2], out_tensor_list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以把跑完后的结果也放到文档里 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
""" | ||
if group is not None and not group.is_member(): | ||
return | ||
|
||
ring_id = 0 if group is None else group.id | ||
op_type = 'alltoall' | ||
temp = paddle.concat(in_tensor_list, axis=0) | ||
helper = LayerHelper(op_type, **locals()) | ||
nranks = len(in_tensor_list) | ||
out = helper.create_variable_for_type_inference( | ||
dtype=in_tensor_list[0].dtype) | ||
if in_dygraph_mode(): | ||
core.ops.alltoall_(temp, 'use_calc_stream', use_calc_stream, 'ring_id', | ||
ring_id) | ||
else: | ||
if not isinstance(in_tensor_list, list): | ||
raise ValueError("The type of 'in_tensor_list' for all_to_all " | ||
"should be list.") | ||
for elem in in_tensor_list: | ||
check_variable_and_dtype( | ||
elem, 'in_tensor_list', | ||
['float16', 'float32', 'float64', 'int32', 'int64'], | ||
'all_to_all') | ||
if not isinstance(out_tensor_list, list): | ||
raise ValueError("The type of 'out_tensor_list' for all_to_all " | ||
"should be list.") | ||
if len(out_tensor_list) != 0: | ||
raise ValueError("The 'out_tensor_list' for all_to_all " | ||
"must be an empty list.") | ||
helper.append_op( | ||
type=op_type, | ||
inputs={'X': [temp]}, | ||
outputs={'Out': [out]}, | ||
attrs={ | ||
'ring_id': group, | ||
'use_calc_stream': use_calc_stream, | ||
}) | ||
out_tensor_list.extend(paddle.split(out, nranks, 0)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 56 additions & 0 deletions
56
python/paddle/fluid/tests/unittests/collective_alltoall_api.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import argparse | ||
import os | ||
import sys | ||
import signal | ||
import time | ||
import socket | ||
from contextlib import closing | ||
from six import string_types | ||
import math | ||
import paddle | ||
import paddle.fluid as fluid | ||
import paddle.fluid.profiler as profiler | ||
import paddle.fluid.unique_name as nameGen | ||
from paddle.fluid import core | ||
import unittest | ||
from multiprocessing import Process | ||
import paddle.fluid.layers as layers | ||
from functools import reduce | ||
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main | ||
|
||
paddle.enable_static() | ||
|
||
|
||
class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase): | ||
def __init__(self): | ||
self.global_ring_id = 0 | ||
|
||
def get_model(self, main_prog, startup_program, rank): | ||
with fluid.program_guard(main_prog, startup_program): | ||
tindata = layers.data( | ||
name="tindata", shape=[10, 1000], dtype='float32') | ||
tindata = paddle.split(tindata, 2, axis=0) | ||
tout_data = [] | ||
paddle.distributed.alltoall(tindata, tout_data) | ||
return tout_data | ||
|
||
|
||
if __name__ == "__main__": | ||
runtime_main(TestCollectiveAllToAllAPI, "alltoall") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
group (Group, optional): The group instance return by new_group or None for global default group.Default: None.
下同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.