Skip to content

Commit e47baf7

Browse files
split fwd and e2e in autograd func bench script
1 parent d401e70 commit e47baf7

File tree

2 files changed

+67
-19
lines changed

2 files changed

+67
-19
lines changed

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from tabulate import tabulate
1515
from tqdm import tqdm
1616

17-
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
17+
from benchmarks.utils import (
18+
bench_fwd_bwd_microseconds,
19+
bench_fwd_microseconds,
20+
profile_fwd_bwd,
21+
)
1822
from torchao.prototype.moe_training import _scaled_grouped_mm
1923
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
2024
from torchao.prototype.moe_training.utils import generate_jagged_offs
@@ -35,9 +39,12 @@ class ExperimentConfig:
3539

3640
@dataclass(frozen=True)
3741
class ExperimentResult:
38-
bf16_us: float
39-
scaled_us: float
40-
scaled_speedup: float
42+
bf16_e2e_us: float
43+
scaled_e2e_us: float
44+
scaled_e2e_speedup: float
45+
bf16_fwd_us: float
46+
scaled_fwd_us: float
47+
scaled_fwd_speedup: float
4148

4249

4350
@dataclass(frozen=True)
@@ -101,8 +108,8 @@ def run_experiment(
101108
(A.shape[0], B_t.shape[-1]), device=device, dtype=torch.bfloat16
102109
)
103110

104-
# benchmark bf16 grouped mm
105-
bf16_us = bench_fwd_bwd_microseconds(
111+
# E2E bf16 benchmark + profiling
112+
bf16_e2e_us = bench_fwd_bwd_microseconds(
106113
torch._grouped_mm,
107114
A,
108115
B_t,
@@ -123,8 +130,8 @@ def run_experiment(
123130
profile_name="bf16_profile",
124131
)
125132

126-
# benchmark scaled grouped mm with dynamic fp8 rowwise quant
127-
scaled_us = bench_fwd_bwd_microseconds(
133+
# E2E scaled benchmark + profiling
134+
scaled_e2e_us = bench_fwd_bwd_microseconds(
128135
_scaled_grouped_mm,
129136
A,
130137
B_t,
@@ -147,10 +154,32 @@ def run_experiment(
147154
fullgraph=False,
148155
)
149156

157+
# Forward pass benchmarks
158+
bf16_fwd_us = bench_fwd_microseconds(
159+
torch._grouped_mm,
160+
A,
161+
B_t,
162+
offs,
163+
use_compile=args.compile,
164+
fullgraph=True,
165+
)
166+
scaled_fwd_us = bench_fwd_microseconds(
167+
_scaled_grouped_mm,
168+
A,
169+
B_t,
170+
offs,
171+
scaling_type=config.recipe,
172+
use_compile=args.compile,
173+
fullgraph=True,
174+
)
175+
150176
return ExperimentResult(
151-
bf16_us=round(bf16_us, 3),
152-
scaled_us=round(scaled_us, 3),
153-
scaled_speedup=round(bf16_us / scaled_us, 3),
177+
bf16_e2e_us=round(bf16_e2e_us, 3),
178+
scaled_e2e_us=round(scaled_e2e_us, 3),
179+
scaled_e2e_speedup=round(bf16_e2e_us / scaled_e2e_us, 3),
180+
bf16_fwd_us=round(bf16_fwd_us, 3),
181+
scaled_fwd_us=round(scaled_fwd_us, 3),
182+
scaled_fwd_speedup=round(bf16_fwd_us / scaled_fwd_us, 3),
154183
)
155184

156185

@@ -159,9 +188,12 @@ def print_results(experiments: List[Experiment]):
159188
"A_shape",
160189
"B_shape",
161190
"recipe",
162-
"bf16_time_us",
163-
"scaled_time_us",
164-
"scaled_speedup",
191+
"bf16_e2e_us",
192+
"scaled_e2e_us",
193+
"scaled_e2e_speedup",
194+
"bf16_fwd_us",
195+
"scaled_fwd_us",
196+
"scaled_fwd_speedup",
165197
]
166198
rows = []
167199
for experiment in experiments:
@@ -172,9 +204,12 @@ def print_results(experiments: List[Experiment]):
172204
A_shape,
173205
B_shape,
174206
experiment.config.recipe,
175-
experiment.result.bf16_us,
176-
experiment.result.scaled_us,
177-
f"{experiment.result.scaled_speedup}x",
207+
experiment.result.bf16_e2e_us,
208+
experiment.result.scaled_e2e_us,
209+
f"{experiment.result.scaled_e2e_speedup}x",
210+
experiment.result.bf16_fwd_us,
211+
experiment.result.scaled_fwd_us,
212+
f"{experiment.result.scaled_fwd_speedup}x",
178213
]
179214
)
180215
print(tabulate(rows, headers=headers))

benchmarks/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,28 @@ def bench_fwd_bwd_microseconds(
88
):
99
assert labels is not None
1010

11-
def fwd_bwd():
11+
def fwd_bwd(*args, **kwargs):
1212
out = fn(*args, **kwargs)
1313
loss = F.mse_loss(out, labels)
1414
loss.backward()
1515

1616
fwd_bwd_compiled = (
1717
torch.compile(fwd_bwd, fullgraph=fullgraph) if use_compile else fwd_bwd
1818
)
19-
return benchmark_cuda_function_in_microseconds(fwd_bwd_compiled)
19+
return benchmark_cuda_function_in_microseconds(
20+
fwd_bwd_compiled,
21+
*args,
22+
**kwargs,
23+
)
24+
25+
26+
def bench_fwd_microseconds(fn, *args, use_compile=False, fullgraph=True, **kwargs):
27+
fn_compiled = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn
28+
return benchmark_cuda_function_in_microseconds(
29+
fn_compiled,
30+
*args,
31+
**kwargs,
32+
)
2033

2134

2235
def profile_fwd_bwd(

0 commit comments

Comments
 (0)