diff --git a/paddle/fluid/operators/limit_by_capacity_op.cc b/paddle/fluid/operators/limit_by_capacity_op.cc new file mode 100644 index 00000000000000..2298f193707a25 --- /dev/null +++ b/paddle/fluid/operators/limit_by_capacity_op.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/limit_by_capacity_op.h" + +namespace paddle { +namespace operators { + +class LimitByCapacityOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("expert_count"), "Input", "expert_count", + "LimitByCapacity"); + OP_INOUT_CHECK(ctx->HasInput("capacity"), "Input", "capacity", + "LimitByCapacity"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "LimitByCapacity"); + + ctx->ShareDim("expert_count", "Out"); + ctx->ShareLoD("expert_count", "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + // the dtype of the expert_count and capacity should be same as int64 + auto expert_count_dtype = + OperatorWithKernel::IndicateVarDataType(ctx, "expert_count"); + auto capacity_dtype = + OperatorWithKernel::IndicateVarDataType(ctx, "capacity"); + + PADDLE_ENFORCE_EQ( + expert_count_dtype, capacity_dtype, + platform::errors::InvalidArgument( + "The dtype of the expert_count and capacity should be same")); + + PADDLE_ENFORCE_EQ( + expert_count_dtype, framework::proto::VarType::INT64, + platform::errors::InvalidArgument("The dtype of the expert_count and " + "capacity should be same as int64")); + return framework::OpKernelType(expert_count_dtype, ctx.GetPlace()); + } +}; + +class LimitByCapacityOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("expert_count", "(Tensor) The input expert count tensor."); + AddInput("capacity", "(Tensor) The input capacity."); + AddOutput("Out", + "(Tensor) The output tensor expert count limit by capacity."); + AddAttr("n_worker", "(int), The number of works."); + AddComment( + R"DOC(limit_by_capacity Operator.limit expert count by capacity.)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CPU_KERNEL(limit_by_capacity, ops::LimitByCapacityOpCPUKernel, + ops::LimitByCapacityOpCPUKernel); + +REGISTER_OP_WITHOUT_GRADIENT(limit_by_capacity, ops::LimitByCapacityOp, + ops::LimitByCapacityOpMaker); diff --git a/paddle/fluid/operators/limit_by_capacity_op.cu b/paddle/fluid/operators/limit_by_capacity_op.cu new file mode 100644 index 00000000000000..a188c812ac278b --- /dev/null +++ b/paddle/fluid/operators/limit_by_capacity_op.cu @@ -0,0 +1,83 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/limit_by_capacity_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1) + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +template +__global__ void limit_by_capacity_impl(const T* expc, T* cap, T* out, + const int n_expert, const int n_worker) { + int eid = blockIdx.y; + int wid = blockIdx.x * blockDim.x + threadIdx.x; + if (wid < n_worker) { + auto proposal = expc[wid * n_expert + eid]; + // int cap_left = atomicSub(cap + eid, proposal); + auto cap_left = paddle::platform::CudaAtomicAdd(cap + eid, proposal * (-1)); + if (cap_left >= proposal) { + out[wid * n_expert + eid] = proposal; + } else if (cap_left >= 0) { + out[wid * n_expert + eid] = cap_left; + } else { + out[wid * n_expert + eid] = 0; + } + } +} + +template +class LimitByCapacityOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto expert_count = context.Input("expert_count"); + auto capacity = context.Input("capacity"); + auto n_worker = context.Attr("n_worker"); + auto out = context.Output("Out"); + + auto n_expert = expert_count->numel() / n_worker; + // std::cout << "n_expert" << n_expert << std::endl; + const auto place = context.GetPlace(); + const auto& dev_ctx = + context.template device_context(); + + dim3 grid_dim(CEIL(n_worker, 1024), n_expert); + dim3 block_dim(1024); + auto out_data = out->mutable_data(place); + const T* ec_data = expert_count->data(); + + framework::Tensor capacity_copy; + framework::TensorCopy(*capacity, place, dev_ctx, &capacity_copy); + T* cap_data = capacity_copy.mutable_data(place); + + limit_by_capacity_impl<<>>( + ec_data, cap_data, out_data, n_expert, n_worker); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(limit_by_capacity, + ops::LimitByCapacityOpCUDAKernel); diff --git a/paddle/fluid/operators/limit_by_capacity_op.h b/paddle/fluid/operators/limit_by_capacity_op.h new file mode 100644 index 00000000000000..b13a64faed8e11 --- /dev/null +++ b/paddle/fluid/operators/limit_by_capacity_op.h @@ -0,0 +1,37 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif + +namespace paddle { +namespace operators { + +template +class LimitByCapacityOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support limit by capacity op for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/prune_gate_by_capacity_op.cc b/paddle/fluid/operators/prune_gate_by_capacity_op.cc new file mode 100644 index 00000000000000..091b33884bfbaa --- /dev/null +++ b/paddle/fluid/operators/prune_gate_by_capacity_op.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/prune_gate_by_capacity_op.h" + +namespace paddle { +namespace operators { + +class PruneGateByCapacityOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("GateIdx"), "Input", "GateIdx", + "prun_gate_by_capacity"); + OP_INOUT_CHECK(ctx->HasInput("ExpertCount"), "Input", "ExpertCount", + "prun_gate_by_capacity"); + + OP_INOUT_CHECK(ctx->HasOutput("NewGateIdx"), "Output", "NewGateIdx", + "prun_gate_by_capacity"); + // OP_INOUT_CHECK(ctx->HasOutput("ExpertCountOut"), "Output", + // "ExpertCountOut", + // "prun_gate_by_capacity"); + // auto gate_idx_dims = ctx->GetInputDim("GateIdx"); + auto expert_count_dims = ctx->GetInputDim("ExpertCount"); + + int64_t n_expert = ctx->Attrs().Get("n_expert"); + int64_t n_worker = ctx->Attrs().Get("n_worker"); + + int64_t expert_count_num_ele = 1; + for (int64_t i = 0; i < expert_count_dims.size(); i++) { + expert_count_num_ele *= expert_count_dims[i]; + } + + PADDLE_ENFORCE_EQ( + expert_count_num_ele, n_expert * n_worker, + platform::errors::Unavailable( + "The number of elements for expert_count is ( %ld ) incorrect. " + "Because the number of expert_count must equal the " + "product of n_worker ( %ld ) and n_expert ( %ld ). " + "Please input appropriate expert_count again!", + expert_count_num_ele, n_worker, n_expert)); + + auto gate_idx_in_dims = ctx->GetInputDim("GateIdx"); + // auto expert_count_in_dims = ctx->GetInputDim("ExpertCount"); + ctx->SetOutputDim("NewGateIdx", gate_idx_in_dims); + // ctx->SetOutputDim("ExpertCountOut", expert_count_in_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto gate_idx_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "GateIdx"); + auto expert_count_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "ExpertCount"); + PADDLE_ENFORCE_EQ( + gate_idx_data_type, expert_count_data_type, + platform::errors::InvalidArgument( + "The dtype of the gate_idx and expert_count should be same")); + PADDLE_ENFORCE_EQ(gate_idx_data_type, framework::proto::VarType::INT64, + platform::errors::InvalidArgument( + "The dtype of the gate_idx and expert_count should " + "be same as int64")); + return framework::OpKernelType(gate_idx_data_type, ctx.device_context()); + } +}; + +class PruneGateByCapacityOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("GateIdx", + "(Tensor), The gate_id sequence corresponding to the input data."); + AddInput("ExpertCount", + "(Tensor), The quantity value counted on the gate_id sequence of " + "the input data."); + AddAttr("n_expert", "The number of Experts on each worker") + .SetDefault(0); + AddAttr("n_worker", "The number of workers on the trainer") + .SetDefault(0); + + AddOutput("NewGateIdx", + "(Tensor), The gate_id sequence corresponding to the new input " + "data after passing through prune."); + // AddOutput( + // "ExpertCountOut", + // "(Tensor), The copy quantity value counted on the gate_id sequence of + // " + // "the input data."); + + AddComment(R"DOC( +prune_gate_by_capacity Operator. + +This operator is used to prune gate by capacity(CUDA). + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_WITHOUT_GRADIENT(prune_gate_by_capacity, ops::PruneGateByCapacityOp, + ops::PruneGateByCapacityOpMaker); + +REGISTER_OP_CPU_KERNEL( + prune_gate_by_capacity, + ops::PruneGateByCapacityCPUKernel, + ops::PruneGateByCapacityCPUKernel); diff --git a/paddle/fluid/operators/prune_gate_by_capacity_op.cu b/paddle/fluid/operators/prune_gate_by_capacity_op.cu new file mode 100644 index 00000000000000..ab7d2f574213d7 --- /dev/null +++ b/paddle/fluid/operators/prune_gate_by_capacity_op.cu @@ -0,0 +1,123 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/prune_gate_by_capacity_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +namespace paddle { +namespace operators { +using LoDTensor = framework::LoDTensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void prune_gate_by_capacity_kernel(const T1* gate_idx_data, + T1* new_gate_idx_data, + T2* expert_count_data, + const int64_t batch_size) { + CUDA_KERNEL_LOOP(i, batch_size) { + auto orig_cap = + platform::CudaAtomicAdd(expert_count_data + gate_idx_data[i], -1); + if (orig_cap <= 0) { + new_gate_idx_data[i] = -1; + } else { + new_gate_idx_data[i] = gate_idx_data[i]; + } + } +} + +template +class PruneGateByCapacityFunctor { + public: + PruneGateByCapacityFunctor(const framework::ExecutionContext& context, + const framework::LoDTensor* gate_idx, + framework::LoDTensor* expert_count_out, + T1* new_gate_idx_data) + : context_(context), + gate_idx_(gate_idx), + expert_count_out_(expert_count_out), + new_gate_idx_data_(new_gate_idx_data) {} + + template + void apply() { + auto batch_size = gate_idx_->numel(); + auto* gate_idx_data = gate_idx_->data(); + + auto& dev_ctx = context_.template device_context(); + auto* expert_count_out_data = expert_count_out_->data(); + // framework::Tensor cpu_expert_count; + // framework::TensorCopySync(*expert_count_out_, platform::CPUPlace(), + // &cpu_expert_count); + int blocks = NumBlocks(batch_size); + int threads = kNumCUDAThreads; + + prune_gate_by_capacity_kernel<<>>( + gate_idx_data, new_gate_idx_data_, expert_count_out_data, batch_size); + } + + private: + const framework::ExecutionContext context_; + const framework::LoDTensor* gate_idx_; + framework::LoDTensor* expert_count_out_; + T1* new_gate_idx_data_; +}; + +template +static void VisitDataType(framework::proto::VarType::Type type, + Visitor visitor) { + if (type == framework::proto::VarType::INT64) { + visitor.template apply(); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The recieved values gate_id type %s can not meet input requirements. " + "Because the given gate_id data type of operators must be " + "int64. Please input appropriate gate_id again! ", + framework::DataTypeToString(type))); + } +} + +template +class PruneGateByCapacityCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* gate_idx = context.Input("GateIdx"); + auto* expert_count = context.Input("ExpertCount"); + // auto* expert_count_out = context.Output("ExpertCountOut"); + auto* new_gate_idx = context.Output("NewGateIdx"); + auto* new_gate_idx_data = new_gate_idx->mutable_data(context.GetPlace()); + + framework::LoDTensor expert_count_out; + framework::TensorCopy(*expert_count, context.GetPlace(), &expert_count_out); + PruneGateByCapacityFunctor functor( + context, gate_idx, &expert_count_out, new_gate_idx_data); + VisitDataType(expert_count->type(), functor); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL( + prune_gate_by_capacity, + ops::PruneGateByCapacityCUDAKernel); diff --git a/paddle/fluid/operators/prune_gate_by_capacity_op.h b/paddle/fluid/operators/prune_gate_by_capacity_op.h new file mode 100644 index 00000000000000..d7a00bd40d786f --- /dev/null +++ b/paddle/fluid/operators/prune_gate_by_capacity_op.h @@ -0,0 +1,33 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +template +class PruneGateByCapacityCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "prune_gate_by_capacity is not supported on CPU.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/distributed/utils.py b/python/paddle/distributed/utils.py index 53f4a93f6480e8..863706ad433180 100644 --- a/python/paddle/distributed/utils.py +++ b/python/paddle/distributed/utils.py @@ -809,3 +809,91 @@ def watch_local_trainers(procs, nranks): raise return alive + + +def limit_by_capacity(expert_count, capacity, n_worker): + """ + limit the expert count by capacity. + Args: + expert_count (Tensor): Tensor. The input expert count whose data type should be int32 or int64. + capacity (Tensor): Tensor. The input capacity whose data type should be int32 or int64 and the elements of capacity should be the same with expert_count.numel()/n_work. + n_work (int): The number of the works. + Returns: + out (Tensor): The output expert count limit by capacity. + Examples: + .. code-block:: python + # required: distributed + import paddle + + gate_idx = [ + [0, 2], + [0, 2] + ] + n_expert = 6 + gate_idx = paddle.to_tensor(gate_idx, dtype="int32") + expert_count = paddle.distributed.utils.expert_count(gate_idx, n_expert) + print(expert_count) # the result: [2, 0, 2, 0, 0, 0] + """ + if in_dygraph_mode(): + return core.ops.expert_count(gate_idx, 'n_expert', n_expert) + else: + op_type = 'expert_count' + + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=gate_idx.dtype) + + helper.append_op( + type=op_type, + inputs={'gate_idx': gate_idx}, + outputs={'Out': out}, + attrs={'n_expert': n_expert}) + return out + + +def prune_gate_by_capacity(gate_idx, expert_count, n_expert, n_worker): + """ + prune gate by capacity(only support CUDA) + + Args: + gate_idx (Tensor): Represents the gate_id sequence corresponding to the input data with type int32, int64. + expert_count (Tensor): The quantity value counted on the gate_id sequence of the input data with type int32, int64. + n_expert(int,optional): The number of Experts on each worker with type int64. + n_worker(int,optional): The number of workers on the trainer with type int64. + + Returns: + new_gate_idx (Tensor): The gate_id sequence corresponding to the new input data after passing through prune. + + Examples: + .. code-block:: python + + import paddle + gate_idx = paddle.to_tensor([1, 3, 3, 3, 3, 2, 1, 1], dtype='int32') + expert_count = paddle.to_tensor([0, 3, 1, 3, 0, 0, 0, 0], dtype='int32') + n_expert = 8 + n_worker = 1 + new_gate_id = paddle.distributed.utils.prune_gate_by_capacity(gate_idx, expert_count, n_expert, n_worker) + print(new_gate_id) + # Tensor(shape=[8], dtype=int32, place=CUDAPlace(0), stop_gradient=True, + [1, 3, 3, 3, -1, 2, 1, 1]) + """ + + if in_dygraph_mode(): + return core.ops.prune_gate_by_capacity( + gate_idx, expert_count, "n_expert", n_expert, "n_worker", n_worker) + check_variable_and_dtype(gate_idx, 'GateIdx', ['int32', 'int64'], + 'paddle.distributed.utils.prune_gate_by_capacity') + check_variable_and_dtype(expert_count, 'ExpertCount', ['int32', 'int64'], + 'paddle.distributed.utils.prune_gate_by_capacity') + + helper = LayerHelper('prune_gate_by_capacity', **locals()) + new_gate_idx = helper.create_variable_for_type_inference( + dtype=gate_idx.dtype) + helper.append_op( + type='prune_gate_by_capacity', + inputs={'GateIdx': gate_idx, + "ExpertCount": expert_count}, + outputs={'NewGateIdx': new_gate_idx}, + attrs={"n_expert": n_expert, + "n_worker": n_worker}) + + return new_gate_idx diff --git a/python/paddle/fluid/tests/unittests/test_prune_gate_by_capacity_op.py b/python/paddle/fluid/tests/unittests/test_prune_gate_by_capacity_op.py new file mode 100644 index 00000000000000..4a0ada3b2fa1a9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_prune_gate_by_capacity_op.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np + + +class TestPruneGateByCapacityOp(unittest.TestCase): + def init_test_case(self): + self.n_expert = 5 + self.n_worker = 1 + self.gate_idx = np.array([1, 3, 3, 3, 3, 2, 1, 1]).astype("int64") + self.expert_count = np.array([0, 3, 1, 3, 0]).astype("int64") + + def setUp(self): + self.init_test_case() + self.place = paddle.CUDAPlace(0) + + def test_static_api(self): + paddle.enable_static() + + def run(place): + with paddle.static.program_guard(paddle.static.Program()): + gate_idx_tensor = paddle.static.data( + 'GateIdx', shape=self.gate_idx.shape, dtype="int64") + expert_count_tensor = paddle.static.data( + 'ExpertCount', shape=self.expert_count.shape, dtype="int64") + out = paddle.distributed.utils.prune_gate_by_capacity( + gate_idx_tensor, expert_count_tensor, self.n_expert, + self.n_worker) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={ + 'GateIdx': self.gate_idx, + 'ExpertCount': self.expert_count, + }, + fetch_list=out) + + print("---------------------------------") + print("static_api:") + print("gate_idx:", self.gate_idx) + print("expert_count:", self.expert_count) + print("new_gate_idx:", res) + print("----------------------------------") + + run(self.place) + + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + gate_idx_tensor = paddle.to_tensor(self.gate_idx) + expert_count_tensor = paddle.to_tensor(self.expert_count) + out = paddle.distributed.utils.prune_gate_by_capacity( + gate_idx_tensor, expert_count_tensor, self.n_expert, + self.n_worker) + + print("---------------------------------") + print("dygraph_api:") + print("gate_idx:", self.gate_idx) + print("expert_count:", self.expert_count) + print("new_gate_idx:", out) + print("----------------------------------") + + run(self.place) + + +if __name__ == '__main__': + unittest.main()