4040 ],
4141)
4242@pytest .mark .parametrize ("compile" , [False , True ])
43- def test_moe_float8_training (target_fqns : list [str ], compile : bool ):
44- # Set token group alignment size to 16. This is required so that
45- # each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
46- # has the contraction dim be divisible by 16. 16 byte alignment is required
47- # for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
48- set_token_group_alignment_size_m (16 )
49- model_args = MoEArgs (
50- num_experts = 8 ,
51- )
52- init_std = 0.02
53- device = torch .device ("cuda" )
54-
55- # reference bf16 MoE
56- dim , hidden_dim = 5120 , 8192
57- ref_model = MoE (model_args , dim , hidden_dim ).to (torch .bfloat16 ).cuda ()
58- torch .manual_seed (42 )
59- ref_model .init_weights (init_std , device )
60-
61- # target MoE for testing conversion
62- model = copy .deepcopy (ref_model )
63-
64- # assert starting params are identical for both models
65- for param1 , param2 in zip (model .parameters (), ref_model .parameters ()):
66- assert torch .equal (param1 , param2 )
67-
68- # convert MoE to float8 training
69- def moe_module_filter_fn (mod : nn .Module , cur_fqn : str ) -> bool :
70- for target_fqn in target_fqns :
71- if target_fqn in cur_fqn :
72- return True
73- return False
74-
75- # quantize test model
76- config = MoETrainingConfig ()
77- quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
78-
79- # validate that only the experts were converted
80- _validate_model_conversion (
81- model ,
82- target_fqns = target_fqns ,
83- )
84- if compile :
85- # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
86- model = torch .compile (model , fullgraph = False )
87- ref_model = torch .compile (ref_model , fullgraph = False )
88-
89- # inputs
90- batch , seq = 8 , 2048
91- ref_x = torch .randn (
92- batch , seq , dim , dtype = torch .bfloat16 , requires_grad = True , device = device
93- )
94- x = ref_x .detach ().clone ().requires_grad_ (True )
95-
96- # forward pass
97- ref_out = ref_model (ref_x )
98- out = model (x )
99-
100- # validate output
101- out_sqnr = compute_error (out , ref_out )
102- min_out_sqnr = 29.0
103- assert out_sqnr .item () >= min_out_sqnr , (
104- f"SQNR must be >= { min_out_sqnr } , got { out_sqnr .item ()} ."
105- )
106-
107- # compute loss
108- labels = torch .ones_like (ref_out )
109- ref_loss = F .mse_loss (ref_out , labels )
110- out_loss = F .mse_loss (out , labels )
111-
112- # backward pass
113- ref_loss .backward ()
114- out_loss .backward ()
115-
116- # validate input gradient
117- input_grad_sqnr = compute_error (x .grad , ref_x .grad )
118- min_input_grad_sqnr = 29.0
119- assert input_grad_sqnr .item () >= min_input_grad_sqnr , (
120- f"SQNR must be >= { min_input_grad_sqnr } , got { input_grad_sqnr .item ()} ."
121- )
122-
123- # validate param gradients
124- min_param_grad_sqnr = 23.0
125- for param1 , param2 in zip (model .parameters (), ref_model .parameters ()):
126- param_grad_sqnr = compute_error (param1 .grad , param2 .grad )
127- assert param_grad_sqnr .item () >= min_param_grad_sqnr , (
128- f"SQNR must be >= { min_param_grad_sqnr } , got { param_grad_sqnr .item ()} ."
129- )
130-
131-
13243@pytest .mark .parametrize (
133- "target_fqns " ,
44+ "recipe_config " ,
13445 [
135- ["experts" ],
136- ["does.not.exist" ],
46+ # {"recipe": MoEScalingType.FP8_ROWWISE, "group_alignment_size": 16, "min_out_sqnr": 29.0, "min_input_grad_sqnr": 29.0, "min_param_grad_sqnr": 23.0},
47+ {
48+ "recipe" : MoEScalingType .MXFP8 ,
49+ "group_alignment_size" : 32 ,
50+ "min_out_sqnr" : 28.0 ,
51+ "min_input_grad_sqnr" : 29.0 ,
52+ "min_param_grad_sqnr" : 21.0 ,
53+ },
13754 ],
13855)
139- @pytest .mark .parametrize ("compile" , [False , True ])
140- def test_moe_mxfp8_training (target_fqns : list [str ], compile : bool ):
141- block_size = 32
142-
143- # Token groups must be divisible by 32 for mxfp8
144- set_token_group_alignment_size_m (block_size )
145-
56+ def test_moe_training (target_fqns : list [str ], compile : bool , recipe_config : dict ):
57+ (
58+ recipe ,
59+ group_alignment_size ,
60+ min_out_sqnr ,
61+ min_input_grad_sqnr ,
62+ min_param_grad_sqnr ,
63+ ) = (
64+ recipe_config ["recipe" ],
65+ recipe_config ["group_alignment_size" ],
66+ recipe_config ["min_out_sqnr" ],
67+ recipe_config ["min_input_grad_sqnr" ],
68+ recipe_config ["min_param_grad_sqnr" ],
69+ )
70+ # Set token group alignment size. This is required so that
71+ # each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
72+ # has the contraction dim be divisible by 16. 16 byte alignment is required
73+ # for the slowest moving dim (stride 1).
74+ set_token_group_alignment_size_m (group_alignment_size )
14675 model_args = MoEArgs (
14776 num_experts = 8 ,
14877 )
@@ -170,15 +99,14 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
17099 return False
171100
172101 # quantize test model
173- config = MoETrainingConfig (scaling_type = MoEScalingType . MXFP8 )
102+ config = MoETrainingConfig (scaling_type = recipe )
174103 quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
175104
176105 # validate that only the experts were converted
177106 _validate_model_conversion (
178107 model ,
179108 target_fqns = target_fqns ,
180109 )
181-
182110 if compile :
183111 # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
184112 model = torch .compile (model , fullgraph = False )
@@ -197,7 +125,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
197125
198126 # validate output
199127 out_sqnr = compute_error (out , ref_out )
200- min_out_sqnr = 28.0
201128 assert out_sqnr .item () >= min_out_sqnr , (
202129 f"SQNR must be >= { min_out_sqnr } , got { out_sqnr .item ()} ."
203130 )
@@ -213,13 +140,11 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
213140
214141 # validate input gradient
215142 input_grad_sqnr = compute_error (x .grad , ref_x .grad )
216- min_input_grad_sqnr = 30.0
217143 assert input_grad_sqnr .item () >= min_input_grad_sqnr , (
218144 f"SQNR must be >= { min_input_grad_sqnr } , got { input_grad_sqnr .item ()} ."
219145 )
220146
221147 # validate param gradients
222- min_param_grad_sqnr = 21.0
223148 for param1 , param2 in zip (model .parameters (), ref_model .parameters ()):
224149 param_grad_sqnr = compute_error (param1 .grad , param2 .grad )
225150 assert param_grad_sqnr .item () >= min_param_grad_sqnr , (
0 commit comments