Skip to content

Commit 85dc2ad

Browse files
committed
unify moe implementation for llama4 and deepseek_v3
1 parent a204e31 commit 85dc2ad

File tree

20 files changed

+307
-553
lines changed

20 files changed

+307
-553
lines changed

torchtitan/components/optimizer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
__all__ = [
2525
"OptimizersContainer",
2626
"build_optimizers",
27+
"build_optimizers_with_moe_load_balancing",
2728
]
2829

2930

@@ -323,3 +324,55 @@ def build_optimizers(
323324
)
324325

325326
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
327+
328+
329+
def build_optimizers_with_moe_load_balancing(
330+
model_parts: list[nn.Module],
331+
optimizer_config: OptimizerConfig,
332+
parallel_dims: ParallelDims,
333+
ft_manager: FTManager | None = None,
334+
) -> OptimizersContainer:
335+
optimizers = build_optimizers(
336+
model_parts=model_parts,
337+
optimizer_config=optimizer_config,
338+
parallel_dims=parallel_dims,
339+
ft_manager=ft_manager,
340+
)
341+
342+
# for MoE auxiliary-loss-free load balancing
343+
def _update_expert_bias(
344+
model_parts: list[nn.Module],
345+
parallel_dims: ParallelDims,
346+
):
347+
dp_cp_mesh = (
348+
parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None
349+
)
350+
# TODO: Currently this sync is blocking (thus exposed) and happens on the
351+
# default compute stream. Need to assess if this is OK performance-wise.
352+
for model_part in model_parts:
353+
for transformer_block in model_part.layers.values():
354+
if transformer_block.moe_enabled:
355+
moe = transformer_block.moe
356+
if moe.load_balance_coeff is None:
357+
return
358+
359+
if dp_cp_mesh is not None:
360+
torch.distributed.all_reduce(
361+
moe.tokens_per_expert, group=dp_cp_mesh.get_group()
362+
)
363+
364+
with torch.no_grad():
365+
expert_bias_delta = moe.load_balance_coeff * torch.sign(
366+
moe.tokens_per_expert.mean() - moe.tokens_per_expert
367+
)
368+
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
369+
moe.expert_bias.add_(expert_bias_delta)
370+
moe.tokens_per_expert.zero_()
371+
372+
optimizers.register_step_pre_hook(
373+
lambda *args, **kwargs: _update_expert_bias(
374+
model_parts, parallel_dims=parallel_dims
375+
)
376+
)
377+
378+
return optimizers

torchtitan/components/quantization/float8.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010

1111
from torchtitan.config.job_config import Float8, JobConfig
1212
from torchtitan.distributed import ParallelDims
13-
from torchtitan.experiments.llama4.infra.expert_parallel import (
14-
set_token_group_alignment_size_m,
15-
)
13+
from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m
1614
from torchtitan.protocols.model_converter import (
1715
ModelConverter,
1816
register_model_converter,

torchtitan/components/quantization/mx.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from torchtitan.config.job_config import JobConfig, MX
1515
from torchtitan.distributed import ParallelDims
16+
from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m
1617
from torchtitan.protocols.model_converter import (
1718
ModelConverter,
1819
register_model_converter,
@@ -58,12 +59,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5859

5960
# For MoE training with mxfp8, token group sizes must be multiples of 32
6061
if job_config.mx.moe_fqns_prototype:
61-
from torchtitan.experiments.llama4.infra.expert_parallel import (
62-
set_token_group_alignment_size,
63-
)
64-
6562
mxfp8_block_size = 32
66-
set_token_group_alignment_size(mxfp8_block_size)
63+
set_token_group_alignment_size_m(mxfp8_block_size)
6764
logger.info(f"Setting token group alignment size to {mxfp8_block_size}")
6865

6966
# Configure MXFP8

torchtitan/experiments/llama4/infra/expert_parallel.py renamed to torchtitan/distributed/expert_parallel.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torch.distributed as dist
1313
import torch.nn as nn
14-
from torch.distributed._functional_collectives import all_to_all_single_autograd
14+
1515
from torch.distributed.tensor import (
1616
DeviceMesh,
1717
distribute_module,
@@ -24,6 +24,41 @@
2424
from torch.distributed.tensor.placement_types import Placement
2525

2626

27+
# from torch.distributed._functional_collectives import all_to_all_single_autograd
28+
# TODO: there is memory leak issue with AC + all_to_all_single_autograd
29+
# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467
30+
class _A2A(torch.autograd.Function):
31+
@staticmethod
32+
def forward(ctx, x, out_splits, in_splits, group):
33+
if isinstance(out_splits, torch.Tensor):
34+
out_splits = out_splits.tolist()
35+
if isinstance(in_splits, torch.Tensor):
36+
in_splits = in_splits.tolist()
37+
T_out = int(sum(out_splits))
38+
39+
y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits
40+
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)
41+
42+
ctx.in_splits = in_splits
43+
ctx.out_splits = out_splits
44+
ctx.group = group
45+
return y
46+
47+
@staticmethod
48+
def backward(ctx, grad_y):
49+
# grad wrt input has length sum(in_splits)
50+
T_in = int(sum(ctx.in_splits))
51+
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:]))
52+
dist.all_to_all_single(
53+
grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group
54+
)
55+
return grad_x, None, None, None
56+
57+
58+
def all_to_all_single_autograd(x, out_splits, in_splits, group):
59+
return _A2A.apply(x, out_splits, in_splits, group)
60+
61+
2762
TOKEN_GROUP_ALIGN_SIZE_M = 8
2863
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
2964

File renamed without changes.

torchtitan/distributed/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616
from torch import distributed as dist
1717
from torch.distributed.device_mesh import DeviceMesh
1818
from torch.distributed.tensor import DTensor
19-
from torch.nn.attention import SDPBackend
2019

2120
from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP
2221
from torchtitan.distributed.parallel_dims import ParallelDims
23-
from torchtitan.models.attention import ScaledDotProductAttention
2422
from torchtitan.tools.logging import logger
2523
from torchtitan.tools.utils import device_module, device_type
2624

@@ -202,6 +200,10 @@ def context(cp_context: Generator[None, None, None] | None = None):
202200
)
203201

204202
if cp_context is not None:
203+
from torch.nn.attention import SDPBackend
204+
205+
from torchtitan.models.attention import ScaledDotProductAttention
206+
205207
if SDPBackend.MATH in ScaledDotProductAttention.backends:
206208
ScaledDotProductAttention.backends.remove(SDPBackend.MATH)
207209
assert (

torchtitan/experiments/llama4/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66

77
from torchtitan.components.loss import build_cross_entropy_loss
88
from torchtitan.components.lr_scheduler import build_lr_schedulers
9+
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
910
from torchtitan.components.tokenizer import build_hf_tokenizer
1011
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1112
from torchtitan.models.llama3 import pipeline_llama
13+
from torchtitan.models.moe import MoEArgs
1214
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
1315

1416
from .infra.parallelize import parallelize_llama
1517
from .model.args import TransformerModelArgs
1618
from .model.model import Transformer
17-
from .optimizer import build_llama4_optimizers
1819

1920
__all__ = [
2021
"TransformerModelArgs",
@@ -40,7 +41,7 @@
4041
multiple_of=2048,
4142
rope_theta=500000,
4243
max_seq_len=10485760,
43-
num_experts=16,
44+
moe_args=MoEArgs(num_experts=16),
4445
interleave_moe_layer_step=1,
4546
),
4647
"17bx128e": TransformerModelArgs(
@@ -51,7 +52,7 @@
5152
ffn_dim_multiplier=1.2,
5253
multiple_of=2048,
5354
rope_theta=500000,
54-
num_experts=128,
55+
moe_args=MoEArgs(num_experts=128),
5556
),
5657
"debugmodel_irope": TransformerModelArgs(
5758
dim=256,
@@ -73,7 +74,7 @@
7374
multiple_of=2048,
7475
rope_theta=500000,
7576
max_seq_len=10485760,
76-
num_experts=16,
77+
moe_args=MoEArgs(num_experts=16),
7778
interleave_moe_layer_step=1,
7879
every_n_layers_nope=4,
7980
use_flex_attn=True,
@@ -87,7 +88,7 @@
8788
ffn_dim_multiplier=1.2,
8889
multiple_of=2048,
8990
rope_theta=500000,
90-
num_experts=128,
91+
moe_args=MoEArgs(num_experts=128),
9192
every_n_layers_nope=4,
9293
use_flex_attn=True,
9394
attn_mask_type="block_causal",
@@ -102,7 +103,7 @@
102103
model_args=llama4_configs,
103104
parallelize_fn=parallelize_llama,
104105
pipelining_fn=pipeline_llama,
105-
build_optimizers_fn=build_llama4_optimizers,
106+
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
106107
build_lr_schedulers_fn=build_lr_schedulers,
107108
build_dataloader_fn=build_hf_dataloader,
108109
build_tokenizer_fn=build_hf_tokenizer,

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@
2121
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
2222
from torchtitan.distributed import ParallelDims
2323

24-
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
25-
from torchtitan.tools.logging import logger
26-
27-
from .expert_parallel import (
24+
from torchtitan.distributed.expert_parallel import (
2825
ExpertParallel,
2926
ExpertTensorParallel,
3027
NoParallel,
3128
TensorParallel,
3229
)
3330

31+
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
32+
from torchtitan.tools.logging import logger
33+
3434

3535
def parallelize_llama(
3636
model: nn.Module,

torchtitan/experiments/llama4/model/args.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
from dataclasses import dataclass
8+
from dataclasses import dataclass, field
99

1010
from torch import nn
1111

1212
from torchtitan.config import JobConfig
13+
14+
from torchtitan.models.moe import MoEArgs
1315
from torchtitan.protocols import BaseModelArgs
1416
from torchtitan.tools.logging import logger
1517
from torchtitan.tools.utils import has_cuda_capability
@@ -34,7 +36,6 @@ class TransformerModelArgs(BaseModelArgs):
3436

3537
use_flex_attn: bool = False
3638
attn_mask_type: str = "causal"
37-
eos_id: int = 0
3839
# iRoPE settings
3940
# When ``every_n_layers_nope`` is specified, NoPE (no positional embedding) is
4041
# used every n layers. Other layers uses RoPE (rotary positional embedding) and
@@ -45,17 +46,11 @@ class TransformerModelArgs(BaseModelArgs):
4546
every_n_layers_nope: int | None = None
4647
fixed_attn_block_size: int = 8192
4748

48-
# MoE args
49-
moe_enabled: bool = True
50-
num_experts: int = 8
51-
use_shared_expert: bool = True
49+
# MoE
50+
moe_args: MoEArgs = field(default_factory=MoEArgs)
5251
auto_scale_hidden_dim: bool = True
5352
# frequency of using MoE layer instead of feedforward layer in a transformer block
5453
interleave_moe_layer_step: int = 2
55-
# token-choice
56-
top_k: int = 1
57-
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
58-
load_balance_coeff: float | None = 1e-3
5954

6055
def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
6156
seq_len = job_config.training.seq_len
@@ -65,11 +60,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
6560
)
6661
self.max_seq_len = seq_len
6762

68-
if self.use_grouped_mm and not has_cuda_capability(9, 0):
63+
if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0):
6964
logger.warning(
7065
"Failed to use grouped mm, which is only supported on SM90 or later",
7166
)
72-
self.use_grouped_mm = False
67+
self.moe_args.use_grouped_mm = False
7368

7469
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
7570
raise NotImplementedError(
@@ -112,7 +107,7 @@ def get_nparams_and_flops(
112107
nparams_sparse_active = (
113108
nparams_moe_router
114109
+ nparams_shared_expert
115-
+ nparams_experts * self.top_k // self.num_experts
110+
+ nparams_experts * self.moe_args.top_k // self.moe_args.num_experts
116111
)
117112

118113
logger.info(

torchtitan/experiments/llama4/model/model.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from torch import nn
1111

1212
from torchtitan.models.attention import build_attention, init_attention_mask
13+
from torchtitan.models.moe import MoE
1314
from torchtitan.protocols import ModelProtocol
1415

1516
from .args import TransformerModelArgs
16-
from .moe import MoE
1717

1818

1919
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
@@ -296,12 +296,28 @@ def __init__(
296296
self.attention = Attention(model_args, attn_use_rope, fixed_attn_block_size)
297297

298298
# use MoE layer for every interleave_moe_layer_step FFN layers
299+
moe_args = model_args.moe_args
299300
self.moe_enabled = (
300-
model_args.moe_enabled
301+
moe_args.moe_enabled
301302
and (layer_id + 1) % model_args.interleave_moe_layer_step == 0
302303
)
303304
if self.moe_enabled:
304-
self.moe = MoE(model_args)
305+
dim = model_args.dim
306+
hidden_dim = 4 * model_args.dim
307+
ffn_dim_multiplier = model_args.ffn_dim_multiplier
308+
hidden_dim = int(2 * hidden_dim / 3)
309+
if ffn_dim_multiplier is not None:
310+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
311+
312+
hidden_dim_denom = 1
313+
if model_args.auto_scale_hidden_dim:
314+
hidden_dim_denom = moe_args.top_k + moe_args.num_shared_experts
315+
316+
if model_args.auto_scale_hidden_dim:
317+
hidden_dim = int(hidden_dim / hidden_dim_denom)
318+
hidden_dim += -hidden_dim % model_args.multiple_of
319+
320+
self.moe = MoE(moe_args, dim=dim, hidden_dim=hidden_dim)
305321
else:
306322
self.feed_forward = FeedForward(
307323
dim=model_args.dim,

0 commit comments

Comments
 (0)