Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
__all__ = [
"OptimizersContainer",
"build_optimizers",
"build_optimizers_with_moe_load_balancing",
]


Expand Down Expand Up @@ -323,3 +324,55 @@ def build_optimizers(
)

return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)


def build_optimizers_with_moe_load_balancing(
model_parts: list[nn.Module],
optimizer_config: OptimizerConfig,
parallel_dims: ParallelDims,
ft_manager: FTManager | None = None,
) -> OptimizersContainer:
optimizers = build_optimizers(
model_parts=model_parts,
optimizer_config=optimizer_config,
parallel_dims=parallel_dims,
ft_manager=ft_manager,
)

# for MoE auxiliary-loss-free load balancing
def _update_expert_bias(
model_parts: list[nn.Module],
parallel_dims: ParallelDims,
):
dp_cp_mesh = (
parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None
)
# TODO: Currently this sync is blocking (thus exposed) and happens on the
# default compute stream. Need to assess if this is OK performance-wise.
for model_part in model_parts:
for transformer_block in model_part.layers.values():
if transformer_block.moe_enabled:
moe = transformer_block.moe
if moe.load_balance_coeff is None:
return

if dp_cp_mesh is not None:
torch.distributed.all_reduce(
moe.tokens_per_expert, group=dp_cp_mesh.get_group()
)

with torch.no_grad():
expert_bias_delta = moe.load_balance_coeff * torch.sign(
moe.tokens_per_expert.mean() - moe.tokens_per_expert
)
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
moe.expert_bias.add_(expert_bias_delta)
moe.tokens_per_expert.zero_()

optimizers.register_step_pre_hook(
lambda *args, **kwargs: _update_expert_bias(
model_parts, parallel_dims=parallel_dims
)
)

return optimizers
4 changes: 1 addition & 3 deletions torchtitan/components/quantization/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

from torchtitan.config.job_config import Float8, JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.experiments.llama4.infra.expert_parallel import (
set_token_group_alignment_size_m,
)
from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m
from torchtitan.protocols.model_converter import (
ModelConverter,
register_model_converter,
Expand Down
7 changes: 2 additions & 5 deletions torchtitan/components/quantization/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from torchtitan.config.job_config import JobConfig, MX
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m
from torchtitan.protocols.model_converter import (
ModelConverter,
register_model_converter,
Expand Down Expand Up @@ -58,12 +59,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):

# For MoE training with mxfp8, token group sizes must be multiples of 32
if job_config.mx.moe_fqns_prototype:
from torchtitan.experiments.llama4.infra.expert_parallel import (
set_token_group_alignment_size,
)

mxfp8_block_size = 32
set_token_group_alignment_size(mxfp8_block_size)
set_token_group_alignment_size_m(mxfp8_block_size)
logger.info(f"Setting token group alignment size to {mxfp8_block_size}")

# Configure MXFP8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._functional_collectives import all_to_all_single_autograd
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
Expand All @@ -24,6 +23,41 @@
from torch.distributed.tensor.placement_types import Placement


# from torch.distributed._functional_collectives import all_to_all_single_autograd
# TODO: there is memory leak issue with AC + all_to_all_single_autograd
# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467
class _A2A(torch.autograd.Function):
@staticmethod
def forward(ctx, x, out_splits, in_splits, group):
if isinstance(out_splits, torch.Tensor):
out_splits = out_splits.tolist()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't tolist() cause d2h sync? is this okay / intentional in this case?

Copy link
Contributor Author

@tianyu-l tianyu-l Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will. This is a temporary fix, but currently in EP there are multiple places with d2h sync. I'm working on another implementation to kill them.

if isinstance(in_splits, torch.Tensor):
in_splits = in_splits.tolist()
T_out = int(sum(out_splits))

y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)

ctx.in_splits = in_splits
ctx.out_splits = out_splits
ctx.group = group
return y

@staticmethod
def backward(ctx, grad_y):
# grad wrt input has length sum(in_splits)
T_in = int(sum(ctx.in_splits))
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:]))
dist.all_to_all_single(
grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group
)
return grad_x, None, None, None


def all_to_all_single_autograd(x, out_splits, in_splits, group):
return _A2A.apply(x, out_splits, in_splits, group)


TOKEN_GROUP_ALIGN_SIZE_M = 8
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]

Expand Down
5 changes: 0 additions & 5 deletions torchtitan/distributed/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,3 @@ def seq_len_divisor(self):
# when load balancing is enabled (by default).
# https://github.com/pytorch/pytorch/blob/4f62dcc/torch/distributed/tensor/experimental/_attention.py#L1246
return self.tp * (self.cp * 2)

@cached_property
def dense_params_mesh_ndim(self):
# Note: In dp2ep EP, EP params mesh ndim is 1 more due to the 'ep' mesh
return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled
20 changes: 10 additions & 10 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
from torch import distributed as dist
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor
from torch.nn.attention import SDPBackend

from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP
from torchtitan.distributed.parallel_dims import ParallelDims
from torchtitan.models.attention import ScaledDotProductAttention
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import device_module, device_type

Expand Down Expand Up @@ -202,6 +200,10 @@ def context(cp_context: Generator[None, None, None] | None = None):
)

if cp_context is not None:
from torch.nn.attention import SDPBackend

from torchtitan.models.attention import ScaledDotProductAttention

if SDPBackend.MATH in ScaledDotProductAttention.backends:
ScaledDotProductAttention.backends.remove(SDPBackend.MATH)
assert (
Expand Down Expand Up @@ -319,7 +321,7 @@ def clip_grad_norm_(
error_if_nonfinite: bool = False,
foreach: bool | None = None,
pp_mesh: DeviceMesh | None = None,
ep_dense_params_mesh_ndim: int | None = None,
ep_enabled: bool = False,
) -> torch.Tensor:
"""
Clip the gradient norm of an iterable of parameters.
Expand Down Expand Up @@ -349,15 +351,14 @@ def clip_grad_norm_(
Total norm of the parameter gradients (viewed as a single vector).

"""
if ep_dense_params_mesh_ndim is not None:
if ep_enabled:
return _clip_grad_norm_with_ep(
parameters,
max_norm,
norm_type,
error_if_nonfinite,
foreach,
pp_mesh,
ep_dense_params_mesh_ndim,
)

if isinstance(parameters, torch.Tensor):
Expand Down Expand Up @@ -401,7 +402,6 @@ def _clip_grad_norm_with_ep(
error_if_nonfinite: bool,
foreach: bool | None,
pp_mesh: DeviceMesh | None,
dense_params_mesh_ndim: int,
) -> torch.Tensor:
ep_params = []
non_ep_params = []
Expand All @@ -412,12 +412,12 @@ def _clip_grad_norm_with_ep(
if p.grad is None:
continue
assert isinstance(p, DTensor) and isinstance(p.grad, DTensor)
if p.device_mesh.ndim == dense_params_mesh_ndim:
non_ep_params.append(p)
non_ep_grads.append(p.grad)
else:
if "ep" in p.device_mesh.mesh_dim_names:
ep_params.append(p)
ep_grads.append(p.grad)
else:
non_ep_params.append(p)
non_ep_grads.append(p.grad)
ep_grads_total_norm = torch.nn.utils.get_total_norm(
ep_grads, norm_type, error_if_nonfinite, foreach
).full_tensor()
Expand Down
6 changes: 1 addition & 5 deletions torchtitan/experiments/forge/example_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,7 @@ def train_step(
pp_mesh=(
parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
),
ep_dense_params_mesh_ndim=(
parallel_dims.dense_params_mesh_ndim
if parallel_dims.ep_enabled
else None
),
ep_enabled=parallel_dims.ep_enabled,
)
self.checkpointer.maybe_wait_for_staging()
self.optimizers.step()
Expand Down
13 changes: 7 additions & 6 deletions torchtitan/experiments/llama4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.models.llama3 import pipeline_llama
from torchtitan.models.moe import MoEArgs
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .infra.parallelize import parallelize_llama
from .model.args import TransformerModelArgs
from .model.model import Transformer
from .optimizer import build_llama4_optimizers

__all__ = [
"TransformerModelArgs",
Expand All @@ -40,7 +41,7 @@
multiple_of=2048,
rope_theta=500000,
max_seq_len=10485760,
num_experts=16,
moe_args=MoEArgs(num_experts=16),
interleave_moe_layer_step=1,
),
"17bx128e": TransformerModelArgs(
Expand All @@ -51,7 +52,7 @@
ffn_dim_multiplier=1.2,
multiple_of=2048,
rope_theta=500000,
num_experts=128,
moe_args=MoEArgs(num_experts=128),
),
"debugmodel_irope": TransformerModelArgs(
dim=256,
Expand All @@ -73,7 +74,7 @@
multiple_of=2048,
rope_theta=500000,
max_seq_len=10485760,
num_experts=16,
moe_args=MoEArgs(num_experts=16),
interleave_moe_layer_step=1,
every_n_layers_nope=4,
use_flex_attn=True,
Expand All @@ -87,7 +88,7 @@
ffn_dim_multiplier=1.2,
multiple_of=2048,
rope_theta=500000,
num_experts=128,
moe_args=MoEArgs(num_experts=128),
every_n_layers_nope=4,
use_flex_attn=True,
attn_mask_type="block_causal",
Expand All @@ -102,7 +103,7 @@
model_args=llama4_configs,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_llama4_optimizers,
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
Expand Down
8 changes: 4 additions & 4 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims

from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
from torchtitan.tools.logging import logger

from .expert_parallel import (
from torchtitan.distributed.expert_parallel import (
ExpertParallel,
ExpertTensorParallel,
NoParallel,
TensorParallel,
)

from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
from torchtitan.tools.logging import logger


def parallelize_llama(
model: nn.Module,
Expand Down
21 changes: 8 additions & 13 deletions torchtitan/experiments/llama4/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
# LICENSE file in the root directory of this source tree.


from dataclasses import dataclass
from dataclasses import dataclass, field

from torch import nn

from torchtitan.config import JobConfig

from torchtitan.models.moe import MoEArgs
from torchtitan.protocols import BaseModelArgs
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import has_cuda_capability
Expand All @@ -34,7 +36,6 @@ class TransformerModelArgs(BaseModelArgs):

use_flex_attn: bool = False
attn_mask_type: str = "causal"
eos_id: int = 0
# iRoPE settings
# When ``every_n_layers_nope`` is specified, NoPE (no positional embedding) is
# used every n layers. Other layers uses RoPE (rotary positional embedding) and
Expand All @@ -45,17 +46,11 @@ class TransformerModelArgs(BaseModelArgs):
every_n_layers_nope: int | None = None
fixed_attn_block_size: int = 8192

# MoE args
moe_enabled: bool = True
num_experts: int = 8
use_shared_expert: bool = True
# MoE
moe_args: MoEArgs = field(default_factory=MoEArgs)
auto_scale_hidden_dim: bool = True
# frequency of using MoE layer instead of feedforward layer in a transformer block
interleave_moe_layer_step: int = 2
# token-choice
top_k: int = 1
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
load_balance_coeff: float | None = 1e-3

def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
seq_len = job_config.training.seq_len
Expand All @@ -65,11 +60,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
)
self.max_seq_len = seq_len

if self.use_grouped_mm and not has_cuda_capability(9, 0):
if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0):
logger.warning(
"Failed to use grouped mm, which is only supported on SM90 or later",
)
self.use_grouped_mm = False
self.moe_args.use_grouped_mm = False

if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
raise NotImplementedError(
Expand Down Expand Up @@ -112,7 +107,7 @@ def get_nparams_and_flops(
nparams_sparse_active = (
nparams_moe_router
+ nparams_shared_expert
+ nparams_experts * self.top_k // self.num_experts
+ nparams_experts * self.moe_args.top_k // self.moe_args.num_experts
)

logger.info(
Expand Down
Loading
Loading