Skip to content

Commit dc3946e

Browse files
committed
unify moe implementation for llama4 and deepseek_v3
1 parent a204e31 commit dc3946e

File tree

17 files changed

+246
-481
lines changed

17 files changed

+246
-481
lines changed

torchtitan/components/quantization/float8.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010

1111
from torchtitan.config.job_config import Float8, JobConfig
1212
from torchtitan.distributed import ParallelDims
13-
from torchtitan.experiments.llama4.infra.expert_parallel import (
14-
set_token_group_alignment_size_m,
15-
)
13+
from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m
1614
from torchtitan.protocols.model_converter import (
1715
ModelConverter,
1816
register_model_converter,

torchtitan/components/quantization/mx.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from torchtitan.config.job_config import JobConfig, MX
1515
from torchtitan.distributed import ParallelDims
16+
from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m
1617
from torchtitan.protocols.model_converter import (
1718
ModelConverter,
1819
register_model_converter,
@@ -58,12 +59,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5859

5960
# For MoE training with mxfp8, token group sizes must be multiples of 32
6061
if job_config.mx.moe_fqns_prototype:
61-
from torchtitan.experiments.llama4.infra.expert_parallel import (
62-
set_token_group_alignment_size,
63-
)
64-
6562
mxfp8_block_size = 32
66-
set_token_group_alignment_size(mxfp8_block_size)
63+
set_token_group_alignment_size_m(mxfp8_block_size)
6764
logger.info(f"Setting token group alignment size to {mxfp8_block_size}")
6865

6966
# Configure MXFP8

torchtitan/experiments/llama4/infra/expert_parallel.py renamed to torchtitan/distributed/expert_parallel.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import torch
1212
import torch.distributed as dist
1313
import torch.nn as nn
14-
from torch.distributed._functional_collectives import all_to_all_single_autograd
14+
15+
# from torch.distributed._functional_collectives import all_to_all_single_autograd
1516
from torch.distributed.tensor import (
1617
DeviceMesh,
1718
distribute_module,
@@ -24,6 +25,40 @@
2425
from torch.distributed.tensor.placement_types import Placement
2526

2627

28+
# TODO: there is memory leak issue with AC + PT-D all_to_all_single_autograd
29+
# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467
30+
class _A2A(torch.autograd.Function):
31+
@staticmethod
32+
def forward(ctx, x, out_splits, in_splits, group):
33+
if isinstance(out_splits, torch.Tensor):
34+
out_splits = out_splits.tolist()
35+
if isinstance(in_splits, torch.Tensor):
36+
in_splits = in_splits.tolist()
37+
T_out = int(sum(out_splits))
38+
39+
y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits
40+
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)
41+
42+
ctx.in_splits = in_splits
43+
ctx.out_splits = out_splits
44+
ctx.group = group
45+
return y
46+
47+
@staticmethod
48+
def backward(ctx, grad_y):
49+
# grad wrt input has length sum(in_splits)
50+
T_in = int(sum(ctx.in_splits))
51+
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:]))
52+
dist.all_to_all_single(
53+
grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group
54+
)
55+
return grad_x, None, None, None
56+
57+
58+
def all_to_all_single_autograd(x, out_splits, in_splits, group):
59+
return _A2A.apply(x, out_splits, in_splits, group)
60+
61+
2762
TOKEN_GROUP_ALIGN_SIZE_M = 8
2863
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
2964

File renamed without changes.

torchtitan/experiments/llama4/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torchtitan.components.tokenizer import build_hf_tokenizer
1010
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1111
from torchtitan.models.llama3 import pipeline_llama
12+
from torchtitan.models.moe import MoEArgs
1213
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
1314

1415
from .infra.parallelize import parallelize_llama
@@ -40,7 +41,7 @@
4041
multiple_of=2048,
4142
rope_theta=500000,
4243
max_seq_len=10485760,
43-
num_experts=16,
44+
moe_args=MoEArgs(num_experts=16),
4445
interleave_moe_layer_step=1,
4546
),
4647
"17bx128e": TransformerModelArgs(
@@ -51,7 +52,7 @@
5152
ffn_dim_multiplier=1.2,
5253
multiple_of=2048,
5354
rope_theta=500000,
54-
num_experts=128,
55+
moe_args=MoEArgs(num_experts=128),
5556
),
5657
"debugmodel_irope": TransformerModelArgs(
5758
dim=256,
@@ -73,7 +74,7 @@
7374
multiple_of=2048,
7475
rope_theta=500000,
7576
max_seq_len=10485760,
76-
num_experts=16,
77+
moe_args=MoEArgs(num_experts=16),
7778
interleave_moe_layer_step=1,
7879
every_n_layers_nope=4,
7980
use_flex_attn=True,
@@ -87,7 +88,7 @@
8788
ffn_dim_multiplier=1.2,
8889
multiple_of=2048,
8990
rope_theta=500000,
90-
num_experts=128,
91+
moe_args=MoEArgs(num_experts=128),
9192
every_n_layers_nope=4,
9293
use_flex_attn=True,
9394
attn_mask_type="block_causal",

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@
2121
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
2222
from torchtitan.distributed import ParallelDims
2323

24-
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
25-
from torchtitan.tools.logging import logger
26-
27-
from .expert_parallel import (
24+
from torchtitan.distributed.expert_parallel import (
2825
ExpertParallel,
2926
ExpertTensorParallel,
3027
NoParallel,
3128
TensorParallel,
3229
)
3330

31+
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
32+
from torchtitan.tools.logging import logger
33+
3434

3535
def parallelize_llama(
3636
model: nn.Module,

torchtitan/experiments/llama4/model/args.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
from dataclasses import dataclass
8+
from dataclasses import dataclass, field
99

1010
from torch import nn
1111

1212
from torchtitan.config import JobConfig
13+
14+
from torchtitan.models.moe import MoEArgs
1315
from torchtitan.protocols import BaseModelArgs
1416
from torchtitan.tools.logging import logger
1517
from torchtitan.tools.utils import has_cuda_capability
@@ -34,7 +36,6 @@ class TransformerModelArgs(BaseModelArgs):
3436

3537
use_flex_attn: bool = False
3638
attn_mask_type: str = "causal"
37-
eos_id: int = 0
3839
# iRoPE settings
3940
# When ``every_n_layers_nope`` is specified, NoPE (no positional embedding) is
4041
# used every n layers. Other layers uses RoPE (rotary positional embedding) and
@@ -45,17 +46,11 @@ class TransformerModelArgs(BaseModelArgs):
4546
every_n_layers_nope: int | None = None
4647
fixed_attn_block_size: int = 8192
4748

48-
# MoE args
49-
moe_enabled: bool = True
50-
num_experts: int = 8
51-
use_shared_expert: bool = True
49+
# MoE
50+
moe_args: MoEArgs = field(default_factory=MoEArgs)
5251
auto_scale_hidden_dim: bool = True
5352
# frequency of using MoE layer instead of feedforward layer in a transformer block
5453
interleave_moe_layer_step: int = 2
55-
# token-choice
56-
top_k: int = 1
57-
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
58-
load_balance_coeff: float | None = 1e-3
5954

6055
def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
6156
seq_len = job_config.training.seq_len
@@ -65,11 +60,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
6560
)
6661
self.max_seq_len = seq_len
6762

68-
if self.use_grouped_mm and not has_cuda_capability(9, 0):
63+
if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0):
6964
logger.warning(
7065
"Failed to use grouped mm, which is only supported on SM90 or later",
7166
)
72-
self.use_grouped_mm = False
67+
self.moe_args.use_grouped_mm = False
7368

7469
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
7570
raise NotImplementedError(
@@ -112,7 +107,7 @@ def get_nparams_and_flops(
112107
nparams_sparse_active = (
113108
nparams_moe_router
114109
+ nparams_shared_expert
115-
+ nparams_experts * self.top_k // self.num_experts
110+
+ nparams_experts * self.moe_args.top_k // self.moe_args.num_experts
116111
)
117112

118113
logger.info(

torchtitan/experiments/llama4/model/model.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from torch import nn
1111

1212
from torchtitan.models.attention import build_attention, init_attention_mask
13+
from torchtitan.models.moe import MoE
1314
from torchtitan.protocols import ModelProtocol
1415

1516
from .args import TransformerModelArgs
16-
from .moe import MoE
1717

1818

1919
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
@@ -296,12 +296,28 @@ def __init__(
296296
self.attention = Attention(model_args, attn_use_rope, fixed_attn_block_size)
297297

298298
# use MoE layer for every interleave_moe_layer_step FFN layers
299+
moe_args = model_args.moe_args
299300
self.moe_enabled = (
300-
model_args.moe_enabled
301+
moe_args.moe_enabled
301302
and (layer_id + 1) % model_args.interleave_moe_layer_step == 0
302303
)
303304
if self.moe_enabled:
304-
self.moe = MoE(model_args)
305+
dim = model_args.dim
306+
hidden_dim = 4 * model_args.dim
307+
ffn_dim_multiplier = model_args.ffn_dim_multiplier
308+
hidden_dim = int(2 * hidden_dim / 3)
309+
if ffn_dim_multiplier is not None:
310+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
311+
312+
hidden_dim_denom = 1
313+
if model_args.auto_scale_hidden_dim:
314+
hidden_dim_denom = moe_args.top_k + moe_args.num_shared_experts
315+
316+
if model_args.auto_scale_hidden_dim:
317+
hidden_dim = int(hidden_dim / hidden_dim_denom)
318+
hidden_dim += -hidden_dim % model_args.multiple_of
319+
320+
self.moe = MoE(moe_args, dim=dim, hidden_dim=hidden_dim)
305321
else:
306322
self.feed_forward = FeedForward(
307323
dim=model_args.dim,

torchtitan/experiments/llama4/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ export_dtype = "float32"
6363
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
6464

6565
[activation_checkpoint]
66-
mode = "none" # ["none", "selective", "full"]
66+
mode = "selective" # ["none", "selective", "full"]
6767
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy
6868

6969
[float8]

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1313
from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers
1414
from torchtitan.models.llama3.infra.pipeline import pipeline_llama
15+
from torchtitan.models.moe import MoEArgs
1516

1617
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
1718

@@ -36,10 +37,14 @@
3637
n_layers=3,
3738
n_dense_layers=1,
3839
n_heads=16,
39-
n_routed_experts=8,
40-
n_shared_experts=2,
41-
n_activated_experts=3,
42-
route_scale=1.0,
40+
moe_args=MoEArgs(
41+
num_experts=8,
42+
num_shared_experts=2,
43+
top_k=3,
44+
score_func="softmax",
45+
route_norm=True,
46+
score_before_experts=False,
47+
),
4348
q_lora_rank=0,
4449
kv_lora_rank=512,
4550
qk_nope_head_dim=128,
@@ -55,10 +60,14 @@
5560
n_layers=3,
5661
n_dense_layers=1,
5762
n_heads=16,
58-
n_routed_experts=8,
59-
n_shared_experts=2,
60-
n_activated_experts=3,
61-
route_scale=1.0,
63+
moe_args=MoEArgs(
64+
num_experts=8,
65+
num_shared_experts=2,
66+
top_k=3,
67+
score_func="softmax",
68+
route_norm=True,
69+
score_before_experts=False,
70+
),
6271
q_lora_rank=0,
6372
kv_lora_rank=512,
6473
qk_nope_head_dim=128,
@@ -76,10 +85,14 @@
7685
n_layers=27,
7786
n_dense_layers=1,
7887
n_heads=16,
79-
n_routed_experts=64,
80-
n_shared_experts=2,
81-
n_activated_experts=6,
82-
route_scale=1.0,
88+
moe_args=MoEArgs(
89+
num_experts=64,
90+
num_shared_experts=2,
91+
top_k=6,
92+
score_func="softmax",
93+
route_norm=True,
94+
score_before_experts=False,
95+
),
8396
q_lora_rank=0,
8497
kv_lora_rank=512,
8598
qk_nope_head_dim=128,
@@ -95,12 +108,17 @@
95108
n_layers=60,
96109
n_dense_layers=1,
97110
n_heads=128,
98-
n_routed_experts=160,
99-
n_shared_experts=2,
100-
n_activated_experts=6,
111+
moe_args=MoEArgs(
112+
num_experts=160,
113+
num_shared_experts=2,
114+
top_k=6,
115+
score_func="softmax",
116+
route_norm=True,
117+
route_scale=16.0,
118+
score_before_experts=False,
119+
),
101120
n_expert_groups=8,
102121
n_limited_groups=3,
103-
route_scale=16.0,
104122
q_lora_rank=1536,
105123
kv_lora_rank=512,
106124
qk_nope_head_dim=128,
@@ -115,13 +133,17 @@
115133
n_layers=61,
116134
n_dense_layers=3,
117135
n_heads=128,
118-
n_routed_experts=256,
119-
n_shared_experts=1,
120-
n_activated_experts=8,
136+
moe_args=MoEArgs(
137+
num_experts=256,
138+
num_shared_experts=1,
139+
top_k=8,
140+
score_func="sigmoid",
141+
route_norm=True,
142+
route_scale=2.5,
143+
score_before_experts=False,
144+
),
121145
n_expert_groups=8,
122146
n_limited_groups=4,
123-
route_scale=2.5,
124-
score_func="sigmoid",
125147
q_lora_rank=1536,
126148
kv_lora_rank=512,
127149
qk_nope_head_dim=128,

0 commit comments

Comments
 (0)