3535@dataclass (frozen = True )
3636class ExperimentConfig :
3737 high_precision_dtype : torch .dtype
38- A_shape : tuple [int ]
39- B_shape : tuple [int ]
38+ MNKG : tuple [int ]
4039 recipe : MoEScalingType
4140
4241
4342@dataclass (frozen = True )
4443class ExperimentResult :
45- bf16_e2e_us : float
46- scaled_e2e_us : float
47- scaled_e2e_speedup : float
44+ bf16_fwd_bwd_us : float
45+ scaled_fwd_bwd_us : float
46+ scaled_fwd_bwd_speedup : float
4847 bf16_fwd_us : float
4948 scaled_fwd_us : float
5049 scaled_fwd_speedup : float
@@ -57,22 +56,46 @@ class Experiment:
5756
5857
5958def get_configs () -> List [ExperimentConfig ]:
60- # Llama4 shapes
61- A_shapes = [(16640 , 5120 )]
62- B_shapes = [(1 , 8192 , 5120 ), (4 , 8192 , 5120 ), (16 , 8192 , 5120 ), (64 , 8192 , 5120 )]
59+ MNKG_list = [
60+ # Llama4 16e with various experts per device (i.e., different EP degrees)
61+ (16384 , 8192 , 5120 , 1 ),
62+ (16384 , 8192 , 5120 , 2 ),
63+ (16384 , 8192 , 5120 , 4 ),
64+ (16384 , 8192 , 5120 , 8 ),
65+ (128000 , 8192 , 5120 , 1 ),
66+ (128000 , 8192 , 5120 , 2 ),
67+ (128000 , 8192 , 5120 , 4 ),
68+ (128000 , 8192 , 5120 , 8 ),
69+ # DSV3 236B with various experts per device (i.e., different EP degrees)
70+ (16384 , 1536 , 5120 , 1 ),
71+ (16384 , 1536 , 5120 , 2 ),
72+ (16384 , 1536 , 5120 , 4 ),
73+ (16384 , 1536 , 5120 , 8 ),
74+ (128000 , 1536 , 5120 , 1 ),
75+ (128000 , 1536 , 5120 , 2 ),
76+ (128000 , 1536 , 5120 , 4 ),
77+ (128000 , 1536 , 5120 , 8 ),
78+ # DSV3 671B with various experts per device (i.e., different EP degrees)
79+ (16384 , 2048 , 7168 , 1 ),
80+ (16384 , 2048 , 7168 , 2 ),
81+ (16384 , 2048 , 7168 , 4 ),
82+ (16384 , 2048 , 7168 , 8 ),
83+ (128000 , 2048 , 7168 , 1 ),
84+ (128000 , 2048 , 7168 , 2 ),
85+ (128000 , 2048 , 7168 , 4 ),
86+ (128000 , 2048 , 7168 , 8 ),
87+ ]
6388 recipes = [MoEScalingType .FP8_ROWWISE , MoEScalingType .MXFP8 ]
6489 high_precision_dtypes = [torch .bfloat16 ]
6590 configs = []
66- for A_shape , B_shape , recipe , high_precision_dtype in itertools .product (
67- A_shapes ,
68- B_shapes ,
91+ for MNKG , recipe , high_precision_dtype in itertools .product (
92+ MNKG_list ,
6993 recipes ,
7094 high_precision_dtypes ,
7195 ):
7296 configs .append (
7397 ExperimentConfig (
74- A_shape = A_shape ,
75- B_shape = B_shape ,
98+ MNKG = MNKG ,
7699 recipe = recipe ,
77100 high_precision_dtype = high_precision_dtype ,
78101 )
@@ -83,15 +106,17 @@ def get_configs() -> List[ExperimentConfig]:
83106def run_experiment (
84107 config : ExperimentConfig , args : argparse .Namespace
85108) -> ExperimentResult :
109+ total_M , N , K , G = config .MNKG
110+
86111 # define test inputs
87112 A = torch .randn (
88- * config . A_shape ,
113+ ( total_M , K ) ,
89114 dtype = config .high_precision_dtype ,
90115 device = device ,
91116 requires_grad = True ,
92117 )
93118 B_t = torch .randn (
94- * config . B_shape ,
119+ ( G , N , K ) ,
95120 dtype = config .high_precision_dtype ,
96121 device = device ,
97122 requires_grad = True ,
@@ -102,17 +127,15 @@ def run_experiment(
102127 # that occurs in the backward pass of the differentiable scaled grouped mm.
103128 # - the transposed tensor in col-major format with groups along the row dimension,
104129 # which represents the right operand.
105- n_groups = config .B_shape [0 ]
106- Mg = A .shape [0 ]
107130 token_group_alignment_size = 32 if config .recipe == MoEScalingType .MXFP8 else 16
108- offs = generate_jagged_offs (n_groups , Mg , multiple_of = token_group_alignment_size )
131+ offs = generate_jagged_offs (G , total_M , multiple_of = token_group_alignment_size )
109132
110133 labels = torch .ones (
111134 (A .shape [0 ], B_t .shape [- 1 ]), device = device , dtype = torch .bfloat16
112135 )
113136
114- # E2E bf16 benchmark + profiling
115- bf16_e2e_us = bench_fwd_bwd_microseconds (
137+ # fwd_bwd bf16 benchmark + profiling
138+ bf16_fwd_bwd_us = bench_fwd_bwd_microseconds (
116139 torch ._grouped_mm ,
117140 A ,
118141 B_t ,
@@ -133,8 +156,8 @@ def run_experiment(
133156 profile_name = "bf16_profile" ,
134157 )
135158
136- # E2E scaled benchmark + profiling
137- scaled_e2e_us = bench_fwd_bwd_microseconds (
159+ # fwd_bwd scaled benchmark + profiling
160+ scaled_fwd_bwd_us = bench_fwd_bwd_microseconds (
138161 _scaled_grouped_mm ,
139162 A ,
140163 B_t ,
@@ -177,9 +200,9 @@ def run_experiment(
177200 )
178201
179202 return ExperimentResult (
180- bf16_e2e_us = round (bf16_e2e_us , 3 ),
181- scaled_e2e_us = round (scaled_e2e_us , 3 ),
182- scaled_e2e_speedup = round (bf16_e2e_us / scaled_e2e_us , 3 ),
203+ bf16_fwd_bwd_us = round (bf16_fwd_bwd_us , 3 ),
204+ scaled_fwd_bwd_us = round (scaled_fwd_bwd_us , 3 ),
205+ scaled_fwd_bwd_speedup = round (bf16_fwd_bwd_us / scaled_fwd_bwd_us , 3 ),
183206 bf16_fwd_us = round (bf16_fwd_us , 3 ),
184207 scaled_fwd_us = round (scaled_fwd_us , 3 ),
185208 scaled_fwd_speedup = round (bf16_fwd_us / scaled_fwd_us , 3 ),
@@ -188,28 +211,24 @@ def run_experiment(
188211
189212def print_results (experiments : List [Experiment ]):
190213 headers = [
191- "A_shape" ,
192- "B_shape" ,
214+ "M,N,K,G" ,
193215 "recipe" ,
194- "bf16_e2e_us " ,
195- "scaled_e2e_us " ,
196- "scaled_e2e_speedup " ,
216+ "bf16_fwd_bwd_us " ,
217+ "scaled_fwd_bwd_us " ,
218+ "scaled_fwd_bwd_speedup " ,
197219 "bf16_fwd_us" ,
198220 "scaled_fwd_us" ,
199221 "scaled_fwd_speedup" ,
200222 ]
201223 rows = []
202224 for experiment in experiments :
203- A_shape = f"({ experiment .config .A_shape [0 ]} , { experiment .config .A_shape [1 ]} )"
204- B_shape = f"({ experiment .config .B_shape [0 ]} , { experiment .config .B_shape [1 ]} , { experiment .config .B_shape [2 ]} )"
205225 rows .append (
206226 [
207- A_shape ,
208- B_shape ,
227+ str (experiment .config .MNKG ),
209228 experiment .config .recipe ,
210- experiment .result .bf16_e2e_us ,
211- experiment .result .scaled_e2e_us ,
212- f"{ experiment .result .scaled_e2e_speedup } x" ,
229+ experiment .result .bf16_fwd_bwd_us ,
230+ experiment .result .scaled_fwd_bwd_us ,
231+ f"{ experiment .result .scaled_fwd_bwd_speedup } x" ,
213232 experiment .result .bf16_fwd_us ,
214233 experiment .result .scaled_fwd_us ,
215234 f"{ experiment .result .scaled_fwd_speedup } x" ,
0 commit comments