55# LICENSE file in the root directory of this source tree.
66
77
8- from dataclasses import dataclass
8+ from dataclasses import dataclass , field
99
1010from torch import nn
1111
1212from torchtitan .config import JobConfig
13+
14+ from torchtitan .models .moe import MoEArgs
1315from torchtitan .protocols import BaseModelArgs
1416from torchtitan .tools .logging import logger
1517from torchtitan .tools .utils import has_cuda_capability
@@ -34,7 +36,6 @@ class TransformerModelArgs(BaseModelArgs):
3436
3537 use_flex_attn : bool = False
3638 attn_mask_type : str = "causal"
37- eos_id : int = 0
3839 # iRoPE settings
3940 # When ``every_n_layers_nope`` is specified, NoPE (no positional embedding) is
4041 # used every n layers. Other layers uses RoPE (rotary positional embedding) and
@@ -45,17 +46,11 @@ class TransformerModelArgs(BaseModelArgs):
4546 every_n_layers_nope : int | None = None
4647 fixed_attn_block_size : int = 8192
4748
48- # MoE args
49- moe_enabled : bool = True
50- num_experts : int = 8
51- use_shared_expert : bool = True
49+ # MoE
50+ moe_args : MoEArgs = field (default_factory = MoEArgs )
5251 auto_scale_hidden_dim : bool = True
5352 # frequency of using MoE layer instead of feedforward layer in a transformer block
5453 interleave_moe_layer_step : int = 2
55- # token-choice
56- top_k : int = 1
57- use_grouped_mm : bool = True # grouped mm or for-loop for the experts computation
58- load_balance_coeff : float | None = 1e-3
5954
6055 def update_from_config (self , job_config : JobConfig , ** kwargs ) -> None :
6156 seq_len = job_config .training .seq_len
@@ -65,11 +60,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
6560 )
6661 self .max_seq_len = seq_len
6762
68- if self .use_grouped_mm and not has_cuda_capability (9 , 0 ):
63+ if self .moe_args . use_grouped_mm and not has_cuda_capability (9 , 0 ):
6964 logger .warning (
7065 "Failed to use grouped mm, which is only supported on SM90 or later" ,
7166 )
72- self .use_grouped_mm = False
67+ self .moe_args . use_grouped_mm = False
7368
7469 if job_config .parallelism .context_parallel_degree > 1 and self .use_flex_attn :
7570 raise NotImplementedError (
@@ -112,7 +107,7 @@ def get_nparams_and_flops(
112107 nparams_sparse_active = (
113108 nparams_moe_router
114109 + nparams_shared_expert
115- + nparams_experts * self .top_k // self .num_experts
110+ + nparams_experts * self .moe_args . top_k // self . moe_args .num_experts
116111 )
117112
118113 logger .info (
0 commit comments