Skip to content

Commit 0234127

Browse files
committed
update readme and inline docstring
1 parent a6d202a commit 0234127

7 files changed

+29
-16
lines changed

README.md

+11-8
Original file line numberDiff line numberDiff line change
@@ -75,35 +75,38 @@ the MLP layer by the `FMoE` layers.
7575

7676
### Using FastMoE in Parallel
7777

78-
FastMoE supports both data parallel and model parallel.
78+
FastMoE supports multiple ways of parallel training. See [a comprehensive
79+
document for parallelism](doc/parallelism) for details. Below shows the two
80+
simplest ways of using FastMoE in parallel.
7981

8082
#### Data Parallel
8183

8284
In FastMoE's data parallel mode, both the gate and the experts are replicated on each worker.
8385
The following figure shows the forward pass of a 3-expert MoE with 2-way data parallel.
8486

8587
<p align="center">
86-
<img src="doc/fastmoe_data_parallel.png" width="600">
88+
<img src="doc/parallelism/fastmoe_data_parallel.png" width="600">
8789
</p>
8890

8991
For data parallel, no extra coding is needed. FastMoE works seamlessly with PyTorch's `DataParallel` or `DistributedDataParallel`.
9092
The only drawback of data parallel is that the number of experts is constrained by each worker's memory.
9193

92-
#### Model Parallel
94+
#### Expert Parallel (also called Model Parlallel in some previous versions)
9395

94-
In FastMoE's model parallel mode, the gate network is still replicated on each worker but
96+
In FastMoE's expert parallel mode, the gate network is still replicated on each worker but
9597
experts are placed separately across workers.
9698
Thus, by introducing additional communication cost, FastMoE enjoys a large expert pool whose size is proportional to the number of workers.
9799

98100
The following figure shows the forward pass of a 6-expert MoE with 2-way model parallel. Note that experts 1-3 are located in worker 1 while experts 4-6 are located in worker 2.
99101

100102
<p align="center">
101-
<img src="doc/fastmoe_model_parallel.png" width="600">
103+
<img src="doc/parallelism/fastmoe_expert_parallel.png" width="600">
102104
</p>
103105

104-
FastMoE's model parallel requires sophiscated parallel strategies that neither PyTorch nor
105-
Megatron-LM provides. The `fmoe.DistributedGroupedDataParallel` module is
106-
introduced to replace PyTorch's DDP module.
106+
FastMoE's expert parallel requires sophiscated parallel strategies that neither
107+
PyTorch nor Megatron-LM provided when FastMoE was created. The
108+
`fmoe.DistributedGroupedDataParallel` module is introduced to replace PyTorch's
109+
DDP module.
107110

108111
#### Faster Performance Features
109112

doc/parallelism/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Multi-Dimensional Parallelism Supported by FastMoE
2+
===
3+
4+
这篇文档懒得写中文版了. 在获得来自社区的贡献前, 请自行谷歌翻译.
File renamed without changes.
File renamed without changes.

doc/parallelism/parallelism.png

68.7 KB
Loading

doc/readme-cn.md

+10-8
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ train(model, ...)
6464

6565
### 分布式地使用 FastMoE
6666

67-
FastMoE 支持数据并行和模型并行.
67+
FastMoE 支持并行方式. 详见[并行方式详细说明](doc/parallelism).
68+
以下简单介绍两种最容易使用的并行方式.
6869

6970
#### 数据并行.
7071

@@ -73,29 +74,30 @@ FastMoE 支持数据并行和模型并行.
7374
下图展示了一个有三个专家的两路数据并行MoE模型进行前向计算的方式.
7475

7576
<p align="center">
76-
<img src="fastmoe_data_parallel.png" width="600">
77+
<img src="parallelism/fastmoe_data_parallel.png" width="600">
7778
</p>
7879

7980
对于数据并行, 额外的代码是不需要的. FastMoE 与 PyTorch 的 `DataParallel`
8081
`DistributedDataParallel` 模块都可以无缝对接. 该方式唯一的问题是,
8182
专家的数量受到单个计算单元(如GPU)的内存大小限制.
8283

83-
#### 模型并行
84+
#### 专家并行 (也曾被叫作模型并行)
8485

85-
在 FastMoE 的模型并行模式中, 门网络依然是复制地被放置在每个计算单元上的,
86+
在 FastMoE 的专家并行模式中, 门网络依然是复制地被放置在每个计算单元上的,
8687
但是专家网络被独立地分别放置在各个计算单元上. 因此, 通过引入额外的通信操作,
8788
FastMoE 可以允许更多的专家网络们同时被训练,
8889
而其数量限制与计算单元的数量是正相关的.
8990

90-
下图展示了一个有六个专家网络的模型被两路模型并行地训练.
91+
下图展示了一个有六个专家网络的模型被两路专家并行地训练.
9192
注意专家1-3被放置在第一个计算单元上, 而专家4-6被放置在第二个计算单元上.
9293

9394
<p align="center">
94-
<img src="fastmoe_model_parallel.png" width="600">
95+
<img src="parallelism/fastmoe_expert_parallel.png" width="600">
9596
</p>
9697

97-
FastMoE 的模型并行模式需要专门的并行策略, 而 PyTorch 和 Megatron-LM
98-
都不支持这样的策略. 因此, 需要使用 `fmoe.DistributedGroupedDataParallel`
98+
FastMoE 的专家并行模式需要专门的并行策略, 而 PyTorch 和 Megatron-LM
99+
都不支持这样的策略 (在我们创建 FastMoE 时). 因此, 需要使用
100+
`fmoe.DistributedGroupedDataParallel`
99101
模块来代替 PyTorch 的 DDP 模块.
100102

101103
### 如何训练得更快

fmoe/layers.py

+4
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ class FMoE(nn.Module):
9797
the output. For each worker, FMoE only computes the output of a certain
9898
slice of the input batch, and will all-gather the outputs after
9999
computation.
100+
* `mp_group` is a deprecated alias of `slice_group`
101+
* `moe_group` stands for the group of process that performs expert
102+
parallelism. The default value `None` means all processes. See the
103+
parallelism document for more details of the groups.
100104
* `top_k` stands for the number of experts each token is going to.
101105
* `gate` is a gate class which can found in `fmoe.gates`.
102106
* `expert` can be specified as a module class, it is used to generate

0 commit comments

Comments
 (0)