diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index fe82565bc36f3..b1676883ff39b 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -23,6 +23,7 @@ register_operators( fused_transformer_op fused_feedforward_op fused_multi_transformer_op + fused_moe_op fused_multi_transformer_int8_op fused_bias_dropout_residual_layer_norm_op resnet_unit_op @@ -120,6 +121,7 @@ if(WITH_GPU OR WITH_ROCM) # fused_attention_op op_library(fused_attention_op) op_library(fused_multi_transformer_op) + op_library(fused_moe_op) op_library(fused_multi_transformer_int8_op) op_library(fused_bias_dropout_residual_layer_norm_op) endif() diff --git a/paddle/fluid/operators/fused/fused_moe_op.cc b/paddle/fluid/operators/fused/fused_moe_op.cc new file mode 100644 index 0000000000000..faaaf5d5b1938 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_moe_op.cc @@ -0,0 +1,132 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +class FusedMoeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *context) const override { + // input + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "fused_moe"); + OP_INOUT_CHECK(context->HasInput("GateWeight"), + "Input", + "GateWeight", + "fused_moe"); + OP_INOUT_CHECK(context->HasInput("GateBias"), + "Input", + "GateBias", + "fused_moe"); + OP_INOUT_CHECK(context->HasInput("LnScale"), + "Input", + "LnScale", + "fused_moe"); + OP_INOUT_CHECK(context->HasInput("LnBias"), + "Input", + "LnBias", + "fused_moe"); + OP_INOUT_CHECK(context->HasInputs("ExpertsWeight1"), + "Input", + "ExpertsWeight1", + "fused_moe"); + OP_INOUT_CHECK(context->HasInputs("ExpertsBias1"), + "Input", + "ExpertsBias1", + "fused_moe"); + OP_INOUT_CHECK(context->HasInputs("ExpertsWeight2"), + "Input", + "ExpertsWeight2", + "fused_moe"); + OP_INOUT_CHECK(context->HasInputs("ExpertsBias2"), + "Input", + "ExpertsBias2", + "fused_moe"); + // output + OP_INOUT_CHECK(context->HasOutput("Out"), + "Output", + "Out", + "fused_moe"); + auto x_dims = context->GetInputDim("X"); + context->SetOutputDim("Out", x_dims); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class FusedMoeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + // AsDispensable 可有可无 + // AsDuplicable 可复制 + // input + AddInput("X", "The input of FusedMoe op"); + AddInput("GateWeight", "The gate weight of FusedMoe op"); + AddInput("GateBias", "The gate bias of FusedMoe op"); + AddInput("LnScale", "The ln scale of FusedMoe op"); + AddInput("LnBias", "The LnBias of FusedMoe op"); + AddInput("ExpertsWeight1", "The expert linear1 weights of fused_moe op") + .AsDuplicable(); + AddInput("ExpertsBias1", "The expert linear1 biases of fused_moe op") + .AsDuplicable() + .AsDispensable(); + AddInput("ExpertsWeight2", "The expert linear2 weights of fused_moe op") + .AsDuplicable(); + AddInput("ExpertsBias2", "The expert linear2 biases of fused_moe op") + .AsDuplicable() + .AsDispensable(); + // output + AddOutput("Out", "Out"); + // attr + AddAttr("pre_layer_norm", "pre_layer_norm").SetDefault(true); + AddAttr("ln_epsilon", "ln_epsilon").SetDefault(1e-5f); + AddAttr("topk", "top k in gate").SetDefault(2); + AddAttr("mp_size", "mp_size").SetDefault(1); + AddAttr("mp_rank", "mp_rank").SetDefault(0); + AddAttr("num_expert", "num_expert").SetDefault(1); + AddAttr("world_size", "world_size").SetDefault(1); + AddAttr("moe_ring_id", "moe_ring_id").SetDefault(-1); + AddComment(R"DOC( + The fused_moe operator is the same as the following pseudo codes: + + pass + + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fused_moe, + ops::FusedMoeOp, + ops::FusedMoeOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_moe_op.cu b/paddle/fluid/operators/fused/fused_moe_op.cu new file mode 100644 index 0000000000000..0478c6cf551bc --- /dev/null +++ b/paddle/fluid/operators/fused/fused_moe_op.cu @@ -0,0 +1,501 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fused_moe_op.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +static void AllToAll(Tensor& tensor, // NOLINT + Tensor& out, + const int ring_id, + const phi::GPUContext& ctx) { + if (ring_id == -1) return; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + + if (map->has(ring_id)) { + paddle::distributed::ProcessGroup* pg = map->get(ring_id); + auto pg_nccl = static_cast(pg); + + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(tensor); + out_tensor.push_back(out); + auto task = pg_nccl->AllToAll(in_tensor, out_tensor, true, true); + task->Wait(); + } else { + auto dtype = platform::ToNCCLDataType( + framework::TransToProtoVarType(tensor.dtype())); + int64_t send_numel = tensor.numel(); // send_numel + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + int nranks = comm->nranks(); + auto stream = ctx.stream(); + + framework::DDim x_dims = tensor.dims(); + framework::DDim out_dims(x_dims); + PADDLE_ENFORCE_EQ( + x_dims[0] % nranks, + 0, + platform::errors::InvalidArgument( + "The first dimension size (%d) of the input tensor must be " + "divisible by the number of ranks (%d).", + x_dims[0], + nranks)); + auto send_buf = tensor.data(); + auto recv_buf = out.mutable_data(out_dims, place); + size_t offset = 0; + send_numel /= nranks; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < nranks; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + offset += send_numel; + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "PaddlePaddle should compile with NCCL or RCCL when used tensor model " + "parallel op.")); +#endif +} + +template +static void AllGather(Tensor& tensor, // NOLINT + Tensor& out, + const int ring_id, + const phi::GPUContext& ctx) { + if (ring_id == -1) return; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + + if (map->has(ring_id)) { + paddle::distributed::ProcessGroup* pg = map->get(ring_id); + auto pg_nccl = static_cast(pg); + + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(tensor); + out_tensor.push_back(out); + auto task = pg_nccl->AllGather(in_tensor, out_tensor, true, true); + task->Wait(); + } else { + auto dtype = platform::ToNCCLDataType( + framework::TransToProtoVarType(tensor.dtype())); + int64_t numel = tensor.numel(); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + auto stream = ctx.stream(); + auto out_dims = tensor.dims(); + int nranks = comm->nranks(); + out_dims[0] *= nranks; + out.mutable_data(out_dims, place); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( + tensor.data(), out.data(), numel, dtype, comm->comm(), stream)); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "PaddlePaddle should compile with NCCL or RCCL when used tensor model " + "parallel op.")); +#endif +} + +template +class FusedMoeOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using U = LayerNormParamType; + auto& dev_ctx = context.cuda_device_context(); + // input + auto* x = context.Input("X"); + auto* gate_weight = context.Input("GateWeight"); + auto* gate_bias = context.Input("GateBias"); + const bool pre_layer_norm = context.Attr("pre_layer_norm"); + auto* ln_scale = + pre_layer_norm ? context.Input("LnScale") : nullptr; + auto* ln_bias = + pre_layer_norm ? context.Input("LnBias") : nullptr; + // linear 1 + auto experts_weight1 = context.MultiInput("ExpertsWeight1"); + auto experts_bias1 = context.MultiInput("ExpertsBias1"); + // linear 2 + auto experts_weight2 = context.MultiInput("ExpertsWeight2"); + auto experts_bias2 = context.MultiInput("ExpertsBias2"); + + // output + auto* out = context.Output("Out"); + dev_ctx.Alloc(out, out->numel() * sizeof(T)); + + // attr + const float epsilon = context.Attr("ln_epsilon"); + const int topk = context.Attr("topk"); + const int mp_size = context.Attr("mp_size"); + const int mp_rank = context.Attr("mp_rank"); + const int num_expert = context.Attr("num_expert"); + const int world_size = context.Attr("world_size"); + const int moe_ring_id = context.Attr("moe_ring_id"); + + // dim + auto x_dim = x->dims(); + int bsz = x_dim[0]; + int seq_len = x_dim[1]; + int bsz_seq = bsz * seq_len; + int d_model = x_dim[2]; + int tot_expert = world_size * num_expert; + int dim_feedforward = experts_weight1[0]->dims()[1]; + + // pre_layer_norm + const U* ln_scale_ptr = + ln_scale == nullptr ? nullptr : ln_scale->data(); + const U* ln_bias_ptr = + ln_bias == nullptr ? nullptr : ln_bias->data(); + Tensor ln_mean, ln_variance; + ln_mean.Resize({{bsz_seq}}); + auto* ln_mean_data = + dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); + ln_variance.Resize({{bsz_seq}}); + auto* ln_variance_data = + dev_ctx.Alloc(&ln_variance, ln_variance.numel() * sizeof(U)); + FusedDropoutLayerNormHelper pre_layernorm_helper( + bsz_seq, d_model, epsilon); + // tmp out + Tensor ln_out; + ln_out.Resize({{bsz, seq_len, d_model}}); + auto *ln_out_data = dev_ctx.Alloc(&ln_out, ln_out.numel() * sizeof(T)); + // after slice, bsz_seq should be change + int sliced_bsz_seq = bsz_seq; + int start = 0; + int end = 0; + if (mp_size > 1) { + start = bsz_seq / world_size * mp_rank; + end = std::min(start + bsz_seq / world_size, bsz_seq); + sliced_bsz_seq = end - start; + } + int out_batch_size = sliced_bsz_seq * topk; + // slice + Tensor sliced_inp; + sliced_inp.Resize({{sliced_bsz_seq, d_model}}); + auto* sliced_inp_data = dev_ctx.Alloc(&sliced_inp, sliced_inp.numel() * sizeof(T)); + // gate linear + Tensor gate_out; + gate_out.Resize({{sliced_bsz_seq, tot_expert}}); + auto* gate_out_data = dev_ctx.Alloc(&gate_out, gate_out.numel() * sizeof(T)); + auto gate_linear_compute = AttnMatMul( + dev_ctx, false, false, sliced_bsz_seq, tot_expert, d_model, true); + // topk + Tensor topk_value, topk_idx; + topk_value.Resize({{sliced_bsz_seq, topk}}); + auto* topk_value_data = dev_ctx.Alloc(&topk_value, topk_value.numel() * sizeof(T)); + topk_idx.Resize({{sliced_bsz_seq, topk}}); + auto* topk_idx_data = dev_ctx.Alloc(&topk_idx, topk_idx.numel() * sizeof(int64_t)); + // local expert count, global expert count + Tensor local_expert_count, global_expert_count; + local_expert_count.Resize({{tot_expert}}); + global_expert_count.Resize({{tot_expert}}); + auto* local_expert_count_data = + dev_ctx.Alloc(&local_expert_count, local_expert_count.numel() * sizeof(int64_t)); + auto* global_expert_count_data = + dev_ctx.Alloc(&global_expert_count, global_expert_count.numel() * sizeof(int64_t)); + // fwd_expert_count, fwd_batch_size + Tensor fwd_expert_count, fwd_batch_size; + fwd_expert_count.Resize({{world_size, num_expert}}); + fwd_batch_size.Resize({{1}}); + auto* fwd_expert_count_data = + dev_ctx.Alloc(&fwd_expert_count, fwd_expert_count.numel() * sizeof(int64_t)); + auto* fwd_batch_size_data = + dev_ctx.Alloc(&fwd_batch_size, fwd_batch_size.numel() * sizeof(int64_t)); + // pos, temp pos + Tensor pos, temp_pos; + pos.Resize({{out_batch_size}}); + temp_pos.Resize({{out_batch_size}}); + auto* pos_data = dev_ctx.Alloc(&pos, pos.numel() * sizeof(int64_t)); + auto* temp_pos_data = dev_ctx.Alloc(&temp_pos, temp_pos.numel() * sizeof(int64_t)); + // cumsum + Tensor lec_cum; + lec_cum.Resize({{tot_expert}}); + auto* lec_cum_data = dev_ctx.Alloc(&lec_cum, lec_cum.numel() * sizeof(int64_t)); + // fused moe ffn tmp out + Tensor index_select_out; + index_select_out.Resize({{out_batch_size, d_model}}); + auto* index_select_out_data = dev_ctx.Alloc(&index_select_out, + index_select_out.numel() * sizeof(T)); + Tensor global_gather_out; + global_gather_out.Resize({{out_batch_size, d_model}}); + auto* global_gather_out_data = dev_ctx.Alloc(&global_gather_out, + global_gather_out.numel() * sizeof(T)); + Tensor moe_gather_out; + moe_gather_out.Resize({{out_batch_size, d_model}}); + auto* moe_gather_out_data = dev_ctx.Alloc(&moe_gather_out, + moe_gather_out.numel() * sizeof(T)); + Tensor bmm_out; + bmm_out.Resize({{sliced_bsz_seq, 1, d_model}}); + auto* bmm_out_data = dev_ctx.Alloc(&bmm_out, bmm_out.numel() * sizeof(T)); + Tensor all_gather_out; + all_gather_out.Resize({{bsz_seq, d_model}}); + auto* all_gather_out_data = + dev_ctx.Alloc(&all_gather_out, all_gather_out.numel() * sizeof(T)); + DropoutParam dropout_param(false, 0, true, true, 0.0, nullptr, 0); + + // step1 layer norm + if (pre_layer_norm) { + pre_layernorm_helper.LayerNorm(dev_ctx, + x->data(), + ln_scale_ptr, + ln_bias_ptr, + ln_out_data, + ln_mean_data, + ln_variance_data); + } else { + ln_out = *x; + } + // step2 resize and slice ln_out + ln_out.Resize({{bsz_seq, d_model}}); + if (mp_size > 1) { + sliced_inp = ln_out.Slice(start, end); + } else { + sliced_inp = ln_out; + } + // step3 gate & topk + gate_linear_compute.ComputeForward(gate_weight, &sliced_inp, gate_bias, &gate_out, &gate_out); + phi::TopkKernel(dev_ctx, + gate_out, + phi::Scalar(topk), + -1, + true, + false, + &topk_value, + &topk_idx); + // step4 prepare forward + // step4.1 number count + NumberCountCompute(dev_ctx, &topk_idx, tot_expert, &local_expert_count); + // step4.2 all_to_all + if (world_size > 1) { + AllToAll(local_expert_count, global_expert_count, moe_ring_id, dev_ctx); + } else { + global_expert_count = local_expert_count; + } + // global expert count resize + global_expert_count.Resize({{world_size, num_expert}}); + // fwd expert count + phi::SumKernel(dev_ctx, + global_expert_count, + phi::IntArray({0}), + global_expert_count.dtype(), + false, + &fwd_expert_count); + // fwd batch size + phi::SumKernel(dev_ctx, + fwd_expert_count, + phi::IntArray({}), // axis is None + fwd_expert_count.dtype(), + false, + &fwd_batch_size); + // step4.3 cumsum & assign pos + phi::CumsumKernel(dev_ctx, + local_expert_count, + phi::Scalar(0), + false, + false, + false, + &lec_cum); + AssignPosCompute(dev_ctx, &lec_cum, &topk_idx, &pos); + if (topk > 1) { + Tensor topk_tensor; + topk_tensor.Resize({{1}}); + auto *topk_tensor_data = dev_ctx.Alloc(&topk_tensor, topk_tensor.numel() * sizeof(int64_t)); + phi::FullKernel(dev_ctx, {1}, topk, pos.dtype(), &topk_tensor); + phi::FloorDivideKernel(dev_ctx, + pos, + topk_tensor, + &temp_pos); + } else { + temp_pos = pos; + } + Tensor fwd_expert_count_cpu; + framework::TensorCopySync(fwd_expert_count, platform::CPUPlace(), &fwd_expert_count_cpu); + Tensor fwd_batch_size_cpu; + framework::TensorCopySync(fwd_batch_size, platform::CPUPlace(), &fwd_batch_size_cpu); + int fwd_bsz = fwd_batch_size_cpu.data()[0]; + + Tensor global_scatter_out; + global_scatter_out.Resize({{fwd_bsz, d_model}}); + auto* global_scatter_out_data = dev_ctx.Alloc(&global_scatter_out, + global_scatter_out.numel() * sizeof(T)); + std::vector tmp_expert_out; + Tensor all_expert_out; + all_expert_out.Resize({{fwd_bsz, d_model}}); + auto* all_expert_out_data = dev_ctx.Alloc(&all_expert_out, + all_expert_out.numel() * sizeof(T)); + // step 5, MOEScatter + // step 5.1, index select + // suppose tmp_pos->shape != [0] + phi::IndexSelectKernel(dev_ctx, sliced_inp, temp_pos, 0, &index_select_out); + if (world_size > 1) { + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 5.2, global_scatter + if (map->has(moe_ring_id)) { + GlobalScatterProcessGroupFunctor functor_; + functor_(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_scatter_out); + } else { + GlobalScatterFunctor functor_; + functor_(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_scatter_out); + } + } else { + global_scatter_out = index_select_out; + } + // step 6, Expert Computation + if (global_scatter_out.dims()[0] != 0) { + int last_index = 0; + for (int idx = 0; idx < num_expert; idx++) { + int cur_expert_count = fwd_expert_count_cpu.data()[idx]; + if (cur_expert_count <= 0) { + continue; + } + int end = cur_expert_count + last_index; + Tensor expert_out1; + expert_out1.Resize({{cur_expert_count, dim_feedforward}}); + auto *expert_out1_data = dev_ctx.Alloc(&expert_out1, + expert_out1.numel() * sizeof(T)); + Tensor act_bias_out; + act_bias_out.Resize({{cur_expert_count, dim_feedforward}}); + auto *act_bias_out_data = dev_ctx.Alloc(&act_bias_out, + act_bias_out.numel() * sizeof(T)); + Tensor expert_out2; + expert_out2.Resize({{cur_expert_count, d_model}}); + auto *expert_out2_data = dev_ctx.Alloc(&expert_out2, + expert_out2.numel() * sizeof(T)); + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, cur_expert_count, dim_feedforward, dropout_param); + + Tensor tmp_inp = global_scatter_out.Slice(last_index, end); + // linear1 matmul + MatMulAndAdd(dev_ctx, + experts_weight1[idx], + &tmp_inp, + nullptr, + false, + false, + false, // dont compute bias + &expert_out1, + nullptr); + // bias gelu + fused_act_dropout_helper.DropoutActBias(dev_ctx, + expert_out1.data(), + experts_bias1[idx]->data(), + "gelu", + act_bias_out.data(), + nullptr); + // linear2 matmul & add + MatMulAndAdd(dev_ctx, + experts_weight2[idx], + &act_bias_out, + experts_bias2[idx], + false, + false, + true, // compute bias + &expert_out2, + &expert_out2); + tmp_expert_out.emplace_back(expert_out2); + last_index = end; + } + phi::funcs::ConcatFunctor concat; + concat(dev_ctx, tmp_expert_out, 0, &all_expert_out); + } else { + all_expert_out = global_scatter_out; + } + // step7. MOEGather + if (world_size > 1) { + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 7.1, global_gather + if (map->has(moe_ring_id)) { + GlobalGatherProcessGroupFunctor functor_; + functor_(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_gather_out); + } else { + GlobalGatherFunctor functor_; + functor_(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_gather_out); + } + } else { + global_gather_out = all_expert_out; + } + // step 7.2, local_gather or scatter + // suppose pos->shape != [0] + phi::ScatterKernel(dev_ctx, + moe_gather_out, + pos, + global_gather_out, + true, + &moe_gather_out); + // step 8, reshape & bmm + if (topk > 1) { + // moe gather out reshape + moe_gather_out.Resize({{sliced_bsz_seq, topk, d_model}}); + topk_value.Resize({{sliced_bsz_seq, 1, topk}}); + phi::BmmKernel(dev_ctx, topk_value, moe_gather_out, &bmm_out); + bmm_out.Resize({{sliced_bsz_seq, d_model}}); + } else { + bmm_out = moe_gather_out; + } + // step 9, AllGather + if (mp_size > 1) { + // all gather + AllGather(bmm_out, all_gather_out, moe_ring_id, dev_ctx); + } else { + all_gather_out = bmm_out; + } + // step 10, reshape + all_gather_out.Resize(x_dim); + // step 11, add residual + phi::AddKernel(dev_ctx, all_gather_out, *x, out); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + fused_moe, + ops::FusedMoeOpKernel, + ops::FusedMoeOpKernel, + ops::FusedMoeOpKernel); diff --git a/paddle/fluid/operators/fused/fused_moe_op.h b/paddle/fluid/operators/fused/fused_moe_op.h new file mode 100644 index 0000000000000..5ebfbff589850 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_moe_op.h @@ -0,0 +1,718 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/operators/fused/fused_dropout_helper.h" +#include "paddle/fluid/operators/layer_norm_kernel.cu.h" +#include "paddle/fluid/operators/fused/attn_gemm.h" +#include "paddle/fluid/operators/matmul_v2_op.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/top_k_kernel.h" +#include "paddle/phi/kernels/cum_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/elementwise_kernel.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/funcs/functors.h" +#include "paddle/phi/kernels/index_select_kernel.h" +#include "paddle/phi/kernels/scatter_kernel.h" +#include "paddle/fluid/operators/collective/global_scatter_op.h" +#include "paddle/fluid/operators/collective/global_gather_op.h" +#include "paddle/phi/kernels/bmm_kernel.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" +// #include "paddle/fluid/framework/convert_utils.h" +// #include "paddle/fluid/platform/float16.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { +// number count +#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1) +#define PERTHREAD_EXPERTS 256 +#define WARP_SIZE 32 + +const int CUDA_NUM_THREADS = 512; +static inline int GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +template +__global__ void initialize_zero_kernel(T* data, const int length) { + CUDA_KERNEL_LOOP(idx, length) { data[idx] = static_cast(0); } +} + +template +__global__ void NumberCount(const T* numbers, + T* number_count, + int64_t batch_size, + int upper_range) { + int res_tmp[PERTHREAD_EXPERTS] = {0}; + int expert_min = blockIdx.x * PERTHREAD_EXPERTS; + int expert_max = expert_min + PERTHREAD_EXPERTS; + if (expert_max > upper_range) { + expert_max = upper_range; + } + for (int i = threadIdx.x; i < batch_size; i += blockDim.x) { + T idx = numbers[i]; + if (idx == -1) { + continue; + } + if (idx < expert_min || idx >= expert_max) { + continue; + } + res_tmp[idx - expert_min] += 1; + } + for (int i = expert_min; i < expert_max; ++i) { + int x = res_tmp[i - expert_min]; +#pragma unroll + for (int j = 1; j < WARP_SIZE; j <<= 1) { +#ifdef __HIPCC__ + x = x + __shfl_down(x, j); +#else + x = x + __shfl_down_sync(-1u, x, j); +#endif + } + if (threadIdx.x % WARP_SIZE == 0) { + platform::CudaAtomicAdd(number_count + i, x); + } + } +} + +template +void NumberCountCompute(const phi::GPUContext &dev_ctx, + framework::Tensor* numbers, + int upper_range, + framework::Tensor* out) { + int64_t batch_size = numbers->numel(); + auto place = dev_ctx.GetPlace(); + + framework::DDim out_dims = phi::make_ddim({upper_range}); + auto out_data = out->mutable_data(out_dims, place); + const T* gate_data = numbers->data(); + + initialize_zero_kernel + <<>>( + out_data, upper_range); + + NumberCount + <<>>( + gate_data, out_data, batch_size, upper_range); +} + +// assign pos +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void AssignPos(T* cum_count, + const T* numbers, + T* out, + int64_t limit) { + CUDA_KERNEL_LOOP(i, limit) { + int number_idx = numbers[i]; + if (number_idx > -1) { + int p = platform::CudaAtomicAdd(cum_count + number_idx, -1); + out[p - 1] = i; + } + } +} + +template +void AssignPosCompute(const phi::GPUContext &dev_ctx, + framework::Tensor* cum_count, // (counter number) int32 | int64 + framework::Tensor* numbers, // (batch_size * seq_len, topk) int32 + framework::Tensor* out) { + auto place = dev_ctx.GetPlace(); + auto numel = numbers->numel(); + T* cum_data = const_cast(cum_count->data()); + auto cum_size = cum_count->numel(); + + framework::Tensor cpu_cum_count; + int64_t cpu_eff_num_len_data = 0; + if (platform::is_cpu_place(cum_count->place())) { + cpu_eff_num_len_data = cum_count->data()[cum_size - 1]; + } else { + framework::TensorCopySync( + *cum_count, platform::CPUPlace(), &cpu_cum_count); + cpu_eff_num_len_data = cpu_cum_count.data()[cum_size - 1]; + } + + framework::DDim out_dims = phi::make_ddim({cpu_eff_num_len_data}); + auto out_data = out->mutable_data(out_dims, place); + + const T* num_data = numbers->data(); + + int blocks = NumBlocks(numel); + int threads = kNumCUDAThreads; + + AssignPos<<>>( + cum_data, num_data, out_data, numel); +} + +template +struct GlobalScatterFunctor { + void operator()(const phi::GPUContext& ctx, + const framework::Tensor* x, + const framework::Tensor* local_count, + const framework::Tensor* global_count, + int ring_id, + bool use_calc_stream, + framework::Tensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + // auto x = ctx.Input("X"); + // auto local_count = ctx.Input("local_count"); + // auto global_count = ctx.Input("global_count"); + auto local_count_type = + framework::TransToProtoVarType(local_count->dtype()); + auto global_count_type = + framework::TransToProtoVarType(global_count->dtype()); + if (local_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + } + if (global_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + } + // auto out = ctx.Output("Out"); + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + } + auto global_count_len = 0; + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + global_count_len = global_count->numel(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + global_count_len = cpu_global_count.numel(); + } + + ncclDataType_t dtype = + platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())); + + // int ring_id = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global scatter op must be non-negative.", + ring_id)); + + auto place = ctx.GetPlace(); + // HARD CODE HERE! + // auto place = platform::CUDAPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + gpuStream_t stream = nullptr; + if (use_calc_stream) { + // auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + // stream = static_cast(dev_ctx)->stream(); + stream = ctx.stream(); + } else { + stream = comm->stream(); + } + int nranks = comm->nranks(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + int64_t fwd_count = 0; + + for (auto i = 0; i < global_count_len; ++i) { + fwd_count += cpu_global_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + + auto recv_ptr = 0; + auto send_buf = x->data(); + auto recv_buf = out->mutable_data(out_dims, place); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclSend(send_buf + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + } + if (cpu_global_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclRecv(recv_buf + recv_ptr * in_feat, + cpu_global_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + recv_ptr += cpu_global_count_data[idx]; + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } + +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +template +struct GlobalScatterProcessGroupFunctor { + void operator()(const phi::GPUContext& ctx, + const framework::Tensor* x, + const framework::Tensor* local_count, + const framework::Tensor* global_count, + int ring_id, + bool use_calc_stream, + framework::Tensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + // auto x = ctx.Input("X"); + // auto local_count = ctx.Input("local_count"); + // auto global_count = ctx.Input("global_count"); + auto local_count_type = + framework::TransToProtoVarType(local_count->dtype()); + auto global_count_type = + framework::TransToProtoVarType(global_count->dtype()); + if (local_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + } + if (global_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + } + // auto out = ctx.Output("Out"); + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + } + auto global_count_len = 0; + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + global_count_len = global_count->numel(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + global_count_len = cpu_global_count.numel(); + } + + // int ring_id = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global scatter op must be non-negative.", + ring_id)); + + auto place = ctx.GetPlace(); + // HARD CODE HERE! + // auto place = platform::CUDAPlace(); + + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + distributed::ProcessGroup* pg = map->get(ring_id); + int nranks = pg->GetSize(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + int64_t fwd_count = 0; + + for (auto i = 0; i < global_count_len; ++i) { + fwd_count += cpu_global_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + + auto recv_ptr = 0; + out->mutable_data(out_dims, place); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + phi::DenseTensor tmp = *x; + pg->Send_Partial(tmp, + j, + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat); + } + if (cpu_global_count_data[idx]) { + pg->Recv_Partial(*out, + j, + recv_ptr * in_feat, + cpu_global_count_data[idx] * in_feat); + recv_ptr += cpu_global_count_data[idx]; + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +template +struct GlobalGatherFunctor { + void operator()(const phi::GPUContext& ctx, + const framework::Tensor* x, + const framework::Tensor* local_count, + const framework::Tensor* global_count, + int ring_id, + bool use_calc_stream, + framework::Tensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + // auto x = ctx.Input("X"); + // auto local_count = ctx.Input("local_count"); + // auto global_count = ctx.Input("global_count"); + auto local_count_type = + framework::TransToProtoVarType(local_count->dtype()); + auto global_count_type = + framework::TransToProtoVarType(global_count->dtype()); + if (local_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + } + if (global_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + } + // auto out = ctx.Output("Out"); + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + auto local_count_len = 0; + + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + local_count_len = local_count->numel(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + local_count_len = cpu_local_count.numel(); + } + + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + } + + ncclDataType_t dtype = + platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())); + + // int ring_id = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global gather op must be non-negative.", + ring_id)); + auto place = ctx.GetPlace(); + // auto place = platform::CUDAPlace(); + + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + gpuStream_t stream = nullptr; + if (use_calc_stream) { + // auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + // stream = static_cast(dev_ctx)->stream(); + stream = ctx.stream(); + } else { + stream = comm->stream(); + } + int nranks = comm->nranks(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + + auto fwd_count = 0; + + for (auto i = 0; i < local_count_len; ++i) { + fwd_count += cpu_local_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + auto send_ptr = 0; + auto send_buf = x->data(); + auto recv_buf = out->mutable_data(out_dims, place); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclSend(send_buf + send_ptr * in_feat, + cpu_global_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + send_ptr += cpu_global_count_data[idx]; + } + if (cpu_local_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclRecv(recv_buf + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +template +struct GlobalGatherProcessGroupFunctor { + void operator()(const phi::GPUContext& ctx, + const framework::Tensor* x, + const framework::Tensor* local_count, + const framework::Tensor* global_count, + int ring_id, + bool use_calc_stream, + framework::Tensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + // auto x = ctx.Input("X"); + // auto local_count = ctx.Input("local_count"); + // auto global_count = ctx.Input("global_count"); + auto local_count_type = + framework::TransToProtoVarType(local_count->dtype()); + auto global_count_type = + framework::TransToProtoVarType(global_count->dtype()); + if (local_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + } + if (global_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + } + // auto out = ctx.Output("Out"); + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + auto local_count_len = 0; + + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + local_count_len = local_count->numel(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + local_count_len = cpu_local_count.numel(); + } + + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + } + + // int ring_id = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global gather op must be non-negative.", + ring_id)); + auto place = ctx.GetPlace(); + // auto place = platform::CUDAPlace(); + + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + distributed::ProcessGroup* pg = map->get(ring_id); + + int nranks = pg->GetSize(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + + auto fwd_count = 0; + + for (auto i = 0; i < local_count_len; ++i) { + fwd_count += cpu_local_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + auto send_ptr = 0; + out->mutable_data(out_dims, place); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + phi::DenseTensor tmp = *x; + pg->Send_Partial( + tmp, j, send_ptr * in_feat, cpu_global_count_data[idx] * in_feat); + send_ptr += cpu_global_count_data[idx]; + } + if (cpu_local_count_data[idx]) { + pg->Recv_Partial(*out, + j, + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat); + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +template +void MatMulAndAdd(const phi::GPUContext& dev_ctx, + const framework::Tensor* weight, + const framework::Tensor* input, + const framework::Tensor* bias, + bool istransA, + bool istransB, + bool compute_bias, + framework::Tensor* output, + framework::Tensor* bias_out) { + // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. + // here: (transa, transb): nt, input * weight. + CBLAS_TRANSPOSE transA = istransA ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE transB = istransB ? CblasTrans : CblasNoTrans; + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + + // (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) + auto blas = phi::funcs::GetBlas(dev_ctx); + blas.GEMM(transA, + transB, + input->dims()[0], + weight->dims()[1], + input->dims()[1], + alpha, + input->data(), + weight->data(), + beta, + output->data()); + if (compute_bias) { + // bias_out = output + bias + std::vector ins = {output, bias}; + std::vector outs = {bias_out}; + phi::funcs::BroadcastKernel( + dev_ctx, ins, &outs, -1, phi::funcs::AddFunctor()); + } +} + +} // namesapce operators +} //namespace paddle \ No newline at end of file diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index af080bd0b3431..9f0ca6b62d3cb 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -71,6 +71,16 @@ std::map> op_ins_map = { "FFN1Bias", "FFN2Weight", "FFN2Bias"}}, + {"fused_moe", + {"X", + "GateWeight", + "GateBias", + "LnScale", + "LnBias", + "ExpertsWeight1", + "ExpertsBias1", + "ExpertsWeight2", + "ExpertsBias2"}}, {"fused_multi_transformer_int8", {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "TimeStep", "SrcMask", @@ -335,6 +345,7 @@ std::map> op_outs_map = { "Beta2PowOut", "MasterParamOut"}}, {"fused_multi_transformer", {"CacheKVOut", "Out"}}, + {"fused_moe", {"Out"}}, {"fused_multi_transformer_int8", {"CacheKVOut", "Out"}}, {"resnet_basic_block", {"Y", "Conv1", "SavedMean1", "SavedInvstd1", "Mean1Out", diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index b2767b1dd1cbf..677bbc0c2b288 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -166,6 +166,7 @@ def _update_list(self): 'concat', 'split', 'fused_feedforward', + 'fused_moe', 'fused_attention', 'fused_multi_transformer', } diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index b23c94c7e4994..9b0f5fa1006aa 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -112,6 +112,8 @@ def _keep_fp32_input(op, in_name): } if op_type == 'fused_multi_transformer': return in_name in {'LnScale', 'LnBias', 'FFNLnScale', 'FFNLnBias'} + if op_type == 'fused_moe': + return in_name in {'LnScale', 'LnBias'} return False diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index decaf45125750..98efe71421cf4 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -186,13 +186,13 @@ def pure_fp16_initialize(models): if (layer._dtype == 'float16') or isinstance( layer, (paddle.nn.BatchNorm, paddle.nn.BatchNorm1D, paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D, - paddle.nn.LayerNorm, paddle.nn.SyncBatchNorm, - paddle.nn.ParameterList)): + paddle.nn.LayerNorm, paddle.nn.SyncBatchNorm, paddle.nn.ParameterList)): # tianyan01 add paddle.nn.ParameterList, hack continue if isinstance(layer, (paddle.incubate.nn.FusedFeedForward, paddle.incubate.nn.FusedMultiHeadAttention, - paddle.incubate.nn.FusedMultiTransformer)): + paddle.incubate.nn.FusedMultiTransformer, + paddle.incubate.nn.FusedMoELayer)): layer._amp_decorate(dtype='float16') continue layer._to_impl(dtype='float16', diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index cf15ee7d8ffaa..2a2def22bb3bf 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -18,6 +18,7 @@ from .layer.fused_transformer import FusedMultiTransformer # noqa: F401 from .layer.fused_linear import FusedLinear # noqa: F401 from .layer.fused_transformer import FusedBiasDropoutResidualLayerNorm # noqa: F401 +from .layer.fused_transformer import FusedMoELayer # tianyan01 add __all__ = [ #noqa 'FusedMultiHeadAttention', @@ -26,4 +27,5 @@ 'FusedMultiTransformer', 'FusedLinear', 'FusedBiasDropoutResidualLayerNorm', + 'FusedMoELayer', ] diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index d71086fe7b07d..8c2053c4bd47c 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -16,6 +16,8 @@ from paddle.nn import Layer from paddle.framework import ParamAttr import paddle +import paddle.nn as nn +from paddle import _legacy_C_ops from paddle.nn import ParameterList from paddle.nn.layer.transformer import ( _convert_attention_mask, @@ -1454,4 +1456,170 @@ def trans_to_fp16(l): trans_to_fp16(self.ffn1_biases) trans_to_fp16(self.ffn2_weights) trans_to_fp16(self.ffn2_biases) + self._dtype = dtype + + +class FusedMoELayer(Layer): + """FusedMoE Layer + Args: + d_model: (int) model dimention + num_expert: (int) expert count + top_k: (int) top-k number + some weights and bias... + moe_group: moe group for experts communication + mp_group: mp group for mp commutication + Examples: + .. code-block:: python + # required: gpu + import paddle + from paddle.incubate.nn import FusedMoELayer + + # input: [batch_size, src_len, d_model] + input = paddle.rand((2, 4, 128)) + # dim_feedforward = 128 + fused_moe_layer = FusedMoELayer(128, 128, 4, 2) + output = fused_moe_layer(input) # [2, 4, 128] + + """ + + def __init__(self, + d_model, + dim_feedforward, + num_expert, + top_k, + ln_scale=None, + ln_bias=None, + gate_weight=None, + gate_bias=None, + linear1_weights=None, + linear1_biases=None, + linear2_weights=None, + linear2_biases=None, + moe_group=None, + mp_group=None): + super(FusedMoELayer, self).__init__() + # only support mp/dp + self.group = moe_group + + self.world_size = 1 + if self.group is not None: + self.world_size = self.group.nranks + self.num_expert = num_expert + + self.mp_group = mp_group + self.mp_rank = 0 + self.mp_size = 1 + if mp_group is not None and mp_group.nranks > 1: + self.mp_rank = mp_group.rank + self.mp_size = mp_group.nranks + self.d_model = d_model + self.top_k = top_k + self.ln_scale = self.create_parameter( + shape=[d_model], + attr=None, + is_bias=False + ) + self.ln_bias = self.create_parameter( + shape=[d_model], attr=None, is_bias=True + ) + self.gate_weight = self.create_parameter( + shape=[d_model, num_expert * self.world_size], + attr=None, + dtype=self._dtype, + is_bias=False + ) + self.gate_bias = self.create_parameter( + shape=[num_expert * self.world_size], + attr=None, + dtype=self._dtype, + is_bias=True + ) + + self.linear1_weights = ParameterList() + self.linear2_weights = ParameterList() + self.linear1_biases = ParameterList() + self.linear2_biases = ParameterList() + def get_attr(attrs, idx): + if isinstance(attrs, (list, tuple, ParameterList)): + assert len(attrs) == num_expert + return attrs[idx] + return attrs + for i in range(num_expert): + w1 = get_attr(linear1_weights, i) + b1 = get_attr(linear1_biases, i) + w2 = get_attr(linear2_weights, i) + b2 = get_attr(linear2_biases, i) + + self.linear1_weights.append(self.create_parameter( + shape=[d_model, dim_feedforward], + attr=w1, + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.KaimingUniform() + )) + self.linear2_weights.append(self.create_parameter( + shape=[dim_feedforward, d_model], + attr=w2, + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.KaimingUniform() + )) + self.linear1_biases.append(self.create_parameter( + shape=[dim_feedforward], + attr=b1, + dtype=self._dtype, + is_bias=True, + default_initializer=nn.initializer.Constant(value=0.0) + )) + self.linear2_biases.append(self.create_parameter( + shape=[d_model], + attr=b2, + dtype=self._dtype, + is_bias=True, + default_initializer=nn.initializer.Constant(value=0.0) + )) + + def forward(self, inp): + inp = _legacy_C_ops.fused_moe( + inp, + self.gate_weight, + self.gate_bias, + self.ln_scale, + self.ln_bias, + list(self.linear1_weights), + list(self.linear1_biases), + list(self.linear2_weights), + list(self.linear2_biases), + 'pre_layer_norm', + True, + 'ln_epsilon', + 1e-5, + 'topk', + self.top_k, + 'mp_size', + self.mp_size, + 'mp_rank', + self.mp_rank, + 'num_expert', + self.num_expert, + 'world_size', + self.world_size, + 'moe_ring_id', + -1 if self.group is None else self.group.id + ) + return inp + + def _amp_decorate(self, dtype): + # tmp fix for amp.decorator(O2) + def trans_to_fp16(l): + for param in l: + if param is not None: + with paddle.no_grad(): + param_applied = _to_dtype(param, dtype) + trans_to_fp16(self.linear1_weights) + trans_to_fp16(self.linear1_biases) + trans_to_fp16(self.linear2_weights) + trans_to_fp16(self.linear2_biases) + _ = _to_dtype(self.gate_weight, dtype) + _ = _to_dtype(self.gate_bias, dtype) self._dtype = dtype \ No newline at end of file