Skip to content

Commit 7f6b148

Browse files
committed
unify moe implementation for llama4 and deepseek_v3
1 parent a204e31 commit 7f6b148

File tree

24 files changed

+315
-584
lines changed

24 files changed

+315
-584
lines changed

torchtitan/components/optimizer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
__all__ = [
2525
"OptimizersContainer",
2626
"build_optimizers",
27+
"build_optimizers_with_moe_load_balancing",
2728
]
2829

2930

@@ -323,3 +324,55 @@ def build_optimizers(
323324
)
324325

325326
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
327+
328+
329+
def build_optimizers_with_moe_load_balancing(
330+
model_parts: list[nn.Module],
331+
optimizer_config: OptimizerConfig,
332+
parallel_dims: ParallelDims,
333+
ft_manager: FTManager | None = None,
334+
) -> OptimizersContainer:
335+
optimizers = build_optimizers(
336+
model_parts=model_parts,
337+
optimizer_config=optimizer_config,
338+
parallel_dims=parallel_dims,
339+
ft_manager=ft_manager,
340+
)
341+
342+
# for MoE auxiliary-loss-free load balancing
343+
def _update_expert_bias(
344+
model_parts: list[nn.Module],
345+
parallel_dims: ParallelDims,
346+
):
347+
dp_cp_mesh = (
348+
parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None
349+
)
350+
# TODO: Currently this sync is blocking (thus exposed) and happens on the
351+
# default compute stream. Need to assess if this is OK performance-wise.
352+
for model_part in model_parts:
353+
for transformer_block in model_part.layers.values():
354+
if transformer_block.moe_enabled:
355+
moe = transformer_block.moe
356+
if moe.load_balance_coeff is None:
357+
return
358+
359+
if dp_cp_mesh is not None:
360+
torch.distributed.all_reduce(
361+
moe.tokens_per_expert, group=dp_cp_mesh.get_group()
362+
)
363+
364+
with torch.no_grad():
365+
expert_bias_delta = moe.load_balance_coeff * torch.sign(
366+
moe.tokens_per_expert.mean() - moe.tokens_per_expert
367+
)
368+
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
369+
moe.expert_bias.add_(expert_bias_delta)
370+
moe.tokens_per_expert.zero_()
371+
372+
optimizers.register_step_pre_hook(
373+
lambda *args, **kwargs: _update_expert_bias(
374+
model_parts, parallel_dims=parallel_dims
375+
)
376+
)
377+
378+
return optimizers

torchtitan/components/quantization/float8.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010

1111
from torchtitan.config.job_config import Float8, JobConfig
1212
from torchtitan.distributed import ParallelDims
13-
from torchtitan.experiments.llama4.infra.expert_parallel import (
14-
set_token_group_alignment_size_m,
15-
)
13+
from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m
1614
from torchtitan.protocols.model_converter import (
1715
ModelConverter,
1816
register_model_converter,

torchtitan/components/quantization/mx.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from torchtitan.config.job_config import JobConfig, MX
1515
from torchtitan.distributed import ParallelDims
16+
from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m
1617
from torchtitan.protocols.model_converter import (
1718
ModelConverter,
1819
register_model_converter,
@@ -58,12 +59,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5859

5960
# For MoE training with mxfp8, token group sizes must be multiples of 32
6061
if job_config.mx.moe_fqns_prototype:
61-
from torchtitan.experiments.llama4.infra.expert_parallel import (
62-
set_token_group_alignment_size,
63-
)
64-
6562
mxfp8_block_size = 32
66-
set_token_group_alignment_size(mxfp8_block_size)
63+
set_token_group_alignment_size_m(mxfp8_block_size)
6764
logger.info(f"Setting token group alignment size to {mxfp8_block_size}")
6865

6966
# Configure MXFP8

torchtitan/experiments/llama4/infra/expert_parallel.py renamed to torchtitan/distributed/expert_parallel.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch
1212
import torch.distributed as dist
1313
import torch.nn as nn
14-
from torch.distributed._functional_collectives import all_to_all_single_autograd
1514
from torch.distributed.tensor import (
1615
DeviceMesh,
1716
distribute_module,
@@ -24,6 +23,41 @@
2423
from torch.distributed.tensor.placement_types import Placement
2524

2625

26+
# from torch.distributed._functional_collectives import all_to_all_single_autograd
27+
# TODO: there is memory leak issue with AC + all_to_all_single_autograd
28+
# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467
29+
class _A2A(torch.autograd.Function):
30+
@staticmethod
31+
def forward(ctx, x, out_splits, in_splits, group):
32+
if isinstance(out_splits, torch.Tensor):
33+
out_splits = out_splits.tolist()
34+
if isinstance(in_splits, torch.Tensor):
35+
in_splits = in_splits.tolist()
36+
T_out = int(sum(out_splits))
37+
38+
y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits
39+
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)
40+
41+
ctx.in_splits = in_splits
42+
ctx.out_splits = out_splits
43+
ctx.group = group
44+
return y
45+
46+
@staticmethod
47+
def backward(ctx, grad_y):
48+
# grad wrt input has length sum(in_splits)
49+
T_in = int(sum(ctx.in_splits))
50+
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:]))
51+
dist.all_to_all_single(
52+
grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group
53+
)
54+
return grad_x, None, None, None
55+
56+
57+
def all_to_all_single_autograd(x, out_splits, in_splits, group):
58+
return _A2A.apply(x, out_splits, in_splits, group)
59+
60+
2761
TOKEN_GROUP_ALIGN_SIZE_M = 8
2862
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
2963

torchtitan/distributed/parallel_dims.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,3 @@ def seq_len_divisor(self):
232232
# when load balancing is enabled (by default).
233233
# https://github.com/pytorch/pytorch/blob/4f62dcc/torch/distributed/tensor/experimental/_attention.py#L1246
234234
return self.tp * (self.cp * 2)
235-
236-
@cached_property
237-
def dense_params_mesh_ndim(self):
238-
# Note: In dp2ep EP, EP params mesh ndim is 1 more due to the 'ep' mesh
239-
return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled
File renamed without changes.

torchtitan/distributed/utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616
from torch import distributed as dist
1717
from torch.distributed.device_mesh import DeviceMesh
1818
from torch.distributed.tensor import DTensor
19-
from torch.nn.attention import SDPBackend
2019

2120
from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP
2221
from torchtitan.distributed.parallel_dims import ParallelDims
23-
from torchtitan.models.attention import ScaledDotProductAttention
2422
from torchtitan.tools.logging import logger
2523
from torchtitan.tools.utils import device_module, device_type
2624

@@ -202,6 +200,10 @@ def context(cp_context: Generator[None, None, None] | None = None):
202200
)
203201

204202
if cp_context is not None:
203+
from torch.nn.attention import SDPBackend
204+
205+
from torchtitan.models.attention import ScaledDotProductAttention
206+
205207
if SDPBackend.MATH in ScaledDotProductAttention.backends:
206208
ScaledDotProductAttention.backends.remove(SDPBackend.MATH)
207209
assert (
@@ -319,7 +321,7 @@ def clip_grad_norm_(
319321
error_if_nonfinite: bool = False,
320322
foreach: bool | None = None,
321323
pp_mesh: DeviceMesh | None = None,
322-
ep_dense_params_mesh_ndim: int | None = None,
324+
ep_enabled: bool = False,
323325
) -> torch.Tensor:
324326
"""
325327
Clip the gradient norm of an iterable of parameters.
@@ -349,15 +351,14 @@ def clip_grad_norm_(
349351
Total norm of the parameter gradients (viewed as a single vector).
350352
351353
"""
352-
if ep_dense_params_mesh_ndim is not None:
354+
if ep_enabled:
353355
return _clip_grad_norm_with_ep(
354356
parameters,
355357
max_norm,
356358
norm_type,
357359
error_if_nonfinite,
358360
foreach,
359361
pp_mesh,
360-
ep_dense_params_mesh_ndim,
361362
)
362363

363364
if isinstance(parameters, torch.Tensor):
@@ -401,7 +402,6 @@ def _clip_grad_norm_with_ep(
401402
error_if_nonfinite: bool,
402403
foreach: bool | None,
403404
pp_mesh: DeviceMesh | None,
404-
dense_params_mesh_ndim: int,
405405
) -> torch.Tensor:
406406
ep_params = []
407407
non_ep_params = []
@@ -412,12 +412,12 @@ def _clip_grad_norm_with_ep(
412412
if p.grad is None:
413413
continue
414414
assert isinstance(p, DTensor) and isinstance(p.grad, DTensor)
415-
if p.device_mesh.ndim == dense_params_mesh_ndim:
416-
non_ep_params.append(p)
417-
non_ep_grads.append(p.grad)
418-
else:
415+
if "ep" in p.device_mesh.mesh_dim_names:
419416
ep_params.append(p)
420417
ep_grads.append(p.grad)
418+
else:
419+
non_ep_params.append(p)
420+
non_ep_grads.append(p.grad)
421421
ep_grads_total_norm = torch.nn.utils.get_total_norm(
422422
ep_grads, norm_type, error_if_nonfinite, foreach
423423
).full_tensor()

torchtitan/experiments/forge/example_train.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,7 @@ def train_step(
231231
pp_mesh=(
232232
parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
233233
),
234-
ep_dense_params_mesh_ndim=(
235-
parallel_dims.dense_params_mesh_ndim
236-
if parallel_dims.ep_enabled
237-
else None
238-
),
234+
ep_enabled=parallel_dims.ep_enabled,
239235
)
240236
self.checkpointer.maybe_wait_for_staging()
241237
self.optimizers.step()

torchtitan/experiments/llama4/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66

77
from torchtitan.components.loss import build_cross_entropy_loss
88
from torchtitan.components.lr_scheduler import build_lr_schedulers
9+
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
910
from torchtitan.components.tokenizer import build_hf_tokenizer
1011
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1112
from torchtitan.models.llama3 import pipeline_llama
13+
from torchtitan.models.moe import MoEArgs
1214
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
1315

1416
from .infra.parallelize import parallelize_llama
1517
from .model.args import TransformerModelArgs
1618
from .model.model import Transformer
17-
from .optimizer import build_llama4_optimizers
1819

1920
__all__ = [
2021
"TransformerModelArgs",
@@ -40,7 +41,7 @@
4041
multiple_of=2048,
4142
rope_theta=500000,
4243
max_seq_len=10485760,
43-
num_experts=16,
44+
moe_args=MoEArgs(num_experts=16),
4445
interleave_moe_layer_step=1,
4546
),
4647
"17bx128e": TransformerModelArgs(
@@ -51,7 +52,7 @@
5152
ffn_dim_multiplier=1.2,
5253
multiple_of=2048,
5354
rope_theta=500000,
54-
num_experts=128,
55+
moe_args=MoEArgs(num_experts=128),
5556
),
5657
"debugmodel_irope": TransformerModelArgs(
5758
dim=256,
@@ -73,7 +74,7 @@
7374
multiple_of=2048,
7475
rope_theta=500000,
7576
max_seq_len=10485760,
76-
num_experts=16,
77+
moe_args=MoEArgs(num_experts=16),
7778
interleave_moe_layer_step=1,
7879
every_n_layers_nope=4,
7980
use_flex_attn=True,
@@ -87,7 +88,7 @@
8788
ffn_dim_multiplier=1.2,
8889
multiple_of=2048,
8990
rope_theta=500000,
90-
num_experts=128,
91+
moe_args=MoEArgs(num_experts=128),
9192
every_n_layers_nope=4,
9293
use_flex_attn=True,
9394
attn_mask_type="block_causal",
@@ -102,7 +103,7 @@
102103
model_args=llama4_configs,
103104
parallelize_fn=parallelize_llama,
104105
pipelining_fn=pipeline_llama,
105-
build_optimizers_fn=build_llama4_optimizers,
106+
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
106107
build_lr_schedulers_fn=build_lr_schedulers,
107108
build_dataloader_fn=build_hf_dataloader,
108109
build_tokenizer_fn=build_hf_tokenizer,

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@
2121
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
2222
from torchtitan.distributed import ParallelDims
2323

24-
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
25-
from torchtitan.tools.logging import logger
26-
27-
from .expert_parallel import (
24+
from torchtitan.distributed.expert_parallel import (
2825
ExpertParallel,
2926
ExpertTensorParallel,
3027
NoParallel,
3128
TensorParallel,
3229
)
3330

31+
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
32+
from torchtitan.tools.logging import logger
33+
3434

3535
def parallelize_llama(
3636
model: nn.Module,

0 commit comments

Comments
 (0)