2222
2323# this test requires torchtitan
2424try :
25- from torchtitan .experiments . llama4 . infra .expert_parallel import (
25+ from torchtitan .distributed .expert_parallel import (
2626 set_token_group_alignment_size_m ,
2727 )
28- from torchtitan .experiments .llama4 .model .args import TransformerModelArgs
29- from torchtitan .experiments .llama4 .model .moe import MoE
28+ from torchtitan .models .moe import MoE , MoEArgs
3029except ImportError :
3130 pytest .skip (
3231 "torchtitan not installed, skipping MoE tests." , allow_module_level = True
@@ -47,16 +46,15 @@ def test_moe_float8_training(target_fqns: list[str], compile: bool):
4746 # has the contraction dim be divisible by 16. 16 byte alignment is required
4847 # for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
4948 set_token_group_alignment_size_m (16 )
50- model_args = TransformerModelArgs (
51- moe_enabled = True ,
49+ model_args = MoEArgs (
5250 num_experts = 8 ,
53- dim = 256 ,
5451 )
5552 init_std = 0.02
5653 device = torch .device ("cuda" )
5754
5855 # reference bf16 MoE
59- ref_model = MoE (model_args ).to (torch .bfloat16 ).cuda ()
56+ dim , hidden_dim = 5120 , 4 * 5120
57+ ref_model = MoE (model_args , dim , hidden_dim ).to (torch .bfloat16 ).cuda ()
6058 torch .manual_seed (42 )
6159 ref_model .init_weights (init_std , device )
6260
@@ -75,22 +73,21 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
7573 return False
7674
7775 # quantize test model
78- config = MoETrainingConfig (scaling_type = MoEScalingType . FP8_ROWWISE )
76+ config = MoETrainingConfig ()
7977 quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
8078
8179 # validate that only the experts were converted
8280 _validate_model_conversion (
8381 model ,
8482 target_fqns = target_fqns ,
8583 )
86-
8784 if compile :
8885 # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
8986 model = torch .compile (model , fullgraph = False )
9087 ref_model = torch .compile (ref_model , fullgraph = False )
9188
9289 # inputs
93- batch , seq , dim = 8 , 2048 , 256
90+ batch , seq = 8 , 2048
9491 ref_x = torch .randn (
9592 batch , seq , dim , dtype = torch .bfloat16 , requires_grad = True , device = device
9693 )
@@ -145,18 +142,15 @@ def test_moe_mxfp8_training(target_fqns: list[str]):
145142 # Token groups must be divisible by 32 for mxfp8
146143 set_token_group_alignment_size_m (block_size )
147144
148- model_args = TransformerModelArgs (
149- moe_enabled = True ,
145+ model_args = MoEArgs (
150146 num_experts = 8 ,
151- dim = 256 ,
152- multiple_of = block_size ,
153- ffn_dim_multiplier = 1.0 ,
154147 )
155148 init_std = 0.02
156149 device = torch .device ("cuda" )
157150
158151 # reference bf16 MoE
159- ref_model = MoE (model_args ).to (torch .bfloat16 ).cuda ()
152+ dim , hidden_dim = 256 , 4 * 256
153+ ref_model = MoE (model_args , dim , hidden_dim ).to (torch .bfloat16 ).cuda ()
160154 torch .manual_seed (42 )
161155 ref_model .init_weights (init_std , device )
162156
@@ -185,7 +179,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
185179 )
186180
187181 # inputs
188- batch , seq , dim = 8 , 2048 , 256
182+ batch , seq = 8 , 2048
189183 ref_x = torch .randn (
190184 batch , seq , dim , dtype = torch .bfloat16 , requires_grad = True , device = device
191185 )
0 commit comments