Skip to content

Commit 13768b1

Browse files
committed
Update on "add IntxUnpackedToInt8Tensor to safetensors"
adding `IntxUnpackedToInt8Tensor‎` to safetensors (`IntxWeightOnlyConfig`) modified unit test, `python test/prototype/safetensors/test_safetensors_support.py` [ghstack-poisoned]
2 parents 9c0b4de + 5884d16 commit 13768b1

File tree

11 files changed

+1125
-25
lines changed

11 files changed

+1125
-25
lines changed
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
#
8+
# To run these benchmarks, use the following command:
9+
#
10+
# torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
11+
#
12+
#######################################################################
13+
import os
14+
import time
15+
from dataclasses import dataclass
16+
from typing import List
17+
18+
import torch
19+
from tabulate import tabulate
20+
from torch import distributed as dist
21+
from torch.distributed._functional_collectives import (
22+
all_to_all_single_autograd,
23+
)
24+
from tqdm import tqdm
25+
26+
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
27+
mxfp8_on_device_all_to_all_v,
28+
)
29+
30+
device = torch.device("cuda")
31+
32+
33+
@dataclass(frozen=True)
34+
class ExperimentConfig:
35+
input_shape: tuple[int]
36+
37+
38+
@dataclass(frozen=True)
39+
class ExperimentResult:
40+
bf16_us: float
41+
mxfp8_us: float
42+
43+
44+
@dataclass(frozen=True)
45+
class Experiment:
46+
config: ExperimentConfig
47+
result: ExperimentResult
48+
49+
50+
def get_configs() -> List[ExperimentConfig]:
51+
# (batch_size, seq_len, dim)
52+
input_shapes = [
53+
(8, 8192, 5120),
54+
]
55+
configs = []
56+
for shape in input_shapes:
57+
configs.append(
58+
ExperimentConfig(
59+
input_shape=shape,
60+
)
61+
)
62+
return configs
63+
64+
65+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
66+
batch_size, seq_len, dim = config.input_shape
67+
x = torch.randn(
68+
(batch_size * seq_len, dim),
69+
dtype=torch.bfloat16,
70+
device=device,
71+
)
72+
ref_x = x.detach().clone()
73+
74+
# Max output tokens per rank is worst case where one rank receives all tokens
75+
input_tokens_per_rank = batch_size * seq_len
76+
max_output_tokens_per_rank = input_tokens_per_rank * dist.get_world_size()
77+
78+
def using_bf16(
79+
input_tensor: torch.Tensor, input_splits: torch.Tensor
80+
) -> torch.Tensor:
81+
# Calculate output splits from input splits
82+
output_splits = torch.empty_like(input_splits)
83+
dist.all_to_all_single(output_splits, input_splits)
84+
85+
# Perform all-to-all
86+
out = all_to_all_single_autograd(
87+
input_tensor,
88+
output_splits.tolist(),
89+
input_splits.tolist(),
90+
dist.group.WORLD,
91+
)
92+
out = torch.ops._c10d_functional.wait_tensor(out)
93+
return out
94+
95+
def using_mxfp8(
96+
input_tensor: torch.Tensor, input_splits: torch.Tensor
97+
) -> torch.Tensor:
98+
output, output_splits = mxfp8_on_device_all_to_all_v(
99+
input_tensor,
100+
input_splits,
101+
max_output_tokens_per_rank,
102+
dist.group.WORLD.group_name,
103+
)
104+
output = torch.ops._c10d_functional.wait_tensor(output)
105+
output_splits = torch.ops._c10d_functional.wait_tensor(output_splits)
106+
return output
107+
108+
def warmup(func_no_args):
109+
for _ in range(2):
110+
func_no_args()
111+
112+
num_splits = dist.get_world_size()
113+
input_splits = generate_split_sizes(
114+
num_splits, input_tokens_per_rank, device=device
115+
)
116+
117+
print(
118+
"Benchmarking using bf16",
119+
"batch_size",
120+
batch_size,
121+
"seq_len",
122+
seq_len,
123+
"dim",
124+
dim,
125+
"input_tokens_per_rank",
126+
input_tokens_per_rank,
127+
"max_output_tokens_per_rank",
128+
max_output_tokens_per_rank,
129+
)
130+
warmup(lambda: using_bf16(ref_x, input_splits))
131+
start_ns = time.perf_counter()
132+
using_bf16(ref_x, input_splits)
133+
end_ns = time.perf_counter()
134+
bf16_us = (end_ns - start_ns) * 1e6
135+
136+
print(
137+
"Benchmarking using_mxfp8",
138+
"batch_size",
139+
batch_size,
140+
"seq_len",
141+
seq_len,
142+
"dim",
143+
dim,
144+
"input_tokens_per_rank",
145+
input_tokens_per_rank,
146+
"max_output_tokens_per_rank",
147+
max_output_tokens_per_rank,
148+
)
149+
warmup(lambda: using_mxfp8(x, input_splits))
150+
start_ns = time.perf_counter()
151+
using_mxfp8(x, input_splits)
152+
end_ns = time.perf_counter()
153+
mxfp8_us = (end_ns - start_ns) * 1e6
154+
155+
return ExperimentResult(
156+
bf16_us=bf16_us,
157+
mxfp8_us=mxfp8_us,
158+
)
159+
160+
161+
def print_results(experiments: List[Experiment]):
162+
headers = [
163+
"input_shape",
164+
"num_splits",
165+
"bf16_us",
166+
"mxfp8_us",
167+
]
168+
rows = []
169+
num_splits = dist.get_world_size()
170+
for experiment in experiments:
171+
rows.append(
172+
[
173+
str(experiment.config.input_shape),
174+
num_splits,
175+
experiment.result.bf16_us,
176+
experiment.result.mxfp8_us,
177+
]
178+
)
179+
print(tabulate(rows, headers=headers))
180+
181+
182+
def generate_split_sizes(K: int, N: int, device: str = "cuda") -> torch.Tensor:
183+
"""
184+
Generates a tensor of K random non-negative integers that sum to N.
185+
Used for testing mxfp8_all_to_all_v implementation.
186+
"""
187+
if K <= 0:
188+
raise ValueError("K must be a positive integer.")
189+
if N < 0:
190+
raise ValueError("N must be a non-negative integer.")
191+
192+
if K == 1:
193+
return torch.tensor([N], dtype=torch.long, device=device)
194+
195+
# Generate K-1 random "dividers" in the range [0, N].
196+
dividers = torch.randint(0, N + 1, (K - 1,), device=device)
197+
198+
# Add 0 and N to the set of dividers to form the boundaries.
199+
boundaries = torch.cat(
200+
[torch.tensor([0], device=device), dividers, torch.tensor([N], device=device)]
201+
)
202+
203+
# Sort the boundaries to ensure they are in order
204+
sorted_boundaries = torch.sort(boundaries).values
205+
206+
# The K integers are the differences between consecutive boundaries (will sum to N)
207+
result = sorted_boundaries[1:] - sorted_boundaries[:-1]
208+
209+
return result.to(dtype=torch.int64)
210+
211+
212+
def main():
213+
torch.random.manual_seed(123)
214+
215+
# Set up process group
216+
setup_distributed()
217+
218+
# Generate experiment configs
219+
configs = get_configs()
220+
results = []
221+
for config in tqdm(configs):
222+
result = run_experiment(config)
223+
results.append(Experiment(config=config, result=result))
224+
225+
# Use Tabulate to print results
226+
print_results(results)
227+
228+
# Clean up process group
229+
dist.destroy_process_group()
230+
231+
232+
def setup_distributed():
233+
rank = int(os.environ["RANK"])
234+
world_size = int(os.environ["WORLD_SIZE"])
235+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
236+
torch.cuda.set_device(rank)
237+
238+
239+
if __name__ == "__main__":
240+
main()

benchmarks/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,27 @@ def profile_fwd_bwd(
7272
print(f"Saved: {profile_name}.json")
7373

7474

75+
def profile_fn(fn, *args, profile_name="profile", **kwargs):
76+
wait, warmup, active = 1, 1, 1
77+
total_steps = wait + warmup + active
78+
with torch.profiler.profile(
79+
activities=[
80+
torch.profiler.ProfilerActivity.CPU,
81+
torch.profiler.ProfilerActivity.CUDA,
82+
],
83+
schedule=torch.profiler.schedule(
84+
wait=wait, warmup=warmup, active=active, repeat=0
85+
),
86+
record_shapes=True,
87+
) as prof:
88+
for _ in range(total_steps):
89+
_ = fn(*args, **kwargs)
90+
prof.step()
91+
92+
# Save profiler results
93+
prof.export_chrome_trace(f"{profile_name}.json")
94+
print(f"Saved: {profile_name}.json")
95+
96+
7597
def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
7698
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3

test/prototype/moe_training/mxfp8/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)