1414import argparse
1515import copy
1616import os
17- import statistics
18- from time import perf_counter_ns
1917
2018import pytest
2119import torch
2422from torch .distributed ._composable .fsdp import fully_shard
2523from torch .nn import functional as F
2624
25+ from benchmarks .prototype .moe_training .utils import (
26+ bench_fwd_bwd_microseconds ,
27+ profile_fn ,
28+ )
29+
2730# this feature requires CUDA and SM89+
2831if not torch .cuda .is_available () or torch .cuda .get_device_capability () < (8 , 9 ):
2932 pytest .skip (
4851 )
4952
5053
51- def bench_moe_float8_training_fsdp (enable_profile = False ):
54+ def bench_moe_float8_training_fsdp (
55+ recipe_name : str , enable_profile : bool , use_compile : bool
56+ ):
5257 assert torch .cuda .is_available ()
58+ assert recipe_name in ["fp8_rowwise" , "mxfp8" ]
59+ recipe = MoEScalingType [recipe_name .upper ()]
5360
5461 # setup distributed for fsdp
5562 setup_distributed ()
@@ -62,15 +69,19 @@ def bench_moe_float8_training_fsdp(enable_profile=False):
6269 init_std = 0.02
6370 device = torch .device ("cuda" )
6471
65- # reference bf16 MoE
66- dim , hidden_dim = 5120 , 4 * 5120
72+ # reference bf16 MoE using llama4 shapes
73+ dim , hidden_dim = 5120 , 8192
6774 ref_model = MoE (model_args , dim , hidden_dim ).to (torch .bfloat16 ).cuda ()
6875 torch .manual_seed (42 )
6976 ref_model .init_weights (init_std , device )
7077
7178 # target MoE for testing conversion
7279 model = copy .deepcopy (ref_model )
7380
81+ # Token group alignment size must be 16 for fp8 rowwise training
82+ alignment_size = 32 if recipe == MoEScalingType .MXFP8 else 16
83+ set_token_group_alignment_size_m (alignment_size )
84+
7485 # assert starting params are identical for both models
7586 for param1 , param2 in zip (model .parameters (), ref_model .parameters ()):
7687 assert torch .equal (param1 , param2 )
@@ -83,15 +94,15 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
8394 return False
8495
8596 # quantize test model
86- config = MoETrainingConfig (scaling_type = MoEScalingType . FP8_ROWWISE )
97+ config = MoETrainingConfig (scaling_type = recipe )
8798 quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
8899
89100 # FSDP2
90101 fully_shard (model )
91102 fully_shard (ref_model )
92103
93104 # inputs (llama4 shapes)
94- batch , seq = 1 , 8192
105+ batch , seq = 1 , 16640
95106 ref_x = torch .randn (
96107 batch , seq , dim , dtype = torch .bfloat16 , requires_grad = True , device = device
97108 )
@@ -104,70 +115,34 @@ def warmup(model, input):
104115 loss .backward ()
105116 torch .cuda .synchronize ()
106117
107- def bench_fn_microseconds (model , input ):
108- labels = torch .ones_like (input )
109- times = []
110- for _ in range (10 ):
111- start_ns = perf_counter_ns ()
112- out = model (input )
113- loss = F .mse_loss (out , labels )
114- loss .backward ()
115- torch .cuda .synchronize ()
116- end_ns = perf_counter_ns ()
117- duration_us = (end_ns - start_ns ) / 1000
118- times .append (duration_us )
119- return statistics .median (times )
120-
121- def profile_fn (model , input , profile_name = "profile" ):
122- # Only profile on rank 0
123- if torch .distributed .get_rank () == 0 :
124- labels = torch .ones_like (input )
125- wait , warmup , active = 1 , 3 , 1
126- total_steps = wait + warmup + active
127- with torch .profiler .profile (
128- activities = [
129- torch .profiler .ProfilerActivity .CPU ,
130- torch .profiler .ProfilerActivity .CUDA ,
131- ],
132- schedule = torch .profiler .schedule (
133- wait = wait , warmup = warmup , active = active , repeat = 0
134- ),
135- record_shapes = True ,
136- with_stack = True ,
137- ) as prof :
138- for _ in range (total_steps ):
139- out = model (input )
140- loss = F .mse_loss (out , labels )
141- loss .backward ()
142- prof .step ()
143-
144- # Save profiler results
145- prof .export_chrome_trace (f"{ profile_name } .json" )
146- print (f"Saved: { profile_name } .json" )
147-
148- # Compile models
149- ref_model = torch .compile (ref_model , fullgraph = False )
150- model = torch .compile (model , fullgraph = False )
151-
152- print ("Benchmarking MoE with FSDP2 using bf16 training" )
153- warmup (ref_model , ref_x )
154- bf16_us = bench_fn_microseconds (ref_model , ref_x )
155- print (f"bf16 time: { bf16_us } us" )
156- if enable_profile :
157- print ("Profiling bf16 model" )
158- profile_fn (ref_model , ref_x , profile_name = "bf16_profile" )
118+ labels = torch .ones_like (x )
159119
160- # Token group alignment size must be 16 for fp8 rowwise training
161- set_token_group_alignment_size_m (16 )
162-
163- print ("Benchmarking MoE with FSDP2 using fp8 rowwise training" )
164- warmup (model , x )
165- fp8_us = bench_fn_microseconds (model , x )
166- print (f"fp8 time: { fp8_us } us" )
120+ # TODO: bench with fullgraph=True if/when it is supported
121+ bf16_us = bench_fwd_bwd_microseconds (
122+ ref_model ,
123+ ref_x ,
124+ labels = labels ,
125+ use_compile = use_compile ,
126+ fullgraph = False ,
127+ )
128+ print (f"BF16 time: { bf16_us } us" )
129+ if enable_profile :
130+ print ("Profiling bf16 training" )
131+ profile_fn (ref_model , ref_x , labels = labels , profile_name = "bf16_profile" )
132+
133+ scaled_us = bench_fwd_bwd_microseconds (
134+ model ,
135+ x ,
136+ labels = labels ,
137+ use_compile = use_compile ,
138+ fullgraph = False ,
139+ )
140+ print (f"Scaled time: { scaled_us } us" )
167141 if enable_profile :
168- print ("Profiling fp8 model " )
169- profile_fn (model , x , profile_name = "fp8_profile " )
142+ print ("Profiling quantized training " )
143+ profile_fn (model , x , labels = labels , profile_name = f" { recipe_name } _profile " )
170144
145+ print (f"Speedup: { bf16_us / scaled_us :.3f} x" )
171146 dist .destroy_process_group ()
172147
173148
@@ -185,5 +160,15 @@ def setup_distributed():
185160 action = "store_true" ,
186161 help = "Enable PyTorch profiling and save results to file" ,
187162 )
163+ parser .add_argument ("--recipe" , type = str , help = "[fp8_rowwise, mxfp8]" )
164+ parser .add_argument (
165+ "--compile" ,
166+ action = "store_true" ,
167+ help = "use torch.compile" ,
168+ )
188169 args = parser .parse_args ()
189- bench_moe_float8_training_fsdp (enable_profile = args .profile )
170+ bench_moe_float8_training_fsdp (
171+ recipe_name = args .recipe ,
172+ enable_profile = args .profile ,
173+ use_compile = args .compile ,
174+ )
0 commit comments