Skip to content

Commit

Permalink
[MoE] Moe apis (#41092)
Browse files Browse the repository at this point in the history
* add random routing op

add _random_routing api in utils

add random routing ut

* # This is a combination of 10 commits.
# The first commit's message is:
add expert count op

add ut for expert_count

# This is the 2nd commit message:

update UT only for cuda

# This is the 3rd commit message:

fix for rocm

# This is the 4th commit message:

update ut

# This is the 5th commit message:

add moe module

# This is the 6th commit message:

add expert count op

add ut for expert_count

# This is the 7th commit message:

update UT only for cuda

# This is the 8th commit message:

update ut

# This is the 9th commit message:

add moe module

# This is the 10th commit message:

make expert count private

* add assign pos op

* fix upper num name

* add api _assign pos

* add ut for assign pos op

* update date

* add op about moe gate

update utils

add limit by capacity op

add ut for limit_by_capacity

add ut for prune_gate_by_capacity

add ut for limit_by_capacity

add ut for prune_gate_by_capacity

* fix for win

* fix bugs in test_limit_by_capacity_op

* update ut

* update for test (timeout)

* fix ut

* update

* update(fix) ut for win

* moe apis in incubate

* # This is a combination of 10 commits.
# The first commit's message is:
add expert count op

add ut for expert_count

# This is the 2nd commit message:

update UT only for cuda

# This is the 3rd commit message:

fix for rocm

# This is the 4th commit message:

update ut

# This is the 5th commit message:

add moe module

# This is the 6th commit message:

add expert count op

add ut for expert_count

# This is the 7th commit message:

update UT only for cuda

# This is the 8th commit message:

update ut

# This is the 9th commit message:

add moe module

# This is the 10th commit message:

make expert count private

* add assign pos op

* fix upper num name

* add api _assign pos

* add ut for assign pos op

* update date

* fix for win

* update for test (timeout)

* fix ut

* update

* fix ut for number count

* add apis and utils

* add gate apis

* add moe and grad clip apis

* update moe apis

* add ops for moe gate

* fix

* update for base moe layer api

* add random routing op

add _random_routing api in utils

add random routing ut

* fix for dygraph

* update with ranodm routing

* update

* fix ut for limit by capacity

* update

* update limit by capacity for easily to switch to single thread mode

* update api docs

Co-authored-by: hlygit66666 <2570058140@qq.com>
  • Loading branch information
sljlp and liyagit21 authored Mar 30, 2022
1 parent 8f7c02f commit aac7879
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 25 deletions.
13 changes: 5 additions & 8 deletions paddle/fluid/operators/limit_by_capacity_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,17 @@
namespace paddle {
namespace operators {

#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1)

using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

template <typename T>
__global__ void limit_by_capacity_impl(const T* expc, T* cap, T* out,
const int n_expert, const int n_worker) {
int eid = blockIdx.y;
int wid = blockIdx.x * blockDim.x + threadIdx.x;
if (wid < n_worker) {
int eid, wid;
CUDA_KERNEL_LOOP(i, (n_expert * n_worker)) {
wid = i / n_expert;
eid = i % n_expert;
auto proposal = expc[wid * n_expert + eid];
// int cap_left = atomicSub(cap + eid, proposal);
auto cap_left = paddle::platform::CudaAtomicAdd(cap + eid, proposal * (-1));
if (cap_left >= proposal) {
out[wid * n_expert + eid] = proposal;
Expand All @@ -54,12 +52,11 @@ class LimitByCapacityOpCUDAKernel : public framework::OpKernel<T> {
auto out = context.Output<Tensor>("Out");

auto n_expert = expert_count->numel() / n_worker;
// std::cout << "n_expert" << n_expert << std::endl;
const auto place = context.GetPlace();
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();

dim3 grid_dim(CEIL(n_worker, 1024), n_expert);
dim3 grid_dim(256);
dim3 block_dim(1024);
auto out_data = out->mutable_data<T>(place);
const T* ec_data = expert_count->data<T>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
22 changes: 10 additions & 12 deletions python/paddle/incubate/distributed/models/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute
from paddle import fluid

__all__ = ["MoeLayer"]


def _local_scatter(inp, pos):
if pos.shape != [0]:
Expand Down Expand Up @@ -71,7 +69,7 @@ def _all_gather(tensor, group=None, use_calc_stream=True):
'ring_id', ring_id, 'nranks', nranks)


class MOEScatter(PyLayer):
class MoEScatter(PyLayer):
r"""
Scatter input samples from [batch x sequences] to contiguous alone experts.
If `world_size` is greater than 1, the samples will first be locally
Expand Down Expand Up @@ -117,10 +115,10 @@ def backward(ctx, grad):
return grad_in, None, None, None


class MOEGather(PyLayer):
class MoEGather(PyLayer):
r"""
Gather output samples from contiguous alone experts back to [batch x
sequences]. Works symmetrically with MOEScatter.
sequences]. Works symmetrically with MoEScatter.
"""

@staticmethod
Expand Down Expand Up @@ -225,8 +223,8 @@ def prepare_forward(gate, num_expert, world_size, moe_group):
fwd_batch_size, )


class MoeLayer(nn.Layer):
"""Moe Layer
class MoELayer(nn.Layer):
"""MoE Layer
Args:
d_model: (int) model dimention
experts: (nn.LayerList) expert networks list
Expand All @@ -243,7 +241,7 @@ class MoeLayer(nn.Layer):
Examples:
.. code-block:: python
from paddle.nn import layer, LayerList
from paddle.distributed.moe import Moelayer
from paddle.distributed.moe import MoElayer
from paddle.distributed.collective import Group
from paddle.distributed import fleet
Expand Down Expand Up @@ -279,7 +277,7 @@ def forward(self, x):
exp_layer = ExpertLayer(d_model, dim_feedforward // top_k, windex=expi, num_expert=num_experts)
experts_list.append(exp_layer)
moeLayer = MoeLayer(d_model = d_model,
moeLayer = MoELayer(d_model = d_model,
experts=experts_list,
gate=gate_config,
moe_group=moe_group,
Expand All @@ -295,7 +293,7 @@ def __init__(self,
moe_group=None,
mp_group=None,
**kwargs):
super(MoeLayer, self).__init__()
super(MoELayer, self).__init__()

recompute_interval = kwargs.get("recompute_interval", 0)

Expand Down Expand Up @@ -385,7 +383,7 @@ def forward(self, inp):
temp_pos = pos
assert topk == self.top_k

x = MOEScatter.apply(inp, temp_pos, local_expert_count,
x = MoEScatter.apply(inp, temp_pos, local_expert_count,
global_expert_count, fwd_batch_size,
self.world_size, self.group)

Expand Down Expand Up @@ -416,7 +414,7 @@ def experts_fwd(x, fwd_expert_count, experts):
if len(gate.shape) == 2:
out_batch_size *= gate.shape[1]

x = MOEGather.apply(x, pos, local_expert_count, global_expert_count,
x = MoEGather.apply(x, pos, local_expert_count, global_expert_count,
out_batch_size, self.world_size, self.group)

x = x.reshape([-1, self.top_k, d_model])
Expand Down

0 comments on commit aac7879

Please sign in to comment.