1414from tabulate import tabulate
1515from tqdm import tqdm
1616
17- from benchmarks .utils import bench_fwd_bwd_microseconds , profile_fwd_bwd
17+ from benchmarks .utils import (
18+ bench_fwd_bwd_microseconds ,
19+ bench_fwd_microseconds ,
20+ profile_fwd_bwd ,
21+ )
1822from torchao .prototype .moe_training import _scaled_grouped_mm
1923from torchao .prototype .moe_training .conversion_utils import MoEScalingType
2024from torchao .prototype .moe_training .utils import generate_jagged_offs
@@ -35,9 +39,12 @@ class ExperimentConfig:
3539
3640@dataclass (frozen = True )
3741class ExperimentResult :
38- bf16_us : float
39- scaled_us : float
40- scaled_speedup : float
42+ bf16_e2e_us : float
43+ scaled_e2e_us : float
44+ scaled_e2e_speedup : float
45+ bf16_fwd_us : float
46+ scaled_fwd_us : float
47+ scaled_fwd_speedup : float
4148
4249
4350@dataclass (frozen = True )
@@ -101,8 +108,8 @@ def run_experiment(
101108 (A .shape [0 ], B_t .shape [- 1 ]), device = device , dtype = torch .bfloat16
102109 )
103110
104- # benchmark bf16 grouped mm
105- bf16_us = bench_fwd_bwd_microseconds (
111+ # E2E bf16 benchmark + profiling
112+ bf16_e2e_us = bench_fwd_bwd_microseconds (
106113 torch ._grouped_mm ,
107114 A ,
108115 B_t ,
@@ -123,8 +130,8 @@ def run_experiment(
123130 profile_name = "bf16_profile" ,
124131 )
125132
126- # benchmark scaled grouped mm with dynamic fp8 rowwise quant
127- scaled_us = bench_fwd_bwd_microseconds (
133+ # E2E scaled benchmark + profiling
134+ scaled_e2e_us = bench_fwd_bwd_microseconds (
128135 _scaled_grouped_mm ,
129136 A ,
130137 B_t ,
@@ -147,10 +154,32 @@ def run_experiment(
147154 fullgraph = False ,
148155 )
149156
157+ # Forward pass benchmarks
158+ bf16_fwd_us = bench_fwd_microseconds (
159+ torch ._grouped_mm ,
160+ A ,
161+ B_t ,
162+ offs ,
163+ use_compile = args .compile ,
164+ fullgraph = True ,
165+ )
166+ scaled_fwd_us = bench_fwd_microseconds (
167+ _scaled_grouped_mm ,
168+ A ,
169+ B_t ,
170+ offs ,
171+ scaling_type = config .recipe ,
172+ use_compile = args .compile ,
173+ fullgraph = True ,
174+ )
175+
150176 return ExperimentResult (
151- bf16_us = round (bf16_us , 3 ),
152- scaled_us = round (scaled_us , 3 ),
153- scaled_speedup = round (bf16_us / scaled_us , 3 ),
177+ bf16_e2e_us = round (bf16_e2e_us , 3 ),
178+ scaled_e2e_us = round (scaled_e2e_us , 3 ),
179+ scaled_e2e_speedup = round (bf16_e2e_us / scaled_e2e_us , 3 ),
180+ bf16_fwd_us = round (bf16_fwd_us , 3 ),
181+ scaled_fwd_us = round (scaled_fwd_us , 3 ),
182+ scaled_fwd_speedup = round (bf16_fwd_us / scaled_fwd_us , 3 ),
154183 )
155184
156185
@@ -159,9 +188,12 @@ def print_results(experiments: List[Experiment]):
159188 "A_shape" ,
160189 "B_shape" ,
161190 "recipe" ,
162- "bf16_time_us" ,
163- "scaled_time_us" ,
164- "scaled_speedup" ,
191+ "bf16_e2e_us" ,
192+ "scaled_e2e_us" ,
193+ "scaled_e2e_speedup" ,
194+ "bf16_fwd_us" ,
195+ "scaled_fwd_us" ,
196+ "scaled_fwd_speedup" ,
165197 ]
166198 rows = []
167199 for experiment in experiments :
@@ -172,9 +204,12 @@ def print_results(experiments: List[Experiment]):
172204 A_shape ,
173205 B_shape ,
174206 experiment .config .recipe ,
175- experiment .result .bf16_us ,
176- experiment .result .scaled_us ,
177- f"{ experiment .result .scaled_speedup } x" ,
207+ experiment .result .bf16_e2e_us ,
208+ experiment .result .scaled_e2e_us ,
209+ f"{ experiment .result .scaled_e2e_speedup } x" ,
210+ experiment .result .bf16_fwd_us ,
211+ experiment .result .scaled_fwd_us ,
212+ f"{ experiment .result .scaled_fwd_speedup } x" ,
178213 ]
179214 )
180215 print (tabulate (rows , headers = headers ))
0 commit comments