Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upload global scatter and global gather operators related files #35546

Merged
merged 3 commits into from
Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions paddle/fluid/operators/collective/global_gather_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/global_gather_op.h"

namespace paddle {
namespace operators {

class GlobalGatherOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GlobalGather");
OP_INOUT_CHECK(ctx->HasInput("local_count"), "Input", "local_count",
"GlobalGather");
OP_INOUT_CHECK(ctx->HasInput("global_count"), "Input", "global_count",
"GlobalGather");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GlobalGather");
int ring_id = ctx->Attrs().Get<int>("ring_id");
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for global gather op must be non-negative.",
ring_id));
auto input_dims = ctx->GetInputDim("X");
auto ndim_input = input_dims.size();
// dim check
PADDLE_ENFORCE_EQ(ndim_input, 2,
platform::errors::InvalidArgument(
"The input tensor's dimension must be 2. "
"But received input's dimension = %d.",
ndim_input));
framework::DDim out_dims = framework::make_ddim({-1, -1});
ctx->SetOutputDim("Out", out_dims);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};

class GlobalGatherOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) tensor send.");
AddInput("local_count",
"(Tensor) Tensor which has n_expert * world_size elements that "
"indicates"
"how many data needed to be received from each expert.");
AddInput("global_count",
"(Tensor) Tensor which has n_expert * world_size elements that "
"indicates"
"how many data needed to be sent to each expert.");
AddOutput("Out", "(Tensor) the result of global_gather.");
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(R"DOC(
Global Gather Operator
Gather data in X to n_expert * world_size exeperts according to
local_count and receive tensors from n_expert * world_size experts according
to global_count.
)DOC");
}
};

template <typename T>
class GlobalGatherOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("global_scatter");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetInput("local_count", this->Input("local_count"));
retv->SetInput("global_count", this->Input("global_count"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(global_gather, ops::GlobalGatherOp, ops::GlobalGatherOpMaker,
ops::GlobalGatherOpGradMaker<paddle::framework::OpDesc>,
ops::GlobalGatherOpGradMaker<paddle::imperative::OpBase>)

REGISTER_OP_CPU_KERNEL(global_gather, ops::GlobalGatherOpCPUKernel<float>,
ops::GlobalGatherOpCPUKernel<double>,
ops::GlobalGatherOpCPUKernel<int>,
ops::GlobalGatherOpCPUKernel<int64_t>,
ops::GlobalGatherOpCPUKernel<plat::float16>);
146 changes: 146 additions & 0 deletions paddle/fluid/operators/collective/global_gather_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/global_gather_op.h"

#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif

namespace paddle {
namespace operators {
template <typename T>
class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL)
#if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
auto local_count = ctx.Input<framework::LoDTensor>("local_count");
auto global_count = ctx.Input<framework::LoDTensor>("global_count");
auto local_count_type = local_count->type();
auto global_count_type = global_count->type();
if (local_count_type != framework::proto::VarType::INT64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Please use int64 type in local_count."));
}
if (global_count_type != framework::proto::VarType::INT64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Please use int64 type in global_count."));
}
auto out = ctx.Output<framework::LoDTensor>("Out");
const int64_t* cpu_local_count_data;
const int64_t* cpu_global_count_data;
auto local_count_len = 0;

framework::Tensor cpu_local_count;
if (platform::is_cpu_place(local_count->place())) {
cpu_local_count_data = local_count->data<int64_t>();
local_count_len = local_count->numel();
} else {
framework::TensorCopySync(*local_count, platform::CPUPlace(),
&cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>();
youth123 marked this conversation as resolved.
Show resolved Hide resolved
local_count_len = cpu_local_count.numel();
}

framework::Tensor cpu_global_count;
if (platform::is_cpu_place(global_count->place())) {
cpu_global_count_data = global_count->data<int64_t>();
} else {
framework::TensorCopySync(*global_count, platform::CPUPlace(),
&cpu_global_count);
cpu_global_count_data = cpu_global_count.data<int64_t>();
youth123 marked this conversation as resolved.
Show resolved Hide resolved
}

ncclDataType_t dtype = platform::ToNCCLDataType(x->type());

int ring_id = ctx.Attr<int>("ring_id");
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for global gather op must be non-negative.",
ring_id));
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
int nranks = comm->nranks();
auto in_feat = x->dims()[1];
auto n_expert = local_count->dims()[0] / nranks;

auto fwd_count = 0;

for (auto i = 0; i < local_count_len; ++i) {
fwd_count += cpu_local_count_data[i];
}
framework::DDim out_dims = framework::make_ddim({fwd_count, in_feat});
int64_t* expert_ptr = new int64_t[n_expert * nranks];
expert_ptr[0] = 0;
auto tot_experts = n_expert * nranks;
for (auto i = 1; i < tot_experts; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1];
}
auto send_ptr = 0;
auto send_buf = x->data<T>();
auto recv_buf = out->mutable_data<T>(out_dims, place);

for (auto i = 0; i < n_expert; ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclSend(send_buf + send_ptr * in_feat,
cpu_global_count_data[idx] * in_feat,
dtype, j, comm->comm(), stream));
send_ptr += cpu_global_count_data[idx];
}
if (cpu_local_count_data[idx]) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclRecv(recv_buf + expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat,
dtype, j, comm->comm(), stream));
}
}
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
}
#else
PADDLE_THROW(
platform::errors::Unavailable("NCCL version >= 2.7.3 is needed."));
#endif
#else
PADDLE_THROW(
platform::errors::Unavailable("PaddlePaddle should compile with GPU."));
#endif
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(global_gather, ops::GlobalGatherOpCUDAKernel<float>,
ops::GlobalGatherOpCUDAKernel<double>,
ops::GlobalGatherOpCUDAKernel<int>,
ops::GlobalGatherOpCUDAKernel<int64_t>,
ops::GlobalGatherOpCUDAKernel<plat::float16>);
37 changes: 37 additions & 0 deletions paddle/fluid/operators/collective/global_gather_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"

#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif

namespace paddle {
namespace operators {

template <typename T>
class GlobalGatherOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support global gather op for cpu kernel now."));
}
};

} // namespace operators
} // namespace paddle
Loading