-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
646 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/operators/limit_by_capacity_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class LimitByCapacityOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("expert_count"), "Input", "expert_count", | ||
"LimitByCapacity"); | ||
OP_INOUT_CHECK(ctx->HasInput("capacity"), "Input", "capacity", | ||
"LimitByCapacity"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "LimitByCapacity"); | ||
|
||
ctx->ShareDim("expert_count", "Out"); | ||
ctx->ShareLoD("expert_count", "Out"); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
// the dtype of the expert_count and capacity should be same as int64 | ||
auto expert_count_dtype = | ||
OperatorWithKernel::IndicateVarDataType(ctx, "expert_count"); | ||
auto capacity_dtype = | ||
OperatorWithKernel::IndicateVarDataType(ctx, "capacity"); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
expert_count_dtype, capacity_dtype, | ||
platform::errors::InvalidArgument( | ||
"The dtype of the expert_count and capacity should be same")); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
expert_count_dtype, framework::proto::VarType::INT64, | ||
platform::errors::InvalidArgument("The dtype of the expert_count and " | ||
"capacity should be same as int64")); | ||
return framework::OpKernelType(expert_count_dtype, ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class LimitByCapacityOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("expert_count", "(Tensor) The input expert count tensor."); | ||
AddInput("capacity", "(Tensor) The input capacity."); | ||
AddOutput("Out", | ||
"(Tensor) The output tensor expert count limit by capacity."); | ||
AddAttr<int>("n_worker", "(int), The number of works."); | ||
AddComment( | ||
R"DOC(limit_by_capacity Operator.limit expert count by capacity.)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
|
||
REGISTER_OP_CPU_KERNEL(limit_by_capacity, ops::LimitByCapacityOpCPUKernel<int>, | ||
ops::LimitByCapacityOpCPUKernel<int64_t>); | ||
|
||
REGISTER_OP_WITHOUT_GRADIENT(limit_by_capacity, ops::LimitByCapacityOp, | ||
ops::LimitByCapacityOpMaker); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/operators/limit_by_capacity_op.h" | ||
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" | ||
#include "paddle/fluid/platform/float16.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1) | ||
|
||
using LoDTensor = framework::LoDTensor; | ||
using Tensor = framework::Tensor; | ||
|
||
template <typename T> | ||
__global__ void limit_by_capacity_impl(const T* expc, T* cap, T* out, | ||
const int n_expert, const int n_worker) { | ||
int eid = blockIdx.y; | ||
int wid = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (wid < n_worker) { | ||
auto proposal = expc[wid * n_expert + eid]; | ||
// int cap_left = atomicSub(cap + eid, proposal); | ||
auto cap_left = paddle::platform::CudaAtomicAdd(cap + eid, proposal * (-1)); | ||
if (cap_left >= proposal) { | ||
out[wid * n_expert + eid] = proposal; | ||
} else if (cap_left >= 0) { | ||
out[wid * n_expert + eid] = cap_left; | ||
} else { | ||
out[wid * n_expert + eid] = 0; | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
class LimitByCapacityOpCUDAKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto expert_count = context.Input<Tensor>("expert_count"); | ||
auto capacity = context.Input<Tensor>("capacity"); | ||
auto n_worker = context.Attr<int>("n_worker"); | ||
auto out = context.Output<Tensor>("Out"); | ||
|
||
auto n_expert = expert_count->numel() / n_worker; | ||
// std::cout << "n_expert" << n_expert << std::endl; | ||
const auto place = context.GetPlace(); | ||
const auto& dev_ctx = | ||
context.template device_context<platform::CUDADeviceContext>(); | ||
|
||
dim3 grid_dim(CEIL(n_worker, 1024), n_expert); | ||
dim3 block_dim(1024); | ||
auto out_data = out->mutable_data<T>(place); | ||
const T* ec_data = expert_count->data<T>(); | ||
|
||
framework::Tensor capacity_copy; | ||
framework::TensorCopy(*capacity, place, dev_ctx, &capacity_copy); | ||
T* cap_data = capacity_copy.mutable_data<T>(place); | ||
|
||
limit_by_capacity_impl<T><<<grid_dim, block_dim, 0, dev_ctx.stream()>>>( | ||
ec_data, cap_data, out_data, n_expert, n_worker); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
|
||
REGISTER_OP_CUDA_KERNEL(limit_by_capacity, | ||
ops::LimitByCapacityOpCUDAKernel<int64_t>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
#include "paddle/fluid/framework/data_type.h" | ||
#include "paddle/fluid/framework/lod_tensor.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
|
||
#if defined(PADDLE_WITH_GLOO) | ||
#include "paddle/fluid/framework/fleet/gloo_wrapper.h" | ||
#endif | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
class LimitByCapacityOpCPUKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
PADDLE_THROW(platform::errors::Unavailable( | ||
"Do not support limit by capacity op for cpu kernel now.")); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/operators/prune_gate_by_capacity_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class PruneGateByCapacityOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("GateIdx"), "Input", "GateIdx", | ||
"prun_gate_by_capacity"); | ||
OP_INOUT_CHECK(ctx->HasInput("ExpertCount"), "Input", "ExpertCount", | ||
"prun_gate_by_capacity"); | ||
|
||
OP_INOUT_CHECK(ctx->HasOutput("NewGateIdx"), "Output", "NewGateIdx", | ||
"prun_gate_by_capacity"); | ||
// OP_INOUT_CHECK(ctx->HasOutput("ExpertCountOut"), "Output", | ||
// "ExpertCountOut", | ||
// "prun_gate_by_capacity"); | ||
// auto gate_idx_dims = ctx->GetInputDim("GateIdx"); | ||
auto expert_count_dims = ctx->GetInputDim("ExpertCount"); | ||
|
||
int64_t n_expert = ctx->Attrs().Get<int64_t>("n_expert"); | ||
int64_t n_worker = ctx->Attrs().Get<int64_t>("n_worker"); | ||
|
||
int64_t expert_count_num_ele = 1; | ||
for (int64_t i = 0; i < expert_count_dims.size(); i++) { | ||
expert_count_num_ele *= expert_count_dims[i]; | ||
} | ||
|
||
PADDLE_ENFORCE_EQ( | ||
expert_count_num_ele, n_expert * n_worker, | ||
platform::errors::Unavailable( | ||
"The number of elements for expert_count is ( %ld ) incorrect. " | ||
"Because the number of expert_count must equal the " | ||
"product of n_worker ( %ld ) and n_expert ( %ld ). " | ||
"Please input appropriate expert_count again!", | ||
expert_count_num_ele, n_worker, n_expert)); | ||
|
||
auto gate_idx_in_dims = ctx->GetInputDim("GateIdx"); | ||
// auto expert_count_in_dims = ctx->GetInputDim("ExpertCount"); | ||
ctx->SetOutputDim("NewGateIdx", gate_idx_in_dims); | ||
// ctx->SetOutputDim("ExpertCountOut", expert_count_in_dims); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
auto gate_idx_data_type = | ||
OperatorWithKernel::IndicateVarDataType(ctx, "GateIdx"); | ||
auto expert_count_data_type = | ||
OperatorWithKernel::IndicateVarDataType(ctx, "ExpertCount"); | ||
PADDLE_ENFORCE_EQ( | ||
gate_idx_data_type, expert_count_data_type, | ||
platform::errors::InvalidArgument( | ||
"The dtype of the gate_idx and expert_count should be same")); | ||
PADDLE_ENFORCE_EQ(gate_idx_data_type, framework::proto::VarType::INT64, | ||
platform::errors::InvalidArgument( | ||
"The dtype of the gate_idx and expert_count should " | ||
"be same as int64")); | ||
return framework::OpKernelType(gate_idx_data_type, ctx.device_context()); | ||
} | ||
}; | ||
|
||
class PruneGateByCapacityOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("GateIdx", | ||
"(Tensor), The gate_id sequence corresponding to the input data."); | ||
AddInput("ExpertCount", | ||
"(Tensor), The quantity value counted on the gate_id sequence of " | ||
"the input data."); | ||
AddAttr<int64_t>("n_expert", "The number of Experts on each worker") | ||
.SetDefault(0); | ||
AddAttr<int64_t>("n_worker", "The number of workers on the trainer") | ||
.SetDefault(0); | ||
|
||
AddOutput("NewGateIdx", | ||
"(Tensor), The gate_id sequence corresponding to the new input " | ||
"data after passing through prune."); | ||
// AddOutput( | ||
// "ExpertCountOut", | ||
// "(Tensor), The copy quantity value counted on the gate_id sequence of | ||
// " | ||
// "the input data."); | ||
|
||
AddComment(R"DOC( | ||
prune_gate_by_capacity Operator. | ||
This operator is used to prune gate by capacity(CUDA). | ||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OP_WITHOUT_GRADIENT(prune_gate_by_capacity, ops::PruneGateByCapacityOp, | ||
ops::PruneGateByCapacityOpMaker); | ||
|
||
REGISTER_OP_CPU_KERNEL( | ||
prune_gate_by_capacity, | ||
ops::PruneGateByCapacityCPUKernel<paddle::platform::CPUDeviceContext, int>, | ||
ops::PruneGateByCapacityCPUKernel<paddle::platform::CPUDeviceContext, | ||
int64_t>); |
Oops, something went wrong.
1965890
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🕵️ CI failures summary
🔍 PR: #33 Commit ID: 1965890 contains failed CI.
🔹 Failed: PR-CI-Windows-OPENBLAS
test_failed