Skip to content

Commit

Permalink
Merge pull request #27 from sljlp/moe
Browse files Browse the repository at this point in the history
fix moe bugs
  • Loading branch information
lilong12 authored Dec 23, 2021
2 parents 4130e4e + 092289c commit 187300e
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions python/paddle/distributed/model/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,11 @@ def prepare_forward(gate, num_expert, world_size, moe_group):
class MoeLayer(nn.Layer):
"""Moe Layer
Args:
d_model: (int) model dimention
experts: (nn.LayerList) expert networks list
gate_config: (dict): gate network config, containing 2 keys:
`type`(str) value can be: "naive", "gshard", "switch" or None, default is "gshard"
`top_k`(int) default value is 2
d_model: (int) model dimention
experts: (nn.LayerList) expert networks list
moe_group: moe group for experts communication
mp_group: mp group for mp commutication
kwargs: other parameters
Expand Down Expand Up @@ -301,7 +301,10 @@ def __init__(self,
assert isinstance(gate_config, dict), "gate config' type must be dict"
# only support mp/dp
self.group = moe_group
self.world_size = self.group.nranks

self.world_size = 1
if self.group is not None:
self.world_size = self.group.nranks
self.num_expert = len(experts)
self.recompute_interval = recompute_interval
assert experts is not None
Expand Down

0 comments on commit 187300e

Please sign in to comment.