Skip to content

Commit

Permalink
[MoE]Assign pos op (#40580)
Browse files Browse the repository at this point in the history
* # This is a combination of 10 commits.
# The first commit's message is:
add expert count op

add ut for expert_count

# This is the 2nd commit message:

update UT only for cuda

# This is the 3rd commit message:

fix for rocm

# This is the 4th commit message:

update ut

# This is the 5th commit message:

add moe module

# This is the 6th commit message:

add expert count op

add ut for expert_count

# This is the 7th commit message:

update UT only for cuda

# This is the 8th commit message:

update ut

# This is the 9th commit message:

add moe module

# This is the 10th commit message:

make expert count private

* add assign pos op

* fix upper num name

* add api _assign pos

* add ut for assign pos op

* update date

* fix for win

* update for test (timeout)

* fix ut

* update

* fix ut for number count

Co-authored-by: hlygit66666 <2570058140@qq.com>
  • Loading branch information
sljlp and liyagit21 authored Mar 24, 2022
1 parent 9d8cfc1 commit 305f32d
Show file tree
Hide file tree
Showing 9 changed files with 430 additions and 43 deletions.
80 changes: 80 additions & 0 deletions paddle/fluid/operators/assign_pos_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/assign_pos_op.h"

namespace paddle {
namespace operators {

class AssignPosOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("cum_count"), "Input", "cum_count",
"AssignPos");
OP_INOUT_CHECK(ctx->HasInput("eff_num_len"), "Input", "eff_num_len",
"AssignPos");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AssignPos");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "AssignPos");
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto cum_count_dtype =
OperatorWithKernel::IndicateVarDataType(ctx, "cum_count");
auto X_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");

PADDLE_ENFORCE_EQ(cum_count_dtype, X_dtype,
platform::errors::InvalidArgument(
"The dtype of the cum_count and X should be same"));
PADDLE_ENFORCE_EQ(cum_count_dtype, framework::proto::VarType::INT64,
platform::errors::InvalidArgument(
"The dtype of the cum_count_dtype, eff_num_len and "
"X should be same as int64"));
return framework::OpKernelType(cum_count_dtype, ctx.device_context());
}
};

class AssignPosOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "numbers to scatter.");
AddInput("cum_count", "The cumulative sum count of numbers.");
AddInput("eff_num_len",
"The effective numbers of numbers should be scattered.");
AddOutput("Out", "Assemble numbers in the order of counters.");

AddComment(R"DOC(
assign_pos_op Operator.
Assign pos decides which tokens should be fetched belong to
specially counter orderingly.
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_WITHOUT_GRADIENT(assign_pos, ops::AssignPosOp,
ops::AssignPosOpMaker);

REGISTER_OP_CPU_KERNEL(assign_pos, ops::AssignPosOpCPUKernel<int>,
ops::AssignPosOpCPUKernel<int64_t>);
94 changes: 94 additions & 0 deletions paddle/fluid/operators/assign_pos_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/assign_pos_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"

DECLARE_bool(avoid_op_randomness);

namespace paddle {
namespace operators {

static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;

static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}

template <typename T>
__global__ void AssignPos(T* cum_count, const T* numbers, T* out,
int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {
int number_idx = numbers[i];
if (number_idx > -1) {
int p = platform::CudaAtomicAdd(cum_count + number_idx, -1);
out[p - 1] = i;
}
}
}

template <typename T>
class AssignPosCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// assign pos decides which tokens should be fetched belong to specially
// counter orderingly.
auto cum_count = context.Input<LoDTensor>(
"cum_count"); // (counter number) int32 | int64
auto numbers =
context.Input<LoDTensor>("X"); // (batch_size * seq_len, topk) int32
auto eff_num_len =
context.Input<LoDTensor>("eff_num_len"); // (sum(cum_count))
auto out = context.Output<LoDTensor>("Out"); // (cum_count) value ranges
// from 0 to batch_size *
// seq_len * topk
auto place = context.GetPlace();
auto numel = numbers->numel();
T* cum_data = const_cast<T*>(cum_count->data<T>());
auto cum_size = cum_count->numel();

framework::Tensor cpu_eff_num_len;
int64_t cpu_eff_num_len_data = 0;
if (platform::is_cpu_place(eff_num_len->place())) {
cpu_eff_num_len_data = eff_num_len->data<T>()[0];
} else {
framework::TensorCopySync(*eff_num_len, platform::CPUPlace(),
&cpu_eff_num_len);
cpu_eff_num_len_data = cpu_eff_num_len.data<T>()[0];
}
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
framework::DDim out_dims = phi::make_ddim({cpu_eff_num_len_data});
auto out_data = out->mutable_data<T>(out_dims, place);

const T* num_data = numbers->data<T>();

int blocks = NumBlocks(numel);
int threads = kNumCUDAThreads;

AssignPos<T><<<blocks, threads, 0, dev_ctx.stream()>>>(cum_data, num_data,
out_data, numel);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(assign_pos, ops::AssignPosCUDAKernel<int64_t>);
35 changes: 35 additions & 0 deletions paddle/fluid/operators/assign_pos_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;

template <typename T>
class AssignPosOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support assign pos op for cpu kernel now."));
}
};

} // namespace operators
} // namespace paddle
24 changes: 11 additions & 13 deletions paddle/fluid/operators/number_count_op.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -22,34 +22,32 @@ class NumberCountOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("gate_idx"), "Input", "gate_idx",
"NumberCount");
OP_INOUT_CHECK(ctx->HasInput("numbers"), "Input", "numbers", "NumberCount");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "number_count",
"NumberCount");
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// the dtype of the gate_idx should be same as int64
auto gate_idx_dtype =
OperatorWithKernel::IndicateVarDataType(ctx, "gate_idx");
// the dtype of the numbers should be same as int64
auto number_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "numbers");

PADDLE_ENFORCE_EQ(gate_idx_dtype, framework::proto::VarType::INT64,
PADDLE_ENFORCE_EQ(number_dtype, framework::proto::VarType::INT64,
platform::errors::InvalidArgument(
"The dtype of the gate_idx_dtype should be int64"));
return framework::OpKernelType(gate_idx_dtype, ctx.GetPlace());
"The dtype of the number_dtype should be int64"));
return framework::OpKernelType(number_dtype, ctx.GetPlace());
}
};

class NumberCountOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("gate_idx", "(Tensor) The input gate index tensor.");
AddOutput("Out", "(Tensor) The output expert count tensor.");
AddAttr<int>("upper_range", "(int), The number of experts.");
AddInput("numbers", "(Tensor) The input gate index tensor.");
AddOutput("Out", "(Tensor) The output number count tensor.");
AddAttr<int>("upper_range", "(int), The number of different numbers.");

AddComment(R"DOC(number_count Operator.count gate indices.)DOC");
AddComment(R"DOC(number_count Operator.count numbers.)DOC");
}
};

Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/operators/number_count_op.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,7 +38,7 @@ __global__ void initialize_zero_kernel(T* data, const int length) {
}

template <typename T>
__global__ void NumberCount(const T* gate_idx, T* number_count,
__global__ void NumberCount(const T* numbers, T* number_count,
int64_t batch_size, int upper_range) {
int res_tmp[PERTHREAD_EXPERTS] = {0};
int expert_min = blockIdx.x * PERTHREAD_EXPERTS;
Expand All @@ -47,7 +47,7 @@ __global__ void NumberCount(const T* gate_idx, T* number_count,
expert_max = upper_range;
}
for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
T idx = gate_idx[i];
T idx = numbers[i];
if (idx == -1) {
continue;
}
Expand Down Expand Up @@ -76,18 +76,18 @@ template <typename T>
class NumberCountOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto gate_idx = context.Input<LoDTensor>("gate_idx");
auto numbers = context.Input<LoDTensor>("numbers");
auto upper_range = context.Attr<int>("upper_range");
auto number_count = context.Output<LoDTensor>("Out");

int64_t batch_size = gate_idx->numel();
int64_t batch_size = numbers->numel();
auto place = context.GetPlace();
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();

framework::DDim out_dims = phi::make_ddim({upper_range});
auto out_data = number_count->mutable_data<T>(out_dims, place);
const T* gate_data = gate_idx->data<T>();
const T* gate_data = numbers->data<T>();

initialize_zero_kernel<
T><<<GET_BLOCKS(upper_range), CUDA_NUM_THREADS, 0, dev_ctx.stream()>>>(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/number_count_op.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down
Loading

0 comments on commit 305f32d

Please sign in to comment.