Skip to content

Commit cbd3adb

Browse files
[mxfp8 moe training] mxfp8 a2a working e2e in torchtitan llama4 training; improve tests + bench scripts (#3088)
[mxfp8 moe training] fix CUDA IMA and improve bench + test scripts
1 parent 5cbbd73 commit cbd3adb

File tree

4 files changed

+237
-123
lines changed

4 files changed

+237
-123
lines changed

benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py

Lines changed: 162 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
1111
#
1212
#######################################################################
13+
import argparse
1314
import os
1415
import time
1516
from dataclasses import dataclass
@@ -18,11 +19,14 @@
1819
import torch
1920
from tabulate import tabulate
2021
from torch import distributed as dist
22+
from torch.distributed import DeviceMesh, init_device_mesh
2123
from torch.distributed._functional_collectives import (
24+
all_to_all_single,
2225
all_to_all_single_autograd,
2326
)
2427
from tqdm import tqdm
2528

29+
from benchmarks.utils import profile_fn
2630
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
2731
mxfp8_on_device_all_to_all_v,
2832
)
@@ -37,8 +41,8 @@ class ExperimentConfig:
3741

3842
@dataclass(frozen=True)
3943
class ExperimentResult:
40-
bf16_us: float
41-
mxfp8_us: float
44+
bf16_ms: float
45+
mxfp8_ms: float
4246

4347

4448
@dataclass(frozen=True)
@@ -50,7 +54,7 @@ class Experiment:
5054
def get_configs() -> List[ExperimentConfig]:
5155
# (batch_size, seq_len, dim)
5256
input_shapes = [
53-
(8, 8192, 5120),
57+
(16, 8192, 5120),
5458
]
5559
configs = []
5660
for shape in input_shapes:
@@ -62,7 +66,111 @@ def get_configs() -> List[ExperimentConfig]:
6266
return configs
6367

6468

65-
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
69+
# Copy/paste a2a impls added in https://github.com/pytorch/torchtitan/pull/1765
70+
def default_a2a_dispatch(
71+
routed_input: torch.Tensor,
72+
num_tokens_per_expert: torch.Tensor,
73+
device_mesh: DeviceMesh,
74+
):
75+
"""
76+
Default implementation of all-to-all dispatch. Incurs device-to-host sync.
77+
78+
Returns:
79+
routed_input: the local tokens after all-to-all dispatch
80+
input_splits: the input splits for all-to-all dispatch
81+
output_splits: the output splits for all-to-all dispatch
82+
num_tokens_per_expert_group: the number of tokens per EP rank after all-to-all dispatch
83+
"""
84+
ep_degree = device_mesh.size(0)
85+
# generate the input splits and output splits for all-to-all
86+
with torch.no_grad():
87+
num_tokens_per_expert_group = all_to_all_single(
88+
num_tokens_per_expert,
89+
None,
90+
None,
91+
group=device_mesh.get_group(),
92+
)
93+
# Need to wait explicitly because it is used by a triton kernel later
94+
# which doesn't realize that AsyncCollectiveTensor needs unwrapping
95+
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
96+
num_tokens_per_expert_group
97+
)
98+
input_splits = (
99+
num_tokens_per_expert.view(ep_degree, -1)
100+
.sum(dim=1)
101+
.to(torch.device("cpu"), non_blocking=True)
102+
)
103+
# NOTE: this would incur a device-to-host sync
104+
output_splits = (
105+
num_tokens_per_expert_group.view(ep_degree, -1)
106+
.sum(dim=1)
107+
.to(torch.device("cpu"), non_blocking=False)
108+
)
109+
input_splits_list = input_splits.tolist()
110+
output_splits_list = output_splits.tolist()
111+
112+
# perform all-to-all
113+
routed_input = all_to_all_single_autograd(
114+
routed_input,
115+
output_splits_list,
116+
input_splits_list,
117+
device_mesh.get_group(),
118+
)
119+
routed_input = torch.ops._c10d_functional.wait_tensor(routed_input)
120+
return (
121+
routed_input,
122+
input_splits_list,
123+
output_splits_list,
124+
num_tokens_per_expert_group,
125+
)
126+
127+
128+
def mxfp8_a2a_dispatch(
129+
routed_input: torch.Tensor,
130+
num_tokens_per_expert: torch.Tensor,
131+
device_mesh: DeviceMesh,
132+
max_tokens_per_ep_rank: int,
133+
):
134+
"""
135+
Perform on-device all-to-all dispatch with dynamically quantized mxfp8 inputs to save network bandwidth
136+
and avoid device-to-host sync.
137+
138+
Returns:
139+
routed_input: the local tokens after all-to-all dispatch
140+
input_splits: the input splits for all-to-all dispatch
141+
output_splits: the output splits for all-to-all dispatch
142+
"""
143+
144+
ep_degree = device_mesh.size(0)
145+
num_tokens_per_expert_group = all_to_all_single(
146+
num_tokens_per_expert,
147+
None,
148+
None,
149+
group=device_mesh.get_group(),
150+
)
151+
input_splits_per_ep_rank = num_tokens_per_expert.view(ep_degree, -1).sum(dim=1)
152+
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
153+
num_tokens_per_expert_group
154+
)
155+
routed_input, output_splits_per_ep_rank = mxfp8_on_device_all_to_all_v(
156+
routed_input,
157+
input_splits_per_ep_rank,
158+
max_tokens_per_ep_rank,
159+
device_mesh.get_group().group_name,
160+
)
161+
tokens_on_rank_after_a2a = output_splits_per_ep_rank.sum()
162+
routed_input_no_padding = routed_input[:tokens_on_rank_after_a2a]
163+
return (
164+
routed_input_no_padding,
165+
input_splits_per_ep_rank,
166+
output_splits_per_ep_rank,
167+
num_tokens_per_expert_group,
168+
)
169+
170+
171+
def run_experiment(
172+
config: ExperimentConfig, args: argparse.Namespace
173+
) -> ExperimentResult:
66174
batch_size, seq_len, dim = config.input_shape
67175
x = torch.randn(
68176
(batch_size * seq_len, dim),
@@ -71,99 +179,70 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
71179
)
72180
ref_x = x.detach().clone()
73181

182+
# Set up device mesh
183+
mesh = init_device_mesh("cuda", (dist.get_world_size(),))
184+
74185
# Max output tokens per rank is worst case where one rank receives all tokens
75186
input_tokens_per_rank = batch_size * seq_len
76187
max_output_tokens_per_rank = input_tokens_per_rank * dist.get_world_size()
77188

78-
def using_bf16(
79-
input_tensor: torch.Tensor, input_splits: torch.Tensor
80-
) -> torch.Tensor:
81-
# Calculate output splits from input splits
82-
output_splits = torch.empty_like(input_splits)
83-
dist.all_to_all_single(output_splits, input_splits)
84-
85-
# Perform all-to-all
86-
out = all_to_all_single_autograd(
87-
input_tensor,
88-
output_splits.tolist(),
89-
input_splits.tolist(),
90-
dist.group.WORLD,
91-
)
92-
out = torch.ops._c10d_functional.wait_tensor(out)
93-
return out
94-
95-
def using_mxfp8(
96-
input_tensor: torch.Tensor, input_splits: torch.Tensor
97-
) -> torch.Tensor:
98-
output, output_splits = mxfp8_on_device_all_to_all_v(
99-
input_tensor,
100-
input_splits,
101-
max_output_tokens_per_rank,
102-
dist.group.WORLD.group_name,
103-
)
104-
output = torch.ops._c10d_functional.wait_tensor(output)
105-
output_splits = torch.ops._c10d_functional.wait_tensor(output_splits)
106-
return output
107-
108189
def warmup(func_no_args):
109190
for _ in range(2):
110191
func_no_args()
111192

112-
num_splits = dist.get_world_size()
193+
num_experts_per_rank = 2
194+
num_splits = dist.get_world_size() * num_experts_per_rank
113195
input_splits = generate_split_sizes(
114196
num_splits, input_tokens_per_rank, device=device
115197
)
116198

117-
print(
118-
"Benchmarking using bf16",
119-
"batch_size",
120-
batch_size,
121-
"seq_len",
122-
seq_len,
123-
"dim",
124-
dim,
125-
"input_tokens_per_rank",
126-
input_tokens_per_rank,
127-
"max_output_tokens_per_rank",
128-
max_output_tokens_per_rank,
129-
)
130-
warmup(lambda: using_bf16(ref_x, input_splits))
131-
start_ns = time.perf_counter()
132-
using_bf16(ref_x, input_splits)
133-
end_ns = time.perf_counter()
134-
bf16_us = (end_ns - start_ns) * 1e6
135-
136-
print(
137-
"Benchmarking using_mxfp8",
138-
"batch_size",
139-
batch_size,
140-
"seq_len",
141-
seq_len,
142-
"dim",
143-
dim,
144-
"input_tokens_per_rank",
145-
input_tokens_per_rank,
146-
"max_output_tokens_per_rank",
147-
max_output_tokens_per_rank,
199+
# Bench default a2a
200+
warmup(lambda: default_a2a_dispatch(ref_x, input_splits, mesh))
201+
start_sec = time.perf_counter()
202+
default_a2a_dispatch(ref_x, input_splits, mesh)
203+
end_sec = time.perf_counter()
204+
bf16_ms = (end_sec - start_sec) * 1e3
205+
if args.profile:
206+
profile_fn(
207+
default_a2a_dispatch,
208+
ref_x,
209+
input_splits,
210+
mesh,
211+
distributed=True,
212+
profile_name="all_to_all_single_autograd",
213+
)
214+
215+
# Bench mxfp8 a2a
216+
warmup(
217+
lambda: mxfp8_a2a_dispatch(x, input_splits, mesh, max_output_tokens_per_rank)
148218
)
149-
warmup(lambda: using_mxfp8(x, input_splits))
150-
start_ns = time.perf_counter()
151-
using_mxfp8(x, input_splits)
152-
end_ns = time.perf_counter()
153-
mxfp8_us = (end_ns - start_ns) * 1e6
219+
start_sec = time.perf_counter()
220+
mxfp8_a2a_dispatch(x, input_splits, mesh, max_output_tokens_per_rank)
221+
end_sec = time.perf_counter()
222+
mxfp8_ms = (end_sec - start_sec) * 1e3
223+
if args.profile:
224+
profile_fn(
225+
mxfp8_a2a_dispatch,
226+
x,
227+
input_splits,
228+
mesh,
229+
max_output_tokens_per_rank,
230+
distributed=True,
231+
profile_name="mxfp8_all_to_all_v",
232+
)
154233

155234
return ExperimentResult(
156-
bf16_us=bf16_us,
157-
mxfp8_us=mxfp8_us,
235+
bf16_ms=bf16_ms,
236+
mxfp8_ms=mxfp8_ms,
158237
)
159238

160239

161240
def print_results(experiments: List[Experiment]):
162241
headers = [
163242
"input_shape",
164243
"num_splits",
165-
"bf16_us",
166-
"mxfp8_us",
244+
"bf16_ms",
245+
"mxfp8_ms",
167246
]
168247
rows = []
169248
num_splits = dist.get_world_size()
@@ -172,8 +251,8 @@ def print_results(experiments: List[Experiment]):
172251
[
173252
str(experiment.config.input_shape),
174253
num_splits,
175-
experiment.result.bf16_us,
176-
experiment.result.mxfp8_us,
254+
experiment.result.bf16_ms,
255+
experiment.result.mxfp8_ms,
177256
]
178257
)
179258
print(tabulate(rows, headers=headers))
@@ -209,7 +288,7 @@ def generate_split_sizes(K: int, N: int, device: str = "cuda") -> torch.Tensor:
209288
return result.to(dtype=torch.int64)
210289

211290

212-
def main():
291+
def main(args: argparse.Namespace):
213292
torch.random.manual_seed(123)
214293

215294
# Set up process group
@@ -219,7 +298,7 @@ def main():
219298
configs = get_configs()
220299
results = []
221300
for config in tqdm(configs):
222-
result = run_experiment(config)
301+
result = run_experiment(config, args)
223302
results.append(Experiment(config=config, result=result))
224303

225304
# Use Tabulate to print results
@@ -237,4 +316,7 @@ def setup_distributed():
237316

238317

239318
if __name__ == "__main__":
240-
main()
319+
parser = argparse.ArgumentParser()
320+
parser.add_argument("--profile", action="store_true")
321+
args = parser.parse_args()
322+
main(args)

benchmarks/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def profile_fwd_bwd(
7272
print(f"Saved: {profile_name}.json")
7373

7474

75-
def profile_fn(fn, *args, profile_name="profile", **kwargs):
75+
def profile_fn(fn, *args, profile_name="profile", distributed=False, **kwargs):
7676
wait, warmup, active = 1, 1, 1
7777
total_steps = wait + warmup + active
7878
with torch.profiler.profile(
@@ -89,9 +89,11 @@ def profile_fn(fn, *args, profile_name="profile", **kwargs):
8989
_ = fn(*args, **kwargs)
9090
prof.step()
9191

92-
# Save profiler results
93-
prof.export_chrome_trace(f"{profile_name}.json")
94-
print(f"Saved: {profile_name}.json")
92+
if distributed:
93+
if torch.distributed.get_rank() == 0:
94+
# Save profiler results
95+
prof.export_chrome_trace(f"{profile_name}.json")
96+
print(f"Saved: {profile_name}.json")
9597

9698

9799
def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):

test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
@instantiate_parametrized_tests
32-
class TritonAllReduceTest(MultiProcessTestCase):
32+
class MXFP8AllToAllVTest(MultiProcessTestCase):
3333
def setUp(self) -> None:
3434
super().setUp()
3535
self._spawn_processes()

0 commit comments

Comments
 (0)