Skip to content

Commit dc003c3

Browse files
authored
[moe] merge moe into main (#4978)
* update moe module * support openmoe
1 parent 8993c8a commit dc003c3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+7634
-1673
lines changed

colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

+382
Large diffs are not rendered by default.

colossalai/context/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from .config import Config, ConfigException
22

3-
# from .moe_context import MOE_CONTEXT
4-
53
__all__ = [
64
"Config",
75
"ConfigException",

colossalai/context/moe_context.py

-132
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from functools import reduce
2+
from typing import Any, Tuple
3+
4+
import torch
5+
from torch import Tensor
6+
from torch.cuda.amp import custom_bwd, custom_fwd
7+
8+
try:
9+
import triton
10+
import triton.language as tl
11+
HAS_TRITON = True
12+
except ImportError:
13+
HAS_TRITON = False
14+
print("please install triton from https://github.com/openai/triton")
15+
16+
if HAS_TRITON:
17+
PRECISION_MAP = {
18+
"fp32": (0, torch.float32),
19+
"fp16": (1, torch.float16),
20+
"bf16": (2, torch.bfloat16),
21+
}
22+
23+
@triton.jit
24+
def _llama_act_combine_forward(
25+
X_GATE1,
26+
X_GATE2,
27+
X_UP,
28+
Y,
29+
stride, # how much to increase the pointer when moving by 1 row
30+
N, # number of columns in X
31+
BLOCK_SIZE: tl.constexpr,
32+
):
33+
# Map the program id to the row of X and Y it should compute.
34+
row = tl.program_id(0)
35+
X_GATE1 += row * stride
36+
X_GATE2 += row * stride
37+
X_UP += row * stride
38+
Y += row * stride
39+
40+
# do activation and combine, and store in y
41+
for off in range(0, N, BLOCK_SIZE):
42+
cols = off + tl.arange(0, BLOCK_SIZE)
43+
mask = cols < N
44+
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
45+
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
46+
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
47+
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
48+
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
49+
# Write output
50+
tl.store(Y + cols, y, mask=mask)
51+
52+
@triton.jit
53+
def _llama_act_combine_backward(
54+
X_GATE1,
55+
X_GATE2,
56+
X_UP,
57+
X_GATE1_GRAD,
58+
X_GATE2_GRAD,
59+
X_UP_GRAD,
60+
Y_GRAD,
61+
stride, # how much to increase the pointer when moving by 1 row
62+
N, # number of columns in X
63+
BLOCK_SIZE: tl.constexpr,
64+
):
65+
# Map the program id to the row of X and Y it should compute.
66+
row = tl.program_id(0)
67+
X_GATE1 += row * stride
68+
X_GATE2 += row * stride
69+
X_UP += row * stride
70+
X_GATE1_GRAD += row * stride
71+
X_GATE2_GRAD += row * stride
72+
X_UP_GRAD += row * stride
73+
Y_GRAD += row * stride
74+
75+
# do activation and combine, and store in y
76+
for off in range(0, N, BLOCK_SIZE):
77+
cols = off + tl.arange(0, BLOCK_SIZE)
78+
mask = cols < N
79+
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
80+
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
81+
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
82+
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.)
83+
84+
# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
85+
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
86+
x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid
87+
x_up_grad = x_gate2_act * x_gate1
88+
x_gate1_grad = x_gate2_act * x_up
89+
# grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 − sigmoid(x)]
90+
# = sigmoid(x) * {1 + x * [(1 − sigmoid(x)]}
91+
x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid))
92+
93+
# Write output
94+
tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask)
95+
tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask)
96+
tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask)
97+
98+
class LlamaActCombine(torch.autograd.Function):
99+
"""
100+
act(x_gate) * x_up
101+
102+
Args:
103+
x_gate (torch.Tensor): (b, l, 2d) x_gate
104+
x_up (torch.Tensor): (b, l, d) x_up
105+
activation (str): only support swiglu
106+
precision (str): fp32, fp16, bf16
107+
"""
108+
109+
@staticmethod
110+
@custom_fwd
111+
def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor:
112+
"""
113+
act(x_gate) * x_up
114+
115+
Args:
116+
x_gate (torch.Tensor): (b, l, 2d) x gate
117+
x_up (torch.Tensor): (b, l, d) x up
118+
activation (str): only support swiglu
119+
"""
120+
assert activation == "swiglu", "Only swiglu is supported"
121+
122+
# split x gate
123+
assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2"
124+
x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1)
125+
x_gate1 = x_gate1.contiguous()
126+
x_gate2 = x_gate2.contiguous()
127+
if not x_up.is_contiguous():
128+
x_up = x_up.contiguous()
129+
# assert shape
130+
assert x_gate1.shape == x_gate2.shape == x_up.shape
131+
132+
# add ctx for backward
133+
if x_gate.requires_grad:
134+
ctx.save_for_backward(x_gate1, x_gate2, x_up)
135+
136+
# allocate output
137+
y = torch.empty_like(x_up)
138+
M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1]
139+
140+
# Less than 64KB per feature: enqueue fused kernel
141+
MAX_FUSED_SIZE = 65536 // x_gate.element_size()
142+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
143+
if N > BLOCK_SIZE:
144+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
145+
# heuristics for number of warps
146+
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
147+
# restore setting
148+
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
149+
# enqueue kernel
150+
_llama_act_combine_forward[(M,)](x_gate1,
151+
x_gate2,
152+
x_up,
153+
y,
154+
x_up.stride(-2),
155+
N,
156+
BLOCK_SIZE=BLOCK_SIZE,
157+
num_warps=num_warps)
158+
return y
159+
160+
@staticmethod
161+
@custom_bwd
162+
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]:
163+
# restore from ctx
164+
(x_gate1, x_gate2, x_up) = ctx.saved_tensors
165+
M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps
166+
167+
# init grad
168+
y_grad = grad_outputs[0]
169+
x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like(
170+
x_gate2), torch.empty_like(x_up)
171+
172+
# enqueue kernel
173+
_llama_act_combine_backward[(M,)](x_gate1,
174+
x_gate2,
175+
x_up,
176+
x_gate1_grad,
177+
x_gate2_grad,
178+
x_up_grad,
179+
y_grad,
180+
x_up.stride(-2),
181+
N,
182+
BLOCK_SIZE=BLOCK_SIZE,
183+
num_warps=num_warps)
184+
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
185+
return x_gate_grad, x_up_grad, None, None
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from ._base_gradient_handler import BaseGradientHandler
22
from ._data_parallel_gradient_handler import DataParallelGradientHandler
3-
from ._moe_gradient_handler import MoeGradientHandler
43
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
54
from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
65
from ._zero_gradient_handler import ZeROGradientHandler
@@ -10,6 +9,5 @@
109
"DataParallelGradientHandler",
1110
"ZeROGradientHandler",
1211
"PipelineSharedModuleGradientHandler",
13-
"MoeGradientHandler",
1412
"SequenceParallelGradientHandler",
1513
]

colossalai/legacy/initialize.py

-12
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from torch.utils.data import DataLoader
1717

1818
from colossalai.context import Config, ConfigException
19-
from colossalai.context.moe_context import MOE_CONTEXT
2019
from colossalai.interface import OptimizerWrapper
2120
from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
2221
from colossalai.legacy.amp.naive_amp import NaiveAMPModel
@@ -36,7 +35,6 @@
3635
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
3736
from colossalai.logging import get_dist_logger
3837
from colossalai.utils import get_current_device
39-
from colossalai.utils.moe import sync_moe_model_param
4038

4139

4240
def get_default_parser():
@@ -323,8 +321,6 @@ def initialize(
323321
if not use_zero:
324322
if is_using_sequence():
325323
sync_model_param(model, ParallelMode.SEQUENCE_DP)
326-
elif MOE_CONTEXT.is_initialized:
327-
sync_moe_model_param(model)
328324
elif is_using_ddp():
329325
sync_model_param(model, ParallelMode.DATA)
330326
else:
@@ -377,14 +373,6 @@ def initialize(
377373
"added even though not specified in the configuration",
378374
ranks=[0],
379375
)
380-
elif is_using_ddp() and MOE_CONTEXT.is_initialized:
381-
gradient_handler_cfg = [dict(type="MoeGradientHandler")]
382-
if verbose:
383-
logger.info(
384-
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
385-
"added even though not specified in the configuration",
386-
ranks=[0],
387-
)
388376
elif is_using_sequence():
389377
model = DDP(
390378
model,

colossalai/moe/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from .checkpoint import MoeCheckpintIO
2+
from .experts import MLPExperts
3+
from .layers import SparseMLP
4+
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
5+
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
6+
7+
__all__ = [
8+
"MLPExperts",
9+
"MoeRouter",
10+
"Top1Router",
11+
"Top2Router",
12+
"TopKRouter",
13+
"NormalNoiseGenerator",
14+
"UniformNoiseGenerator",
15+
"SparseMLP",
16+
"MoeCheckpintIO",
17+
]

0 commit comments

Comments
 (0)