@@ -35,8 +35,8 @@ class ExperimentConfig:
3535@dataclass (frozen = True )
3636class ExperimentResult :
3737 bf16_us : float
38- fp8_us : float
39- fp8_speedup : float
38+ scaled_us : float
39+ scaled_speedup : float
4040
4141
4242@dataclass (frozen = True )
@@ -48,8 +48,8 @@ class Experiment:
4848def get_configs () -> List [ExperimentConfig ]:
4949 # Llama4 shapes
5050 A_shapes = [(16640 , 5120 )]
51- B_shapes = [(1 , 8192 , 5120 ), ( 16 , 8192 , 5120 ), ( 128 , 8192 , 5120 )]
52- recipes = [MoEScalingType .FP8_ROWWISE ]
51+ B_shapes = [(16 , 8192 , 5120 )]
52+ recipes = [MoEScalingType .MXFP8 , MoEScalingType . FP8_ROWWISE ]
5353 high_precision_dtypes = [torch .bfloat16 ]
5454 configs = []
5555 for A_shape , B_shape , recipe , high_precision_dtype in itertools .product (
@@ -93,7 +93,8 @@ def run_experiment(
9393 # which represents the right operand.
9494 n_groups = config .B_shape [0 ]
9595 Mg = A .shape [0 ]
96- offs = generate_jagged_offs (n_groups , Mg , multiple_of = 16 )
96+ token_group_alignment_size = 32 if config .recipe == MoEScalingType .MXFP8 else 16
97+ offs = generate_jagged_offs (n_groups , Mg , multiple_of = token_group_alignment_size )
9798
9899 labels = torch .ones (
99100 (A .shape [0 ], B_t .shape [- 1 ]), device = device , dtype = torch .bfloat16
@@ -107,6 +108,7 @@ def run_experiment(
107108 offs ,
108109 labels = labels ,
109110 use_compile = args .compile ,
111+ fullgraph = False ,
110112 )
111113 if args .profile :
112114 profile_fwd_bwd (
@@ -116,18 +118,20 @@ def run_experiment(
116118 offs ,
117119 labels = labels ,
118120 use_compile = args .compile ,
121+ fullgraph = False ,
119122 profile_name = "bf16_profile" ,
120123 )
121124
122125 # benchmark scaled grouped mm with dynamic fp8 rowwise quant
123- fp8_us = bench_fwd_bwd_microseconds (
126+ scaled_us = bench_fwd_bwd_microseconds (
124127 _scaled_grouped_mm ,
125128 A ,
126129 B_t ,
127130 offs ,
128131 scaling_type = config .recipe ,
129132 labels = labels ,
130133 use_compile = args .compile ,
134+ fullgraph = False ,
131135 )
132136 if args .profile :
133137 profile_fwd_bwd (
@@ -139,22 +143,24 @@ def run_experiment(
139143 labels = labels ,
140144 use_compile = args .compile ,
141145 profile_name = "scaled_profile" ,
146+ fullgraph = False ,
142147 )
143148
144149 return ExperimentResult (
145150 bf16_us = round (bf16_us , 3 ),
146- fp8_us = round (fp8_us , 3 ),
147- fp8_speedup = round (bf16_us / fp8_us , 3 ),
151+ scaled_us = round (scaled_us , 3 ),
152+ scaled_speedup = round (bf16_us / scaled_us , 3 ),
148153 )
149154
150155
151156def print_results (experiments : List [Experiment ]):
152157 headers = [
153158 "A_shape" ,
154159 "B_shape" ,
160+ "recipe" ,
155161 "bf16_time_us" ,
156162 "scaled_time_us" ,
157- "fp8_speedup " ,
163+ "scaled_speedup " ,
158164 ]
159165 rows = []
160166 for experiment in experiments :
@@ -164,9 +170,10 @@ def print_results(experiments: List[Experiment]):
164170 [
165171 A_shape ,
166172 B_shape ,
173+ experiment .config .recipe ,
167174 experiment .result .bf16_us ,
168- experiment .result .fp8_us ,
169- f"{ experiment .result .fp8_speedup } x" ,
175+ experiment .result .scaled_us ,
176+ f"{ experiment .result .scaled_speedup } x" ,
170177 ]
171178 )
172179 print (tabulate (rows , headers = headers ))
0 commit comments