Skip to content

Commit 568c193

Browse files
[moe training] update tests + benchmarks with conditional runs based on SM arch; make test cases more comprehensive and consistent (#2905)
1 parent 1bb1a40 commit 568c193

File tree

9 files changed

+435
-124
lines changed

9 files changed

+435
-124
lines changed

benchmarks/prototype/moe_training/benchmark_2d_3d_grouped_gemms.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
77
import argparse
88
import itertools
9+
import logging
910
from dataclasses import dataclass
1011
from typing import List
1112

@@ -105,10 +106,22 @@ def run_experiment(
105106
)
106107

107108
# bench fp8 rowwise grouped mm
108-
fp8_rowwise_us = bench_fp8_rowwise_grouped_mm(A, B_t, offs)
109+
if torch.cuda.get_device_capability() != (9, 0):
110+
logging.warning(
111+
f"Skipping FP8 rowwise benchmarks, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
112+
)
113+
fp8_rowwise_us = float("inf")
114+
else:
115+
fp8_rowwise_us = bench_fp8_rowwise_grouped_mm(A, B_t, offs)
109116

110117
# benchmark mxfp8 grouped mm
111-
mxfp8_us = bench_mxfp8_grouped_mm(A, B_t, offs)
118+
if torch.cuda.get_device_capability() != (10, 0):
119+
logging.warning(
120+
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
121+
)
122+
mxfp8_us = float("inf")
123+
else:
124+
mxfp8_us = bench_mxfp8_grouped_mm(A, B_t, offs)
112125

113126
return ExperimentResult(
114127
bf16_us=round(bf16_us, 3),
@@ -126,9 +139,25 @@ def print_results(experiments: List[Experiment]):
126139
"bf16_time_us",
127140
"fp8_rowwise_time_us",
128141
"mxfp8_time_us",
142+
"bf16_tflops",
143+
"fp8_rowwise_tflops",
144+
"mxfp8_tflops",
145+
"fp8_rowwise_speedup",
146+
"mxfp8_speedup",
129147
]
130148
rows = []
131149
for experiment in experiments:
150+
# calculate tflops
151+
e, m, n, k = (
152+
experiment.config.e,
153+
experiment.config.m,
154+
experiment.config.n,
155+
experiment.config.k,
156+
)
157+
flops = 2 * e * m * n * k
158+
bf16_tflops = (flops / 1e12) / (experiment.result.bf16_us / 1e6)
159+
fp8_rowwise_tflops = (flops / 1e12) / (experiment.result.fp8_rowwise_us / 1e6)
160+
mxfp8_tflops = (flops / 1e12) / (experiment.result.mxfp8_us / 1e6)
132161
rows.append(
133162
[
134163
experiment.config.e,
@@ -138,6 +167,11 @@ def print_results(experiments: List[Experiment]):
138167
experiment.result.bf16_us,
139168
experiment.result.fp8_rowwise_us,
140169
experiment.result.mxfp8_us,
170+
round(bf16_tflops, 3),
171+
round(fp8_rowwise_tflops, 3),
172+
round(mxfp8_tflops, 3),
173+
f"{experiment.result.bf16_us / experiment.result.fp8_rowwise_us:.2f}x",
174+
f"{experiment.result.bf16_us / experiment.result.mxfp8_us:.2f}x",
141175
]
142176
)
143177
print(tabulate(rows, headers=headers))

benchmarks/prototype/moe_training/benchmark_moe_fsdp.py renamed to benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
#
88
# To run these benchmarks, use the following command:
99
#
10-
# torchrun --nproc-per-node=8 --local-ranks-filter=0 torchao/prototype/moe_training/benchmarks/benchmark_moe_layer.py
10+
# torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py
1111
#
1212
#######################################################################
1313

1414
import argparse
1515
import copy
16+
import logging
1617
import os
1718

1819
import pytest
@@ -23,13 +24,6 @@
2324
from torch.nn import functional as F
2425

2526
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
26-
27-
# this feature requires CUDA and SM89+
28-
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
29-
pytest.skip(
30-
"CUDA not available or compute capability < 8.9", allow_module_level=True
31-
)
32-
3327
from torchao.prototype.moe_training.conversion_utils import (
3428
MoEScalingType,
3529
MoETrainingConfig,
@@ -48,12 +42,27 @@
4842
)
4943

5044

51-
def bench_moe_float8_training_fsdp(
52-
recipe_name: str, enable_profile: bool, use_compile: bool
53-
):
45+
def bench_moe_training_fsdp(recipe_name: str, enable_profile: bool, use_compile: bool):
5446
assert torch.cuda.is_available()
5547
assert recipe_name in ["fp8_rowwise", "mxfp8"]
5648
recipe = MoEScalingType[recipe_name.upper()]
49+
if recipe == MoEScalingType.FP8_ROWWISE and torch.cuda.get_device_capability() != (
50+
9,
51+
0,
52+
):
53+
logging.warning(
54+
f"Skipping FP8 rowwise benchmarks, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
55+
)
56+
return
57+
58+
elif recipe == MoEScalingType.MXFP8 and torch.cuda.get_device_capability() != (
59+
10,
60+
0,
61+
):
62+
logging.warning(
63+
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
64+
)
65+
return
5766

5867
# setup distributed for fsdp
5968
setup_distributed()
@@ -157,14 +166,16 @@ def setup_distributed():
157166
action="store_true",
158167
help="Enable PyTorch profiling and save results to file",
159168
)
160-
parser.add_argument("--recipe", type=str, help="[fp8_rowwise, mxfp8]")
169+
parser.add_argument(
170+
"--recipe", type=str, help="[fp8_rowwise, mxfp8]", required=True
171+
)
161172
parser.add_argument(
162173
"--compile",
163174
action="store_true",
164175
help="use torch.compile",
165176
)
166177
args = parser.parse_args()
167-
bench_moe_float8_training_fsdp(
178+
bench_moe_training_fsdp(
168179
recipe_name=args.recipe,
169180
enable_profile=args.profile,
170181
use_compile=args.compile,

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,19 @@
66
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
77
import argparse
88
import itertools
9+
import logging
910
from dataclasses import dataclass
1011
from typing import List
1112

1213
import torch
1314
from tabulate import tabulate
1415
from tqdm import tqdm
1516

16-
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
17+
from benchmarks.utils import (
18+
bench_fwd_bwd_microseconds,
19+
bench_fwd_microseconds,
20+
profile_fwd_bwd,
21+
)
1722
from torchao.prototype.moe_training import _scaled_grouped_mm
1823
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
1924
from torchao.prototype.moe_training.utils import generate_jagged_offs
@@ -34,9 +39,12 @@ class ExperimentConfig:
3439

3540
@dataclass(frozen=True)
3641
class ExperimentResult:
37-
bf16_us: float
38-
scaled_us: float
39-
scaled_speedup: float
42+
bf16_e2e_us: float
43+
scaled_e2e_us: float
44+
scaled_e2e_speedup: float
45+
bf16_fwd_us: float
46+
scaled_fwd_us: float
47+
scaled_fwd_speedup: float
4048

4149

4250
@dataclass(frozen=True)
@@ -100,8 +108,8 @@ def run_experiment(
100108
(A.shape[0], B_t.shape[-1]), device=device, dtype=torch.bfloat16
101109
)
102110

103-
# benchmark bf16 grouped mm
104-
bf16_us = bench_fwd_bwd_microseconds(
111+
# E2E bf16 benchmark + profiling
112+
bf16_e2e_us = bench_fwd_bwd_microseconds(
105113
torch._grouped_mm,
106114
A,
107115
B_t,
@@ -122,8 +130,8 @@ def run_experiment(
122130
profile_name="bf16_profile",
123131
)
124132

125-
# benchmark scaled grouped mm with dynamic fp8 rowwise quant
126-
scaled_us = bench_fwd_bwd_microseconds(
133+
# E2E scaled benchmark + profiling
134+
scaled_e2e_us = bench_fwd_bwd_microseconds(
127135
_scaled_grouped_mm,
128136
A,
129137
B_t,
@@ -146,10 +154,32 @@ def run_experiment(
146154
fullgraph=False,
147155
)
148156

157+
# Forward pass benchmarks
158+
bf16_fwd_us = bench_fwd_microseconds(
159+
torch._grouped_mm,
160+
A,
161+
B_t,
162+
offs,
163+
use_compile=args.compile,
164+
fullgraph=True,
165+
)
166+
scaled_fwd_us = bench_fwd_microseconds(
167+
_scaled_grouped_mm,
168+
A,
169+
B_t,
170+
offs,
171+
scaling_type=config.recipe,
172+
use_compile=args.compile,
173+
fullgraph=True,
174+
)
175+
149176
return ExperimentResult(
150-
bf16_us=round(bf16_us, 3),
151-
scaled_us=round(scaled_us, 3),
152-
scaled_speedup=round(bf16_us / scaled_us, 3),
177+
bf16_e2e_us=round(bf16_e2e_us, 3),
178+
scaled_e2e_us=round(scaled_e2e_us, 3),
179+
scaled_e2e_speedup=round(bf16_e2e_us / scaled_e2e_us, 3),
180+
bf16_fwd_us=round(bf16_fwd_us, 3),
181+
scaled_fwd_us=round(scaled_fwd_us, 3),
182+
scaled_fwd_speedup=round(bf16_fwd_us / scaled_fwd_us, 3),
153183
)
154184

155185

@@ -158,9 +188,12 @@ def print_results(experiments: List[Experiment]):
158188
"A_shape",
159189
"B_shape",
160190
"recipe",
161-
"bf16_time_us",
162-
"scaled_time_us",
163-
"scaled_speedup",
191+
"bf16_e2e_us",
192+
"scaled_e2e_us",
193+
"scaled_e2e_speedup",
194+
"bf16_fwd_us",
195+
"scaled_fwd_us",
196+
"scaled_fwd_speedup",
164197
]
165198
rows = []
166199
for experiment in experiments:
@@ -171,9 +204,12 @@ def print_results(experiments: List[Experiment]):
171204
A_shape,
172205
B_shape,
173206
experiment.config.recipe,
174-
experiment.result.bf16_us,
175-
experiment.result.scaled_us,
176-
f"{experiment.result.scaled_speedup}x",
207+
experiment.result.bf16_e2e_us,
208+
experiment.result.scaled_e2e_us,
209+
f"{experiment.result.scaled_e2e_speedup}x",
210+
experiment.result.bf16_fwd_us,
211+
experiment.result.scaled_fwd_us,
212+
f"{experiment.result.scaled_fwd_speedup}x",
177213
]
178214
)
179215
print(tabulate(rows, headers=headers))
@@ -184,6 +220,24 @@ def main(args: argparse.Namespace):
184220
configs = get_configs()
185221
results = []
186222
for config in tqdm(configs):
223+
if (
224+
config.recipe == MoEScalingType.FP8_ROWWISE
225+
and torch.cuda.get_device_capability() != (9, 0)
226+
):
227+
logging.warning(
228+
f"Skipping FP8 rowwise benchmarks, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
229+
)
230+
continue
231+
232+
elif (
233+
config.recipe == MoEScalingType.MXFP8
234+
and torch.cuda.get_device_capability() != (10, 0)
235+
):
236+
logging.warning(
237+
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
238+
)
239+
continue
240+
187241
result = run_experiment(config, args)
188242
results.append(Experiment(config=config, result=result))
189243

benchmarks/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,28 @@ def bench_fwd_bwd_microseconds(
88
):
99
assert labels is not None
1010

11-
def fwd_bwd():
11+
def fwd_bwd(*args, **kwargs):
1212
out = fn(*args, **kwargs)
1313
loss = F.mse_loss(out, labels)
1414
loss.backward()
1515

1616
fwd_bwd_compiled = (
1717
torch.compile(fwd_bwd, fullgraph=fullgraph) if use_compile else fwd_bwd
1818
)
19-
return benchmark_cuda_function_in_microseconds(fwd_bwd_compiled)
19+
return benchmark_cuda_function_in_microseconds(
20+
fwd_bwd_compiled,
21+
*args,
22+
**kwargs,
23+
)
24+
25+
26+
def bench_fwd_microseconds(fn, *args, use_compile=False, fullgraph=True, **kwargs):
27+
fn_compiled = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn
28+
return benchmark_cuda_function_in_microseconds(
29+
fn_compiled,
30+
*args,
31+
**kwargs,
32+
)
2033

2134

2235
def profile_fwd_bwd(

test/prototype/moe_training/test_everything.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ IS_ROCM=$(rocm-smi --version || true)
1212
# These tests do not work on ROCm yet
1313
if [ -z "$IS_ROCM" ]
1414
then
15+
pytest test/prototype/moe_training/test_kernels.py -s
16+
pytest test/prototype/moe_training/test_training.py -s
1517
./test/prototype/moe_training/test_fsdp.sh
1618
./test/prototype/moe_training/test_tp.sh
1719
./test/prototype/moe_training/test_fsdp_tp.sh

0 commit comments

Comments
 (0)