Skip to content

Commit

Permalink
add moe gate
Browse files Browse the repository at this point in the history
  • Loading branch information
liyagit21 authored and sljlp committed Jan 6, 2022
1 parent 9f0958f commit 1965890
Show file tree
Hide file tree
Showing 8 changed files with 646 additions and 0 deletions.
80 changes: 80 additions & 0 deletions paddle/fluid/operators/limit_by_capacity_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/limit_by_capacity_op.h"

namespace paddle {
namespace operators {

class LimitByCapacityOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("expert_count"), "Input", "expert_count",
"LimitByCapacity");
OP_INOUT_CHECK(ctx->HasInput("capacity"), "Input", "capacity",
"LimitByCapacity");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "LimitByCapacity");

ctx->ShareDim("expert_count", "Out");
ctx->ShareLoD("expert_count", "Out");
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// the dtype of the expert_count and capacity should be same as int64
auto expert_count_dtype =
OperatorWithKernel::IndicateVarDataType(ctx, "expert_count");
auto capacity_dtype =
OperatorWithKernel::IndicateVarDataType(ctx, "capacity");

PADDLE_ENFORCE_EQ(
expert_count_dtype, capacity_dtype,
platform::errors::InvalidArgument(
"The dtype of the expert_count and capacity should be same"));

PADDLE_ENFORCE_EQ(
expert_count_dtype, framework::proto::VarType::INT64,
platform::errors::InvalidArgument("The dtype of the expert_count and "
"capacity should be same as int64"));
return framework::OpKernelType(expert_count_dtype, ctx.GetPlace());
}
};

class LimitByCapacityOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("expert_count", "(Tensor) The input expert count tensor.");
AddInput("capacity", "(Tensor) The input capacity.");
AddOutput("Out",
"(Tensor) The output tensor expert count limit by capacity.");
AddAttr<int>("n_worker", "(int), The number of works.");
AddComment(
R"DOC(limit_by_capacity Operator.limit expert count by capacity.)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CPU_KERNEL(limit_by_capacity, ops::LimitByCapacityOpCPUKernel<int>,
ops::LimitByCapacityOpCPUKernel<int64_t>);

REGISTER_OP_WITHOUT_GRADIENT(limit_by_capacity, ops::LimitByCapacityOp,
ops::LimitByCapacityOpMaker);
83 changes: 83 additions & 0 deletions paddle/fluid/operators/limit_by_capacity_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/limit_by_capacity_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1)

using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

template <typename T>
__global__ void limit_by_capacity_impl(const T* expc, T* cap, T* out,
const int n_expert, const int n_worker) {
int eid = blockIdx.y;
int wid = blockIdx.x * blockDim.x + threadIdx.x;
if (wid < n_worker) {
auto proposal = expc[wid * n_expert + eid];
// int cap_left = atomicSub(cap + eid, proposal);
auto cap_left = paddle::platform::CudaAtomicAdd(cap + eid, proposal * (-1));
if (cap_left >= proposal) {
out[wid * n_expert + eid] = proposal;
} else if (cap_left >= 0) {
out[wid * n_expert + eid] = cap_left;
} else {
out[wid * n_expert + eid] = 0;
}
}
}

template <typename T>
class LimitByCapacityOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto expert_count = context.Input<Tensor>("expert_count");
auto capacity = context.Input<Tensor>("capacity");
auto n_worker = context.Attr<int>("n_worker");
auto out = context.Output<Tensor>("Out");

auto n_expert = expert_count->numel() / n_worker;
// std::cout << "n_expert" << n_expert << std::endl;
const auto place = context.GetPlace();
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();

dim3 grid_dim(CEIL(n_worker, 1024), n_expert);
dim3 block_dim(1024);
auto out_data = out->mutable_data<T>(place);
const T* ec_data = expert_count->data<T>();

framework::Tensor capacity_copy;
framework::TensorCopy(*capacity, place, dev_ctx, &capacity_copy);
T* cap_data = capacity_copy.mutable_data<T>(place);

limit_by_capacity_impl<T><<<grid_dim, block_dim, 0, dev_ctx.stream()>>>(
ec_data, cap_data, out_data, n_expert, n_worker);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(limit_by_capacity,
ops::LimitByCapacityOpCUDAKernel<int64_t>);
37 changes: 37 additions & 0 deletions paddle/fluid/operators/limit_by_capacity_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"

#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif

namespace paddle {
namespace operators {

template <typename T>
class LimitByCapacityOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support limit by capacity op for cpu kernel now."));
}
};

} // namespace operators
} // namespace paddle
123 changes: 123 additions & 0 deletions paddle/fluid/operators/prune_gate_by_capacity_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/prune_gate_by_capacity_op.h"

namespace paddle {
namespace operators {

class PruneGateByCapacityOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("GateIdx"), "Input", "GateIdx",
"prun_gate_by_capacity");
OP_INOUT_CHECK(ctx->HasInput("ExpertCount"), "Input", "ExpertCount",
"prun_gate_by_capacity");

OP_INOUT_CHECK(ctx->HasOutput("NewGateIdx"), "Output", "NewGateIdx",
"prun_gate_by_capacity");
// OP_INOUT_CHECK(ctx->HasOutput("ExpertCountOut"), "Output",
// "ExpertCountOut",
// "prun_gate_by_capacity");
// auto gate_idx_dims = ctx->GetInputDim("GateIdx");
auto expert_count_dims = ctx->GetInputDim("ExpertCount");

int64_t n_expert = ctx->Attrs().Get<int64_t>("n_expert");
int64_t n_worker = ctx->Attrs().Get<int64_t>("n_worker");

int64_t expert_count_num_ele = 1;
for (int64_t i = 0; i < expert_count_dims.size(); i++) {
expert_count_num_ele *= expert_count_dims[i];
}

PADDLE_ENFORCE_EQ(
expert_count_num_ele, n_expert * n_worker,
platform::errors::Unavailable(
"The number of elements for expert_count is ( %ld ) incorrect. "
"Because the number of expert_count must equal the "
"product of n_worker ( %ld ) and n_expert ( %ld ). "
"Please input appropriate expert_count again!",
expert_count_num_ele, n_worker, n_expert));

auto gate_idx_in_dims = ctx->GetInputDim("GateIdx");
// auto expert_count_in_dims = ctx->GetInputDim("ExpertCount");
ctx->SetOutputDim("NewGateIdx", gate_idx_in_dims);
// ctx->SetOutputDim("ExpertCountOut", expert_count_in_dims);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto gate_idx_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "GateIdx");
auto expert_count_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "ExpertCount");
PADDLE_ENFORCE_EQ(
gate_idx_data_type, expert_count_data_type,
platform::errors::InvalidArgument(
"The dtype of the gate_idx and expert_count should be same"));
PADDLE_ENFORCE_EQ(gate_idx_data_type, framework::proto::VarType::INT64,
platform::errors::InvalidArgument(
"The dtype of the gate_idx and expert_count should "
"be same as int64"));
return framework::OpKernelType(gate_idx_data_type, ctx.device_context());
}
};

class PruneGateByCapacityOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("GateIdx",
"(Tensor), The gate_id sequence corresponding to the input data.");
AddInput("ExpertCount",
"(Tensor), The quantity value counted on the gate_id sequence of "
"the input data.");
AddAttr<int64_t>("n_expert", "The number of Experts on each worker")
.SetDefault(0);
AddAttr<int64_t>("n_worker", "The number of workers on the trainer")
.SetDefault(0);

AddOutput("NewGateIdx",
"(Tensor), The gate_id sequence corresponding to the new input "
"data after passing through prune.");
// AddOutput(
// "ExpertCountOut",
// "(Tensor), The copy quantity value counted on the gate_id sequence of
// "
// "the input data.");

AddComment(R"DOC(
prune_gate_by_capacity Operator.
This operator is used to prune gate by capacity(CUDA).
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_WITHOUT_GRADIENT(prune_gate_by_capacity, ops::PruneGateByCapacityOp,
ops::PruneGateByCapacityOpMaker);

REGISTER_OP_CPU_KERNEL(
prune_gate_by_capacity,
ops::PruneGateByCapacityCPUKernel<paddle::platform::CPUDeviceContext, int>,
ops::PruneGateByCapacityCPUKernel<paddle::platform::CPUDeviceContext,
int64_t>);
Loading

1 comment on commit 1965890

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #33 Commit ID: 1965890 contains failed CI.

🔹 Failed: PR-CI-Windows-OPENBLAS

test_failed
2022-01-06 17:00:38 The following tests FAILED:
2022-01-06 17:00:38 920 - test_prune_gate_by_capacity_op (Failed)
2022-01-06 17:00:38 929 - test_prune_gate_by_capacity_op (Failed)
2022-01-06 17:00:38 929 - test_prune_gate_by_capacity_op (Failed)
2022-01-06 17:00:38 C:\home\workspace\Paddle\build>goto:eof
2022-01-06 17:00:38 C:\home\workspace\Paddle\build>set error_code=8
2022-01-06 17:00:38 C:\home\workspace\Paddle\build>for /F %# in ('wmic os get localdatetime|findstr 20') do set end=%#
2022-01-06 17:00:38 C:\home\workspace\Paddle\build>set end=20220106170037.148000+480
2022-01-06 17:00:38 C:\home\workspace\Paddle\build>set end=0106170037
2022-01-06 17:00:38 C:\home\workspace\Paddle\build>call :timestamp "0106163522" "0106170037" "1 card TestCases Total"
2022-01-06 17:00:38 C:\home\workspace\Paddle\build>setlocal enabledelayedexpansion
2022-01-06 17:00:38 578122
2022-01-06 17:00:38 "Windows 1 card TestCases Total Time: 1515s"
2022-01-06 17:00:38 ipipe_log_param_Windows_1_card_TestCases_Total_Time: 1515s
2022-01-06 17:00:38 578122
2022-01-06 17:00:38 "Windows TestCases Total Time: 1515s"
2022-01-06 17:00:38 ipipe_log_param_Windows_TestCases_Total_Time: 1515s
2022-01-06 17:00:38 Running unit tests failed, will exit
2022-01-06 17:00:38 EXCODE: 8

Please sign in to comment.