From 0f13c5dfb67e1cea21b9e20d120e1b87b80d8481 Mon Sep 17 00:00:00 2001 From: hlygit66666 <2570058140@qq.com> Date: Mon, 27 Sep 2021 05:14:48 +0000 Subject: [PATCH] fix limit_by_capacity op --- paddle/fluid/operators/limit_by_capacity_op.cc | 9 +++------ paddle/fluid/operators/limit_by_capacity_op.cu | 6 +++--- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/limit_by_capacity_op.cc b/paddle/fluid/operators/limit_by_capacity_op.cc index 42c18bf67c16b2..9b810a653cfb80 100644 --- a/paddle/fluid/operators/limit_by_capacity_op.cc +++ b/paddle/fluid/operators/limit_by_capacity_op.cc @@ -52,11 +52,8 @@ class LimitByCapacityOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_WITHOUT_GRADIENT(LimitByCapacityOp, ops::LimitByCapacityOp, - ops::LimitByCapacityOpMaker); - -REGISTER_OPERATOR(limit_by_capacity, ops::LimitByCapacityOp, - ops::LimitByCapacityOpMaker); - REGISTER_OP_CPU_KERNEL(limit_by_capacity, ops::LimitByCapacityOpCPUKernel, ops::LimitByCapacityOpCPUKernel); + +REGISTER_OP_WITHOUT_GRADIENT(limit_by_capacity, ops::LimitByCapacityOp, + ops::LimitByCapacityOpMaker); diff --git a/paddle/fluid/operators/limit_by_capacity_op.cu b/paddle/fluid/operators/limit_by_capacity_op.cu index 754d889a401c8d..d1e4d4b6c44b69 100644 --- a/paddle/fluid/operators/limit_by_capacity_op.cu +++ b/paddle/fluid/operators/limit_by_capacity_op.cu @@ -26,8 +26,8 @@ using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; template -__global__ void LimitByCapacity(const T* expc, int* cap, T* out, - const int n_expert, const int n_worker) { +__global__ void limit_by_capacity_impl(const T* expc, int* cap, T* out, + const int n_expert, const int n_worker) { int eid = blockIdx.y; int wid = blockIdx.x * blockDim.x + threadIdx.x; if (wid < n_worker) { @@ -67,7 +67,7 @@ class LimitByCapacityOpCUDAKernel : public framework::OpKernel { framework::TensorCopy(*capacity, place, dev_ctx, &capacity_copy); int* cap_data = capacity_copy.mutable_data(place); - LimitByCapacity<<>>( + limit_by_capacity_impl<<>>( ec_data, cap_data, out_data, n_expert, n_worker); } };