diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 2a112177e0..ce71ac7f0c 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -24,6 +24,7 @@ __all__ = [ "OptimizersContainer", "build_optimizers", + "build_optimizers_with_moe_load_balancing", ] @@ -323,3 +324,55 @@ def build_optimizers( ) return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) + + +def build_optimizers_with_moe_load_balancing( + model_parts: list[nn.Module], + optimizer_config: OptimizerConfig, + parallel_dims: ParallelDims, + ft_manager: FTManager | None = None, +) -> OptimizersContainer: + optimizers = build_optimizers( + model_parts=model_parts, + optimizer_config=optimizer_config, + parallel_dims=parallel_dims, + ft_manager=ft_manager, + ) + + # for MoE auxiliary-loss-free load balancing + def _update_expert_bias( + model_parts: list[nn.Module], + parallel_dims: ParallelDims, + ): + dp_cp_mesh = ( + parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None + ) + # TODO: Currently this sync is blocking (thus exposed) and happens on the + # default compute stream. Need to assess if this is OK performance-wise. + for model_part in model_parts: + for transformer_block in model_part.layers.values(): + if transformer_block.moe_enabled: + moe = transformer_block.moe + if moe.load_balance_coeff is None: + return + + if dp_cp_mesh is not None: + torch.distributed.all_reduce( + moe.tokens_per_expert, group=dp_cp_mesh.get_group() + ) + + with torch.no_grad(): + expert_bias_delta = moe.load_balance_coeff * torch.sign( + moe.tokens_per_expert.mean() - moe.tokens_per_expert + ) + expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() + moe.expert_bias.add_(expert_bias_delta) + moe.tokens_per_expert.zero_() + + optimizers.register_step_pre_hook( + lambda *args, **kwargs: _update_expert_bias( + model_parts, parallel_dims=parallel_dims + ) + ) + + return optimizers diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 58699b92ee..3629258154 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -10,9 +10,7 @@ from torchtitan.config.job_config import Float8, JobConfig from torchtitan.distributed import ParallelDims -from torchtitan.experiments.llama4.infra.expert_parallel import ( - set_token_group_alignment_size_m, -) +from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m from torchtitan.protocols.model_converter import ( ModelConverter, register_model_converter, diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index 276208c9a8..15c74b7fd7 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -13,6 +13,7 @@ from torchtitan.config.job_config import JobConfig, MX from torchtitan.distributed import ParallelDims +from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m from torchtitan.protocols.model_converter import ( ModelConverter, register_model_converter, @@ -58,12 +59,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): # For MoE training with mxfp8, token group sizes must be multiples of 32 if job_config.mx.moe_fqns_prototype: - from torchtitan.experiments.llama4.infra.expert_parallel import ( - set_token_group_alignment_size, - ) - mxfp8_block_size = 32 - set_token_group_alignment_size(mxfp8_block_size) + set_token_group_alignment_size_m(mxfp8_block_size) logger.info(f"Setting token group alignment size to {mxfp8_block_size}") # Configure MXFP8 diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/distributed/expert_parallel.py similarity index 90% rename from torchtitan/experiments/llama4/infra/expert_parallel.py rename to torchtitan/distributed/expert_parallel.py index 9a9dad66ae..bc5d43f9f2 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -11,7 +11,6 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch.distributed._functional_collectives import all_to_all_single_autograd from torch.distributed.tensor import ( DeviceMesh, distribute_module, @@ -24,6 +23,41 @@ from torch.distributed.tensor.placement_types import Placement +# from torch.distributed._functional_collectives import all_to_all_single_autograd +# TODO: there is memory leak issue with AC + all_to_all_single_autograd +# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467 +class _A2A(torch.autograd.Function): + @staticmethod + def forward(ctx, x, out_splits, in_splits, group): + if isinstance(out_splits, torch.Tensor): + out_splits = out_splits.tolist() + if isinstance(in_splits, torch.Tensor): + in_splits = in_splits.tolist() + T_out = int(sum(out_splits)) + + y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits + dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group) + + ctx.in_splits = in_splits + ctx.out_splits = out_splits + ctx.group = group + return y + + @staticmethod + def backward(ctx, grad_y): + # grad wrt input has length sum(in_splits) + T_in = int(sum(ctx.in_splits)) + grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:])) + dist.all_to_all_single( + grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group + ) + return grad_x, None, None, None + + +def all_to_all_single_autograd(x, out_splits, in_splits, group): + return _A2A.apply(x, out_splits, in_splits, group) + + TOKEN_GROUP_ALIGN_SIZE_M = 8 ValidTokenGroupAlignmentSize = Literal[8, 16, 32] diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 01e14cc0b0..3108049a6f 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -232,8 +232,3 @@ def seq_len_divisor(self): # when load balancing is enabled (by default). # https://github.com/pytorch/pytorch/blob/4f62dcc/torch/distributed/tensor/experimental/_attention.py#L1246 return self.tp * (self.cp * 2) - - @cached_property - def dense_params_mesh_ndim(self): - # Note: In dp2ep EP, EP params mesh ndim is 1 more due to the 'ep' mesh - return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled diff --git a/torchtitan/distributed/pipeline.py b/torchtitan/distributed/pipeline_parallel.py similarity index 100% rename from torchtitan/distributed/pipeline.py rename to torchtitan/distributed/pipeline_parallel.py diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index aa25149d66..7d4dc935c3 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -16,11 +16,9 @@ from torch import distributed as dist from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor -from torch.nn.attention import SDPBackend from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP from torchtitan.distributed.parallel_dims import ParallelDims -from torchtitan.models.attention import ScaledDotProductAttention from torchtitan.tools.logging import logger from torchtitan.tools.utils import device_module, device_type @@ -202,6 +200,10 @@ def context(cp_context: Generator[None, None, None] | None = None): ) if cp_context is not None: + from torch.nn.attention import SDPBackend + + from torchtitan.models.attention import ScaledDotProductAttention + if SDPBackend.MATH in ScaledDotProductAttention.backends: ScaledDotProductAttention.backends.remove(SDPBackend.MATH) assert ( @@ -319,7 +321,7 @@ def clip_grad_norm_( error_if_nonfinite: bool = False, foreach: bool | None = None, pp_mesh: DeviceMesh | None = None, - ep_dense_params_mesh_ndim: int | None = None, + ep_enabled: bool = False, ) -> torch.Tensor: """ Clip the gradient norm of an iterable of parameters. @@ -349,7 +351,7 @@ def clip_grad_norm_( Total norm of the parameter gradients (viewed as a single vector). """ - if ep_dense_params_mesh_ndim is not None: + if ep_enabled: return _clip_grad_norm_with_ep( parameters, max_norm, @@ -357,7 +359,6 @@ def clip_grad_norm_( error_if_nonfinite, foreach, pp_mesh, - ep_dense_params_mesh_ndim, ) if isinstance(parameters, torch.Tensor): @@ -401,7 +402,6 @@ def _clip_grad_norm_with_ep( error_if_nonfinite: bool, foreach: bool | None, pp_mesh: DeviceMesh | None, - dense_params_mesh_ndim: int, ) -> torch.Tensor: ep_params = [] non_ep_params = [] @@ -412,12 +412,12 @@ def _clip_grad_norm_with_ep( if p.grad is None: continue assert isinstance(p, DTensor) and isinstance(p.grad, DTensor) - if p.device_mesh.ndim == dense_params_mesh_ndim: - non_ep_params.append(p) - non_ep_grads.append(p.grad) - else: + if "ep" in p.device_mesh.mesh_dim_names: ep_params.append(p) ep_grads.append(p.grad) + else: + non_ep_params.append(p) + non_ep_grads.append(p.grad) ep_grads_total_norm = torch.nn.utils.get_total_norm( ep_grads, norm_type, error_if_nonfinite, foreach ).full_tensor() diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index c54fc645c4..0bebd197d1 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -231,11 +231,7 @@ def train_step( pp_mesh=( parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None ), - ep_dense_params_mesh_ndim=( - parallel_dims.dense_params_mesh_ndim - if parallel_dims.ep_enabled - else None - ), + ep_enabled=parallel_dims.ep_enabled, ) self.checkpointer.maybe_wait_for_staging() self.optimizers.step() diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py index 7e3dd8f07c..0ffe139dae 100644 --- a/torchtitan/experiments/llama4/__init__.py +++ b/torchtitan/experiments/llama4/__init__.py @@ -6,15 +6,16 @@ from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.models.llama3 import pipeline_llama +from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import register_train_spec, TrainSpec from .infra.parallelize import parallelize_llama from .model.args import TransformerModelArgs from .model.model import Transformer -from .optimizer import build_llama4_optimizers __all__ = [ "TransformerModelArgs", @@ -40,7 +41,7 @@ multiple_of=2048, rope_theta=500000, max_seq_len=10485760, - num_experts=16, + moe_args=MoEArgs(num_experts=16), interleave_moe_layer_step=1, ), "17bx128e": TransformerModelArgs( @@ -51,7 +52,7 @@ ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, - num_experts=128, + moe_args=MoEArgs(num_experts=128), ), "debugmodel_irope": TransformerModelArgs( dim=256, @@ -73,7 +74,7 @@ multiple_of=2048, rope_theta=500000, max_seq_len=10485760, - num_experts=16, + moe_args=MoEArgs(num_experts=16), interleave_moe_layer_step=1, every_n_layers_nope=4, use_flex_attn=True, @@ -87,7 +88,7 @@ ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, - num_experts=128, + moe_args=MoEArgs(num_experts=128), every_n_layers_nope=4, use_flex_attn=True, attn_mask_type="block_causal", @@ -102,7 +103,7 @@ model_args=llama4_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, - build_optimizers_fn=build_llama4_optimizers, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_hf_tokenizer, diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index b1e60f9962..4a7a860680 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -21,16 +21,16 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims -from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp -from torchtitan.tools.logging import logger - -from .expert_parallel import ( +from torchtitan.distributed.expert_parallel import ( ExpertParallel, ExpertTensorParallel, NoParallel, TensorParallel, ) +from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp +from torchtitan.tools.logging import logger + def parallelize_llama( model: nn.Module, diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index 741f00fd4e..dda130548d 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -5,11 +5,13 @@ # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass +from dataclasses import dataclass, field from torch import nn from torchtitan.config import JobConfig + +from torchtitan.models.moe import MoEArgs from torchtitan.protocols import BaseModelArgs from torchtitan.tools.logging import logger from torchtitan.tools.utils import has_cuda_capability @@ -34,7 +36,6 @@ class TransformerModelArgs(BaseModelArgs): use_flex_attn: bool = False attn_mask_type: str = "causal" - eos_id: int = 0 # iRoPE settings # When ``every_n_layers_nope`` is specified, NoPE (no positional embedding) is # used every n layers. Other layers uses RoPE (rotary positional embedding) and @@ -45,17 +46,11 @@ class TransformerModelArgs(BaseModelArgs): every_n_layers_nope: int | None = None fixed_attn_block_size: int = 8192 - # MoE args - moe_enabled: bool = True - num_experts: int = 8 - use_shared_expert: bool = True + # MoE + moe_args: MoEArgs = field(default_factory=MoEArgs) auto_scale_hidden_dim: bool = True # frequency of using MoE layer instead of feedforward layer in a transformer block interleave_moe_layer_step: int = 2 - # token-choice - top_k: int = 1 - use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation - load_balance_coeff: float | None = 1e-3 def update_from_config(self, job_config: JobConfig, **kwargs) -> None: seq_len = job_config.training.seq_len @@ -65,11 +60,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len - if self.use_grouped_mm and not has_cuda_capability(9, 0): + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): logger.warning( "Failed to use grouped mm, which is only supported on SM90 or later", ) - self.use_grouped_mm = False + self.moe_args.use_grouped_mm = False if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: raise NotImplementedError( @@ -112,7 +107,7 @@ def get_nparams_and_flops( nparams_sparse_active = ( nparams_moe_router + nparams_shared_expert - + nparams_experts * self.top_k // self.num_experts + + nparams_experts * self.moe_args.top_k // self.moe_args.num_experts ) logger.info( diff --git a/torchtitan/experiments/llama4/model/model.py b/torchtitan/experiments/llama4/model/model.py index 4e276efbbc..eb46a22b00 100644 --- a/torchtitan/experiments/llama4/model/model.py +++ b/torchtitan/experiments/llama4/model/model.py @@ -10,10 +10,10 @@ from torch import nn from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.models.moe import MoE from torchtitan.protocols import ModelProtocol from .args import TransformerModelArgs -from .moe import MoE def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: @@ -296,12 +296,25 @@ def __init__( self.attention = Attention(model_args, attn_use_rope, fixed_attn_block_size) # use MoE layer for every interleave_moe_layer_step FFN layers - self.moe_enabled = ( - model_args.moe_enabled - and (layer_id + 1) % model_args.interleave_moe_layer_step == 0 - ) + moe_args = model_args.moe_args + self.moe_enabled = (layer_id + 1) % model_args.interleave_moe_layer_step == 0 if self.moe_enabled: - self.moe = MoE(model_args) + dim = model_args.dim + hidden_dim = 4 * model_args.dim + ffn_dim_multiplier = model_args.ffn_dim_multiplier + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + + hidden_dim_denom = 1 + if model_args.auto_scale_hidden_dim: + hidden_dim_denom = moe_args.top_k + moe_args.num_shared_experts + + if model_args.auto_scale_hidden_dim: + hidden_dim = int(hidden_dim / hidden_dim_denom) + hidden_dim += -hidden_dim % model_args.multiple_of + + self.moe = MoE(moe_args, dim=dim, hidden_dim=hidden_dim) else: self.feed_forward = FeedForward( dim=model_args.dim, diff --git a/torchtitan/experiments/llama4/optimizer.py b/torchtitan/experiments/llama4/optimizer.py deleted file mode 100644 index 0986452fae..0000000000 --- a/torchtitan/experiments/llama4/optimizer.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn - -from torchtitan.components.ft import FTManager -from torchtitan.components.optimizer import build_optimizers, OptimizersContainer -from torchtitan.config import Optimizer as OptimizerConfig -from torchtitan.distributed import ParallelDims - - -# for MoE auxiliary-loss-free load balancing -def _update_expert_bias( - model_parts: list[nn.Module], - parallel_dims: ParallelDims, -): - dp_cp_mesh = ( - parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None - ) - # TODO: Currently this sync is blocking (thus exposed) and happens on the - # default compute stream. Need to assess if this is OK performance-wise. - for model_part in model_parts: - for transformer_block in model_part.layers.values(): - if transformer_block.moe_enabled: - moe = transformer_block.moe - if moe.load_balance_coeff is None: - return - - if dp_cp_mesh is not None: - torch.distributed.all_reduce( - moe.tokens_per_expert, group=dp_cp_mesh.get_group() - ) - - with torch.no_grad(): - expert_bias_delta = moe.load_balance_coeff * torch.sign( - moe.tokens_per_expert.mean() - moe.tokens_per_expert - ) - expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() - moe.expert_bias.add_(expert_bias_delta) - moe.tokens_per_expert.zero_() - - -def build_llama4_optimizers( - model_parts: list[nn.Module], - optimizer_config: OptimizerConfig, - parallel_dims: ParallelDims, - ft_manager: FTManager | None = None, -) -> OptimizersContainer: - optimizers = build_optimizers( - model_parts=model_parts, - optimizer_config=optimizer_config, - parallel_dims=parallel_dims, - ft_manager=ft_manager, - ) - - optimizers.register_step_pre_hook( - lambda *args, **kwargs: _update_expert_bias( - model_parts, parallel_dims=parallel_dims - ) - ) - - return optimizers diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index fb672cc4c7..a7f068c073 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -63,7 +63,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = "none" # ["none", "selective", "full"] +mode = "selective" # ["none", "selective", "full"] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy [float8] diff --git a/torchtitan/models/deepseek_v3/README.md b/torchtitan/models/deepseek_v3/README.md index 6698852b47..38742cc716 100644 --- a/torchtitan/models/deepseek_v3/README.md +++ b/torchtitan/models/deepseek_v3/README.md @@ -1,4 +1,4 @@ -# DeepSeek-V3 in TorchTitan +# DeepSeek-V3 in `torchtitan` DeepSeek-V3 is a Mixture-of-Experts (MoE) transformer model with Multi-head Latent Attention (MLA) architecture. @@ -50,11 +50,8 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml ## To be added -- Modeling - - Merge DeepSeek-V3 and Llama4 MoE common components - - Attention Layer: need to pass softmax_scale to sdpa() to support scaling - Parallelism - - Context Parallel support for DeepSeek-V3 + - Context Parallel support for DeepSeek V3 - torch.compile - Quantization - Testing diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 8243a0a84a..a39b35dfa2 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -8,10 +8,11 @@ from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader -from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers from torchtitan.models.llama3.infra.pipeline import pipeline_llama +from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import register_train_spec, TrainSpec @@ -36,10 +37,14 @@ n_layers=3, n_dense_layers=1, n_heads=16, - n_routed_experts=8, - n_shared_experts=2, - n_activated_experts=3, - route_scale=1.0, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=True, + score_before_experts=False, + ), q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, @@ -55,10 +60,14 @@ n_layers=3, n_dense_layers=1, n_heads=16, - n_routed_experts=8, - n_shared_experts=2, - n_activated_experts=3, - route_scale=1.0, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=True, + score_before_experts=False, + ), q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, @@ -76,10 +85,14 @@ n_layers=27, n_dense_layers=1, n_heads=16, - n_routed_experts=64, - n_shared_experts=2, - n_activated_experts=6, - route_scale=1.0, + moe_args=MoEArgs( + num_experts=64, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=True, + score_before_experts=False, + ), q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, @@ -95,12 +108,17 @@ n_layers=60, n_dense_layers=1, n_heads=128, - n_routed_experts=160, - n_shared_experts=2, - n_activated_experts=6, + moe_args=MoEArgs( + num_experts=160, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=True, + route_scale=16.0, + score_before_experts=False, + ), n_expert_groups=8, n_limited_groups=3, - route_scale=16.0, q_lora_rank=1536, kv_lora_rank=512, qk_nope_head_dim=128, @@ -115,13 +133,17 @@ n_layers=61, n_dense_layers=3, n_heads=128, - n_routed_experts=256, - n_shared_experts=1, - n_activated_experts=8, + moe_args=MoEArgs( + num_experts=256, + num_shared_experts=1, + top_k=8, + score_func="sigmoid", + route_norm=True, + route_scale=2.5, + score_before_experts=False, + ), n_expert_groups=8, n_limited_groups=4, - route_scale=2.5, - score_func="sigmoid", q_lora_rank=1536, kv_lora_rank=512, qk_nope_head_dim=128, @@ -139,7 +161,7 @@ model_args=deepseekv3_configs, parallelize_fn=parallelize_deepseekv3, pipelining_fn=pipeline_llama, - build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights + build_optimizers_fn=build_optimizers_with_moe_load_balancing, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_hf_tokenizer, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 532358b2da..8e289f01fb 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -17,7 +17,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims -from torchtitan.experiments.llama4.infra.expert_parallel import NoParallel +from torchtitan.distributed.expert_parallel import NoParallel from torchtitan.experiments.llama4.infra.parallelize import apply_fsdp, apply_moe_ep_tp from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp from torchtitan.tools.logging import logger diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index cd94104cdb..025a550b9b 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -7,12 +7,13 @@ # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Literal from torch import nn from torchtitan.config import JobConfig +from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import BaseModelArgs from torchtitan.tools.logging import logger from torchtitan.tools.utils import has_cuda_capability @@ -67,16 +68,13 @@ class DeepSeekV3ModelArgs(BaseModelArgs): n_dense_layers: int = 1 n_heads: int = 16 norm_eps: float = 1e-5 # eps used for RMSNorm + # MoE - n_routed_experts: int = 64 - n_shared_experts: int = 2 - n_activated_experts: int = 6 + moe_args: MoEArgs = field(default_factory=MoEArgs) + # TODO: node-limited routing is not supported yet n_expert_groups: int = 1 n_limited_groups: int = 1 - score_func: Literal["softmax", "sigmoid"] = "softmax" - route_scale: float = 1.0 - use_grouped_mm: bool = True - load_balance_coeff: float = 1e-3 + # Multi-Head Latent Attention (MLA) q_lora_rank: int = 0 kv_lora_rank: int = 512 @@ -85,6 +83,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs): v_head_dim: int = 128 use_flex_attn: bool = False attn_mask_type: str = "causal" + # yarn original_seq_len: int = 4096 rope_theta: float = 10000.0 @@ -101,11 +100,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len - if self.use_grouped_mm and not has_cuda_capability(9, 0): + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): logger.warning( "Failed to use grouped mm, which is only supported on SM90 or later", ) - self.use_grouped_mm = False + self.moe_args.use_grouped_mm = False if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: raise NotImplementedError( @@ -149,7 +148,7 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in nparams_sparse_active = ( nparams_moe_router + nparams_shared_expert - + nparams_experts * self.n_activated_experts // self.n_routed_experts + + nparams_experts * self.moe_args.top_k // self.moe_args.num_experts ) logger.info( diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 1d92c12545..cfdc794ca9 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -8,12 +8,50 @@ from typing import Tuple import torch +import torch.nn.functional as F from torch import nn + from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.models.moe import MoE from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs -from .moe import FeedForward, MoE + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float = 0.02): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) # Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 @@ -269,10 +307,14 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.attention = Attention(model_args) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.moe_enabled = layer_id >= model_args.n_dense_layers + self.moe_enabled = layer_id >= model_args.n_dense_layers if self.moe_enabled: - self.moe = MoE(model_args) + self.moe = MoE( + model_args.moe_args, + dim=model_args.dim, + hidden_dim=model_args.moe_inter_dim, + ) else: self.feed_forward = FeedForward(model_args.dim, model_args.inter_dim) diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py deleted file mode 100644 index 02a094686c..0000000000 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ /dev/null @@ -1,375 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn.functional as F -from torch import nn -from torchtitan.experiments.llama4.infra.expert_parallel import expert_parallel - -from .args import DeepSeekV3ModelArgs - - -class FeedForward(nn.Module): - """ - FeedForward module - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. - - Attributes: - w1 (Linear): Linear transformation for the first layer. - w2 (Linear): Linear transformation for the second layer. - w3 (Linear): Linear transformation for the third layer. - - """ - - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - def init_weights(self, init_std: float = 0.02): - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) - for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) - - -class GroupedExperts(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - num_experts: int, - use_grouped_mm: bool, - ): - super().__init__() - self.num_experts = num_experts - self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) - self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) - self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) - self.use_grouped_mm = use_grouped_mm - - def forward( - self, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if self.use_grouped_mm: - return GroupedExperts._run_experts_grouped_mm( - self.w1, self.w2, self.w3, x, num_tokens_per_expert - ) - else: - return GroupedExperts._run_experts_for_loop( - self.w1, self.w2, self.w3, x, num_tokens_per_expert - ) - - # TODO: keeping this for-loop implementation for comparison - # and readability, may remove later - @expert_parallel - @staticmethod - def _run_experts_for_loop( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - # NOTE: this would incur a synchronization between device and host - num_tokens_per_expert = num_tokens_per_expert.tolist() - - # side-effect code due to the usage of generate_permute_indices - num_padding = x.shape[0] - sum(num_tokens_per_expert) - - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - h = F.silu(torch.matmul(x_expert, w1[expert_idx])) - h = h * torch.matmul(x_expert, w3[expert_idx]) - h = torch.matmul(h, w2[expert_idx]) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - - # side-effect code due to the usage of generate_permute_indices - out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) - else: - # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, w1)) - h = h * torch.bmm(x, w3) - # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, w2) - - return out - - @expert_parallel - @staticmethod - def _run_experts_grouped_mm( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # grouped mm between a 2D tensor and a 3D tensor - assert x.dim() == 2 - else: - offsets = None - # fall back to regular bmm between 3D tensors - assert x.dim() == 3 - - h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) - h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) - out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) - - return out - - def init_weights(self, init_std: float): - nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) - nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) - - -class TokenChoiceTopKRouter(nn.Module): - """This class implements token-choice routing. In token-choice top-K routing, each token is - routed to top K experts based on the router scores. - - Args: - gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). - num_experts (int): Number of experts in each moe layer. - top_k (int): Number of experts each token will be routed to in token-choice routing. - use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. - """ - - def __init__( - self, - dim: int, - num_experts: int, - top_k: int, - use_sigmoid: bool = False, - route_sclaing_factor: float = 1.0, - ): - super().__init__() - - self.dim = dim - self.num_experts = num_experts - self.top_k = top_k - self.use_sigmoid = use_sigmoid - self.route_sclaing_factor = route_sclaing_factor - self.gate = nn.Linear(self.dim, self.num_experts, bias=False) - - def forward( - self, x: torch.Tensor, expert_bias: torch.Tensor | None = None - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - TODO: We haven't implement the group-based routing (node limit routing), - and currently EP is not supporting node limit routing yet. - - Args: - x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. - - Returns: - routed_input (torch.Tensor): - Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. - token_indices (torch.Tensor): - Token indices for routed_input with shape ``(bs*slen*top_k,)``. - num_tokens_per_expert (torch.Tensor): - Number of tokens assigned to each expert with shape ``(num_experts,)``. - """ - # scores shape (bs*slen, num_experts) - scores = self.gate(x) - - # By default, sigmoid or softmax is performed in float32 to avoid loss explosion - if self.use_sigmoid: - scores = torch.sigmoid(scores.to(torch.float32)) - else: - scores = F.softmax(scores.to(torch.float32), dim=1) - - # top scores shape (bs*slen, top_k) - # NOTE: The expert_bias is only used for routing. The gating value - # top_scores is still derived from the original scores. - if expert_bias is not None: - _, selected_experts_indices = torch.topk( - scores + expert_bias, k=self.top_k, dim=1 - ) - top_scores = scores.gather(dim=1, index=selected_experts_indices) - else: - top_scores, selected_experts_indices = torch.topk( - scores, k=self.top_k, dim=1 - ) - - if self.use_sigmoid: - denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 - top_scores = top_scores / denominator - - # group tokens together by expert indices from 0 to num_experts and pass that to experts forward - num_tokens_per_expert = torch.histc( - selected_experts_indices.view(-1), - bins=self.num_experts, - min=0, - max=self.num_experts, - ) - - # Reorder the token indices to match the order of the experts - # token_indices_experts_sorted shape (bs*slen*top_k,) - token_indices_experts_sorted = torch.argsort( - selected_experts_indices.view(-1), stable=True - ) - - # reorder the scores to match the order of the token indices - top_scores = top_scores.view(-1)[token_indices_experts_sorted] - token_indices_experts_sorted = token_indices_experts_sorted // self.top_k - - top_scores = ( - top_scores * self.route_sclaing_factor - ) # must multiply the scaling factor - return top_scores, token_indices_experts_sorted, num_tokens_per_expert - - def init_weights(self, init_std: float): - nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) - - -class MoE(nn.Module): - def __init__(self, model_args: DeepSeekV3ModelArgs): - - super().__init__() - dim = model_args.dim - - num_experts = model_args.n_routed_experts - hidden_dim = model_args.moe_inter_dim - top_k = model_args.n_activated_experts - route_scaling_factor = model_args.route_scale - - self.experts = GroupedExperts( - dim=dim, - hidden_dim=hidden_dim, - num_experts=num_experts, - use_grouped_mm=model_args.use_grouped_mm, - ) - self.router = TokenChoiceTopKRouter( - dim=dim, - num_experts=num_experts, - top_k=top_k, - use_sigmoid=model_args.score_func == "sigmoid", - route_sclaing_factor=route_scaling_factor, - ) - self.shared_expert = ( - # Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py#L517 - GroupedExperts( - dim=dim, - hidden_dim=hidden_dim * model_args.n_shared_experts, - num_experts=1, # Here needs to be 1 to make it equivalent to the MLP - use_grouped_mm=model_args.use_grouped_mm, - ) - if model_args.n_shared_experts > 0 - else None - ) - - # auxiliary-loss-free load balancing - self.load_balance_coeff = model_args.load_balance_coeff - if self.load_balance_coeff is not None: - assert self.load_balance_coeff > 0.0 - self.register_buffer( - "expert_bias", - torch.zeros(num_experts, dtype=torch.float32), - ) - self.register_buffer( - "tokens_per_expert", - torch.zeros(num_experts, dtype=torch.float32), - ) - else: - self.expert_bias = None - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. - - Returns: - out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. - """ - bs, slen, dim = x.shape - - # top_scores and selected_indices shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) - ( - top_scores, - token_indices, - num_tokens_per_expert, - ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) - - # tokens_per_expert will be used to update the expert bias for load balancing. - # Prevent extra local tokens accumulation on evaluation or activation recomputation. - if self.load_balance_coeff is not None and torch.is_grad_enabled(): - with torch.no_grad(): - self.tokens_per_expert.add_(num_tokens_per_expert) - # shape (bs*slen*top_k, dim) - token_indices = token_indices.reshape(-1, 1).expand(-1, dim) - - # shape (bs*slen*top_k, dim) - routed_input = torch.gather( - x.view(-1, dim), - dim=0, - index=token_indices, - ) - - # shape (bs*slen*top_k, dim) - routed_output = self.experts(routed_input, num_tokens_per_expert) - - routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( - x.dtype - ) - - # shared expert - if self.shared_expert is not None: - out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( - bs * slen, dim - ) - else: - out = torch.zeros_like(x.reshape(bs * slen, dim)) - - # Accumulate multiple expert results becase each token can be routed to multiple experts - out = out.scatter_add(dim=0, index=token_indices, src=routed_output) - out = out.reshape(bs, slen, dim) - return out - - def init_weights( - self, - init_std: float, - buffer_device: torch.device, - ): - self.experts.init_weights(init_std) - self.router.init_weights(init_std) - if self.shared_expert is not None: - self.shared_expert.init_weights(init_std) - - if self.load_balance_coeff is not None: - with torch.device(buffer_device): - self.expert_bias = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) - self.tokens_per_expert = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 64d080126b..093f89a18b 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -65,7 +65,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = "none" # ["none", "selective", "full"] +mode = "selective" # ["none", "selective", "full"] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy [float8] diff --git a/torchtitan/models/llama3/infra/pipeline.py b/torchtitan/models/llama3/infra/pipeline.py index db3d6465e6..8741b2eef4 100644 --- a/torchtitan/models/llama3/infra/pipeline.py +++ b/torchtitan/models/llama3/infra/pipeline.py @@ -19,7 +19,7 @@ from torchtitan.components.loss import LossFunction from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims -from torchtitan.distributed.pipeline import ( +from torchtitan.distributed.pipeline_parallel import ( build_pipeline_schedule, generate_llm_fqn_per_model_part, pipeline_module_split, diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/models/moe.py similarity index 80% rename from torchtitan/experiments/llama4/model/moe.py rename to torchtitan/models/moe.py index 73a5d0a205..b8d777306c 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/models/moe.py @@ -4,13 +4,31 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Literal + import torch import torch.nn.functional as F from torch import nn -from ..infra.expert_parallel import expert_parallel +from torchtitan.distributed.expert_parallel import expert_parallel + + +@dataclass +class MoEArgs: + num_experts: int = 8 + num_shared_experts: int = 1 -from .args import TransformerModelArgs + # router + score_func: Literal["softmax", "sigmoid"] = "sigmoid" + route_norm: bool = False + route_scale: float = 1.0 + score_before_experts: bool = True + + # token-choice + top_k: int = 1 + use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation + load_balance_coeff: float | None = 1e-3 class GroupedExperts(nn.Module): @@ -142,13 +160,17 @@ def __init__( dim: int, num_experts: int, top_k: int, - use_sigmoid: bool = False, + score_func: Literal["softmax", "sigmoid"], + route_norm: bool, + route_scale: float, ): super().__init__() self.gate = nn.Linear(dim, num_experts, bias=False) self.num_experts = num_experts self.top_k = top_k - self.use_sigmoid = use_sigmoid + self.score_func = score_func + self.route_norm = route_norm + self.route_scale = route_scale def forward( self, x: torch.Tensor, expert_bias: torch.Tensor | None = None @@ -169,10 +191,12 @@ def forward( scores = self.gate(x) # By default, sigmoid or softmax is performed in float32 to avoid loss explosion - if self.use_sigmoid: + if self.score_func == "sigmoid": scores = torch.sigmoid(scores.to(torch.float32)) - else: + elif self.score_func == "softmax": scores = F.softmax(scores.to(torch.float32), dim=1) + else: + raise NotImplementedError(f"Unknown score function {self.score_function}") # top scores shape (bs*slen, top_k) # NOTE: The expert_bias is only used for routing. The gating value @@ -187,6 +211,11 @@ def forward( scores, k=self.top_k, dim=1 ) + if self.score_func == "sigmoid" and self.route_norm: + denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 + top_scores = top_scores / denominator + top_scores = top_scores * self.route_scale + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward num_tokens_per_expert = torch.histc( selected_experts_indices.view(-1), @@ -194,10 +223,13 @@ def forward( min=0, max=self.num_experts, ) + + # Reorder the token indices to match the order of the experts # token_indices_experts_sorted shape (bs*slen*top_k,) token_indices_experts_sorted = torch.argsort( selected_experts_indices.view(-1), stable=True ) + top_scores = top_scores.view(-1)[token_indices_experts_sorted] token_indices_experts_sorted = token_indices_experts_sorted // self.top_k @@ -208,50 +240,43 @@ def init_weights(self, init_std: float): class MoE(nn.Module): - def __init__(self, model_args: TransformerModelArgs): + def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): super().__init__() - dim = model_args.dim - hidden_dim = 4 * model_args.dim - ffn_dim_multiplier = model_args.ffn_dim_multiplier - hidden_dim = int(2 * hidden_dim / 3) - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - - num_experts = model_args.num_experts - - hidden_dim_denom = 1 - if model_args.auto_scale_hidden_dim: - hidden_dim_denom = model_args.top_k + int(model_args.use_shared_expert) - - if model_args.auto_scale_hidden_dim: - hidden_dim = int(hidden_dim / hidden_dim_denom) - hidden_dim += -hidden_dim % model_args.multiple_of + num_experts = moe_args.num_experts self.experts = GroupedExperts( dim=dim, hidden_dim=hidden_dim, num_experts=num_experts, - use_grouped_mm=model_args.use_grouped_mm, + use_grouped_mm=moe_args.use_grouped_mm, ) self.router = TokenChoiceTopKRouter( - dim=dim, num_experts=num_experts, top_k=model_args.top_k + dim=dim, + num_experts=num_experts, + top_k=moe_args.top_k, + score_func=moe_args.score_func, + route_norm=moe_args.route_norm, + route_scale=moe_args.route_scale, ) self.shared_expert = ( GroupedExperts( dim=dim, - hidden_dim=hidden_dim, + # TODO: if it doesn't use GroupedExperts.num_experts + # we can just use normal FeedForward + hidden_dim=hidden_dim * moe_args.num_shared_experts, num_experts=1, - use_grouped_mm=model_args.use_grouped_mm, + use_grouped_mm=moe_args.use_grouped_mm, ) - if model_args.use_shared_expert + if moe_args.num_shared_experts > 0 else None ) + self.score_before_experts = moe_args.score_before_experts # define fields for auxiliary-loss-free load balancing (https://arxiv.org/abs/2408.15664) # NOTE: tokens_per_expert is accumulated in the model forward pass. # expert_bias is updated outside the model in an optimzer step pre hook # to work with gradient accumulation. - self.load_balance_coeff = model_args.load_balance_coeff + self.load_balance_coeff = moe_args.load_balance_coeff if self.load_balance_coeff is not None: assert self.load_balance_coeff > 0.0 self.register_buffer( @@ -284,8 +309,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) # tokens_per_expert will be used to update the expert bias for load balancing. - # Prevent extra local tokens accumulation on evaluation or activation recomputation. - if self.load_balance_coeff is not None and torch.is_grad_enabled(): + # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- + # first in the forward pass, and then in the backward pass. However, this has no + # effect on the expert bias update thanks to the torch.sign() operator. + if self.load_balance_coeff is not None: with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) @@ -298,13 +325,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dim=0, index=token_indices, ) - routed_input = (routed_input.to(torch.float32) * top_scores.reshape(-1, 1)).to( - x.dtype - ) + + if self.score_before_experts: + routed_input = ( + routed_input.to(torch.float32) * top_scores.reshape(-1, 1) + ).to(x.dtype) # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_tokens_per_expert) + if not self.score_before_experts: + routed_output = ( + routed_output.to(torch.float32) * top_scores.reshape(-1, 1) + ).to(x.dtype) + # shared expert if self.shared_expert is not None: out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( diff --git a/torchtitan/train.py b/torchtitan/train.py index 0955bbb2cb..10abf8fee4 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -478,11 +478,7 @@ def train_step( pp_mesh=( parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None ), - ep_dense_params_mesh_ndim=( - parallel_dims.dense_params_mesh_ndim - if parallel_dims.ep_enabled - else None - ), + ep_enabled=parallel_dims.ep_enabled, ) self.checkpointer.maybe_wait_for_staging() self.optimizers.step()