Skip to content

Commit 48fb79d

Browse files
[moe training] update tests + benchmarks with conditional runs based on SM arch; make test cases more comprehensive and consistent
1 parent 843448d commit 48fb79d

File tree

8 files changed

+261
-57
lines changed

8 files changed

+261
-57
lines changed

benchmarks/prototype/moe_training/benchmark_2d_3d_grouped_gemms.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
77
import argparse
88
import itertools
9+
import logging
910
from dataclasses import dataclass
1011
from typing import List
1112

@@ -105,10 +106,22 @@ def run_experiment(
105106
)
106107

107108
# bench fp8 rowwise grouped mm
108-
fp8_rowwise_us = bench_fp8_rowwise_grouped_mm(A, B_t, offs)
109+
if torch.cuda.get_device_capability() != (9, 0):
110+
logging.warning(
111+
f"Skipping FP8 rowwise benchmarks, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
112+
)
113+
fp8_rowwise_us = float("inf")
114+
else:
115+
fp8_rowwise_us = bench_fp8_rowwise_grouped_mm(A, B_t, offs)
109116

110117
# benchmark mxfp8 grouped mm
111-
mxfp8_us = bench_mxfp8_grouped_mm(A, B_t, offs)
118+
if torch.cuda.get_device_capability() != (10, 0):
119+
logging.warning(
120+
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
121+
)
122+
mxfp8_us = float("inf")
123+
else:
124+
mxfp8_us = bench_mxfp8_grouped_mm(A, B_t, offs)
112125

113126
return ExperimentResult(
114127
bf16_us=round(bf16_us, 3),
@@ -126,9 +139,25 @@ def print_results(experiments: List[Experiment]):
126139
"bf16_time_us",
127140
"fp8_rowwise_time_us",
128141
"mxfp8_time_us",
142+
"bf16_tflops",
143+
"fp8_rowwise_tflops",
144+
"mxfp8_tflops",
145+
"fp8_rowwise_speedup",
146+
"mxfp8_speedup",
129147
]
130148
rows = []
131149
for experiment in experiments:
150+
# calculate tflops
151+
e, m, n, k = (
152+
experiment.config.e,
153+
experiment.config.m,
154+
experiment.config.n,
155+
experiment.config.k,
156+
)
157+
flops = 2 * e * m * n * k
158+
bf16_tflops = (flops / 1e12) / (experiment.result.bf16_us / 1e6)
159+
fp8_rowwise_tflops = (flops / 1e12) / (experiment.result.fp8_rowwise_us / 1e6)
160+
mxfp8_tflops = (flops / 1e12) / (experiment.result.mxfp8_us / 1e6)
132161
rows.append(
133162
[
134163
experiment.config.e,
@@ -138,6 +167,11 @@ def print_results(experiments: List[Experiment]):
138167
experiment.result.bf16_us,
139168
experiment.result.fp8_rowwise_us,
140169
experiment.result.mxfp8_us,
170+
round(bf16_tflops, 3),
171+
round(fp8_rowwise_tflops, 3),
172+
round(mxfp8_tflops, 3),
173+
f"{experiment.result.bf16_us / experiment.result.fp8_rowwise_us:.2f}x",
174+
f"{experiment.result.bf16_us / experiment.result.mxfp8_us:.2f}x",
141175
]
142176
)
143177
print(tabulate(rows, headers=headers))

benchmarks/prototype/moe_training/benchmark_moe_fsdp.py renamed to benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
#
88
# To run these benchmarks, use the following command:
99
#
10-
# torchrun --nproc-per-node=8 --local-ranks-filter=0 torchao/prototype/moe_training/benchmarks/benchmark_moe_layer.py
10+
# torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py
1111
#
1212
#######################################################################
1313

1414
import argparse
1515
import copy
16+
import logging
1617
import os
1718

1819
import pytest
@@ -23,13 +24,6 @@
2324
from torch.nn import functional as F
2425

2526
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
26-
27-
# this feature requires CUDA and SM89+
28-
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
29-
pytest.skip(
30-
"CUDA not available or compute capability < 8.9", allow_module_level=True
31-
)
32-
3327
from torchao.prototype.moe_training.conversion_utils import (
3428
MoEScalingType,
3529
MoETrainingConfig,
@@ -48,12 +42,27 @@
4842
)
4943

5044

51-
def bench_moe_float8_training_fsdp(
52-
recipe_name: str, enable_profile: bool, use_compile: bool
53-
):
45+
def bench_moe_training_fsdp(recipe_name: str, enable_profile: bool, use_compile: bool):
5446
assert torch.cuda.is_available()
5547
assert recipe_name in ["fp8_rowwise", "mxfp8"]
5648
recipe = MoEScalingType[recipe_name.upper()]
49+
if recipe == MoEScalingType.FP8_ROWWISE and torch.cuda.get_device_capability() != (
50+
9,
51+
0,
52+
):
53+
logging.warning(
54+
f"Skipping FP8 rowwise benchmarks, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
55+
)
56+
return
57+
58+
elif recipe == MoEScalingType.MXFP8 and torch.cuda.get_device_capability() != (
59+
10,
60+
0,
61+
):
62+
logging.warning(
63+
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
64+
)
65+
return
5766

5867
# setup distributed for fsdp
5968
setup_distributed()
@@ -157,14 +166,16 @@ def setup_distributed():
157166
action="store_true",
158167
help="Enable PyTorch profiling and save results to file",
159168
)
160-
parser.add_argument("--recipe", type=str, help="[fp8_rowwise, mxfp8]")
169+
parser.add_argument(
170+
"--recipe", type=str, help="[fp8_rowwise, mxfp8]", required=True
171+
)
161172
parser.add_argument(
162173
"--compile",
163174
action="store_true",
164175
help="use torch.compile",
165176
)
166177
args = parser.parse_args()
167-
bench_moe_float8_training_fsdp(
178+
bench_moe_training_fsdp(
168179
recipe_name=args.recipe,
169180
enable_profile=args.profile,
170181
use_compile=args.compile,

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
77
import argparse
88
import itertools
9+
import logging
910
from dataclasses import dataclass
1011
from typing import List
1112

@@ -184,6 +185,24 @@ def main(args: argparse.Namespace):
184185
configs = get_configs()
185186
results = []
186187
for config in tqdm(configs):
188+
if (
189+
config.recipe == MoEScalingType.FP8_ROWWISE
190+
and torch.cuda.get_device_capability() != (9, 0)
191+
):
192+
logging.warning(
193+
f"Skipping FP8 rowwise benchmarks, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
194+
)
195+
continue
196+
197+
elif (
198+
config.recipe == MoEScalingType.MXFP8
199+
and torch.cuda.get_device_capability() != (10, 0)
200+
):
201+
logging.warning(
202+
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
203+
)
204+
continue
205+
187206
result = run_experiment(config, args)
188207
results.append(Experiment(config=config, result=result))
189208

test/prototype/moe_training/test_everything.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ IS_ROCM=$(rocm-smi --version || true)
1212
# These tests do not work on ROCm yet
1313
if [ -z "$IS_ROCM" ]
1414
then
15+
pytest test/prototype/moe_training/test_kernels.py -s
16+
pytest test/prototype/moe_training/test_training.py -s
1517
./test/prototype/moe_training/test_fsdp.sh
1618
./test/prototype/moe_training/test_tp.sh
1719
./test/prototype/moe_training/test_fsdp_tp.sh

test/prototype/moe_training/test_fsdp.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@
3434
"CUDA not available or compute capability < 8.9", allow_module_level=True
3535
)
3636

37-
from testing_utils import _validate_model_conversion
38-
3937
from torchao.float8.float8_utils import compute_error
4038
from torchao.prototype.moe_training.conversion_utils import (
4139
MoEScalingType,
4240
MoETrainingConfig,
4341
)
4442
from torchao.quantization.quant_api import quantize_
4543

44+
from .testing_utils import _validate_model_conversion
45+
4646
# this test requires torchtitan
4747
try:
4848
from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m
@@ -54,27 +54,71 @@
5454

5555

5656
@pytest.mark.parametrize(
57-
"recipe, min_out_sqnr, alignment_size, min_param_grad_sqnr",
57+
"target_fqns",
58+
[
59+
["experts"],
60+
["does.not.exist"],
61+
],
62+
)
63+
@pytest.mark.parametrize("compile", [False, True])
64+
@pytest.mark.parametrize(
65+
"recipe_config",
5866
[
59-
(MoEScalingType.FP8_ROWWISE, 29.0, 16, 23.0),
60-
(MoEScalingType.MXFP8, 28.0, 32, 21.0),
67+
{
68+
"recipe": MoEScalingType.FP8_ROWWISE,
69+
"group_alignment_size": 16,
70+
"min_out_sqnr": 29.0,
71+
"min_input_grad_sqnr": 29.0,
72+
"min_param_grad_sqnr": 23.0,
73+
},
74+
{
75+
"recipe": MoEScalingType.MXFP8,
76+
"group_alignment_size": 32,
77+
"min_out_sqnr": 28.0,
78+
"min_input_grad_sqnr": 29.0,
79+
"min_param_grad_sqnr": 21.0,
80+
},
6181
],
6282
)
63-
def test_moe_float8_training_fsdp(
64-
recipe: MoEScalingType,
65-
min_out_sqnr: float,
66-
alignment_size: int,
67-
min_param_grad_sqnr: float,
68-
):
83+
def test_moe_training_fsdp(target_fqns: list[str], compile: bool, recipe_config: dict):
84+
(
85+
recipe,
86+
group_alignment_size,
87+
min_out_sqnr,
88+
min_input_grad_sqnr,
89+
min_param_grad_sqnr,
90+
) = (
91+
recipe_config["recipe"],
92+
recipe_config["group_alignment_size"],
93+
recipe_config["min_out_sqnr"],
94+
recipe_config["min_input_grad_sqnr"],
95+
recipe_config["min_param_grad_sqnr"],
96+
)
6997
assert torch.cuda.is_available()
98+
if recipe == MoEScalingType.FP8_ROWWISE and torch.cuda.get_device_capability() != (
99+
9,
100+
0,
101+
):
102+
pytest.skip(
103+
f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
104+
)
105+
106+
elif recipe == MoEScalingType.MXFP8 and torch.cuda.get_device_capability() != (
107+
10,
108+
0,
109+
):
110+
pytest.skip(
111+
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
112+
)
70113

71114
# setup distributed for fsdp
72115
setup_distributed()
73116

74-
set_token_group_alignment_size_m(alignment_size)
117+
# set token group alignment size needed for GEMM (contraction dim stride must be 16 byte aligned)
118+
# or quantization ops (mxfp8 scaling groups are size 1x32)
119+
set_token_group_alignment_size_m(group_alignment_size)
75120

76121
# define model args
77-
target_fqns = ["experts"]
78122
model_args = MoEArgs(
79123
num_experts=8,
80124
)
@@ -143,7 +187,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
143187

144188
# validate input gradient
145189
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
146-
min_input_grad_sqnr = 29.0
147190
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
148191
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
149192
)

0 commit comments

Comments
 (0)