Skip to content

Commit 02a1bc4

Browse files
committed
Merge remote-tracking branch 'origin/main' into wengshiy/int8_scaled_embedding_bag
2 parents e846df7 + 0d3217d commit 02a1bc4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2144
-786
lines changed

.github/scripts/torchao_model_releases/README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ uv pip install vllm --pre --extra-index-url https://download.pytorch.org/whl/nig
119119

120120
After environment is setup, we can run eval:
121121
```
122-
sh eval.sh --eval_type latency --model_ids Qwen/Qwen3-8B --batch_sizes 1,256
122+
sh eval.sh --eval_type latency --model_ids Qwen/Qwen3-8B --batch_sizes 1 256
123123
```
124124

125125
#### Model Quality Eval
@@ -129,9 +129,16 @@ uv pip install lm-eval
129129
```
130130
After environment is setup, we can run eval:
131131
```
132-
sh eval.sh --eval_type quality --model_ids Qwen/Qwen3-8B --tasks hellaswag,mmlu
132+
sh eval.sh --eval_type quality --model_ids Qwen/Qwen3-8B --tasks hellaswag mmlu
133133
```
134134

135+
Note: you can pass in `--use_cache` if the eval task failed during the middle of the run
136+
and you don't want to re-run all evals.
137+
```
138+
sh eval.sh --eval_type quality --model_ids Qwen/Qwen3-8B --tasks hellaswag mmlu --use_cache
139+
```
140+
141+
135142
#### Summarize results
136143
After we have finished all evals for each model, we can summarize the results with:
137144
```

.github/scripts/torchao_model_releases/eval.sh

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ set -e
99
source eval_env_checks.sh
1010

1111
usage() {
12-
echo "Usage: $0 --model_ids <model1> <model2> ... [--eval_type <all|memory|latency|quality>] [--batch_sizes <batch_sizes>] [--tasks <tasks>]"
12+
echo "Usage: $0 --model_ids <model1> <model2> ... [--eval_type <all|memory|latency|quality>] [--batch_sizes <batch_sizes>] [--tasks <tasks>] [--use_cache]"
1313
echo "Defaults:"
1414
echo " batch_sizes: 1 256"
1515
echo " tasks: mmlu"
@@ -20,6 +20,7 @@ EVAL_TYPE="all"
2020
# these will be parsed in the other scripts
2121
BATCH_SIZES="1 256" # Default for latency eval
2222
TASKS="mmlu" # Default for quality eval
23+
USE_CACHE=false # default: do not use cache
2324
# Parse arguments
2425
while [[ $# -gt 0 ]]; do
2526
case "$1" in
@@ -58,6 +59,10 @@ while [[ $# -gt 0 ]]; do
5859
TASKS="$1"
5960
shift
6061
;;
62+
--use_cache)
63+
USE_CACHE=true
64+
shift
65+
;;
6166
*)
6267
echo "Unknown argument: $1"
6368
usage
@@ -82,7 +87,11 @@ run_latency() {
8287
run_quality() {
8388
check_lm_eval
8489
local model_id="$1"
85-
sh eval_quality.sh --model_ids "$model_id" --tasks $TASKS
90+
if $USE_CACHE; then
91+
sh eval_quality.sh --model_ids "$model_id" --tasks $TASKS --use_cache
92+
else
93+
sh eval_quality.sh --model_ids "$model_id" --tasks $TASKS
94+
fi
8695
}
8796
for MODEL_ID in "${MODEL_ID_ARRAY[@]}"; do
8897
case "$EVAL_TYPE" in

.github/scripts/torchao_model_releases/eval_quality.sh

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ check_lm_eval
1111

1212
MODEL_ID_ARRAY=()
1313
TASK_ARRAY=("mmlu") # default can be overwritten by user input
14+
USE_CACHE=false # default: do not use cache
1415
# Parse arguments
1516
while [[ $# -gt 0 ]]; do
1617
case "$1" in
@@ -29,9 +30,13 @@ while [[ $# -gt 0 ]]; do
2930
shift
3031
done
3132
;;
33+
--use_cache)
34+
USE_CACHE=true
35+
shift
36+
;;
3237
*)
3338
echo "Unknown argument: $1"
34-
echo "Usage: $0 --model_id <model_id> [--tasks <tasks> (comma-separated, e.g. mmlu,arc_challenge, default mmlu)]"
39+
echo "Usage: $0 --model_id <model_id> [--tasks <tasks> (comma-separated, e.g. mmlu,arc_challenge, default mmlu)] [--use_cache]"
3540
exit 1
3641
;;
3742
esac
@@ -51,16 +56,19 @@ for MODEL_ID in "${MODEL_ID_ARRAY[@]}"; do
5156
EVAL_CACHE_DB_PREFIX="/tmp/${SAFE_MODEL_ID}_quality_${TASK}"
5257
mkdir -p "${EVAL_CACHE_DB_PREFIX}"
5358
echo "Running model quality (accuracy) evaluation for model $MODEL_ID on task $TASK"
54-
55-
lm_eval \
59+
LM_EVAL_CMD="lm_eval \
5660
--model hf \
57-
--model_args pretrained="$MODEL_ID" \
58-
--tasks "$TASK" \
61+
--model_args pretrained=\"$MODEL_ID\" \
62+
--tasks \"$TASK\" \
5963
--device cuda:0 \
60-
--use_cache "$EVAL_CACHE_DB_PREFIX" \
6164
--batch_size auto \
62-
--output_path "$RESULTS_DIR" > "$OUTPUT_FILE" 2>&1
65+
--output_path \"$RESULTS_DIR\""
66+
67+
if $USE_CACHE; then
68+
LM_EVAL_CMD="$LM_EVAL_CMD --use_cache \"$EVAL_CACHE_DB_PREFIX\""
69+
fi
6370

71+
eval "$LM_EVAL_CMD" > "$OUTPUT_FILE" 2>&1
6472
echo "Quality eval output for task '$TASK' saved to $OUTPUT_FILE"
6573
done
6674
echo "======================== Eval Model Quality $MODEL_ID End =================="

benchmarks/benchmark_e2e_fp8_sparse_linear.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,18 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
4040
input_tensor = torch.randn(num_tokens, hidden_size).to(torch.bfloat16).cuda()
4141
fp16_time = benchmark_microseconds(ffn_ref, input_tensor)
4242

43+
# Sparsify-only benchmarks
44+
ao_fast_sparsification_time = benchmark_microseconds(
45+
torch.ops.torchao.sparse24_sm90_sparsify(
46+
input_tensor,
47+
"cutlass",
48+
"identity",
49+
"largest",
50+
dtype=torch.float8_e4m3fn,
51+
)
52+
)
53+
cusparselt_time = benchmark_microseconds(torch._cslt_compress, input_tensor)
54+
4355
# bf16
4456
ffn_clone = (
4557
nn.Sequential(
@@ -117,7 +129,10 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
117129
"fp8_c_time (us)": fp8_c_time,
118130
"fp8_c_sparse_time (us)": fp8_c_sparse_time,
119131
"fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time,
132+
"ao_fast_sparsification_time (us)": ao_fast_sparsification_time,
133+
"cusparselt_compress_time (us)": cusparselt_time,
120134
"speedup": fp8_c_time / fp8_c_activation_sparse_time,
135+
"sparsify_speedup": cusparselt_time / ao_fast_sparsification_time,
121136
}
122137

123138

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
#
8+
# To run these benchmarks, use the following command:
9+
#
10+
# torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
11+
#
12+
#######################################################################
13+
import os
14+
import time
15+
from dataclasses import dataclass
16+
from typing import List
17+
18+
import torch
19+
from tabulate import tabulate
20+
from torch import distributed as dist
21+
from torch.distributed._functional_collectives import (
22+
all_to_all_single_autograd,
23+
)
24+
from tqdm import tqdm
25+
26+
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
27+
mxfp8_on_device_all_to_all_v,
28+
)
29+
30+
device = torch.device("cuda")
31+
32+
33+
@dataclass(frozen=True)
34+
class ExperimentConfig:
35+
input_shape: tuple[int]
36+
37+
38+
@dataclass(frozen=True)
39+
class ExperimentResult:
40+
bf16_us: float
41+
mxfp8_us: float
42+
43+
44+
@dataclass(frozen=True)
45+
class Experiment:
46+
config: ExperimentConfig
47+
result: ExperimentResult
48+
49+
50+
def get_configs() -> List[ExperimentConfig]:
51+
# (batch_size, seq_len, dim)
52+
input_shapes = [
53+
(8, 8192, 5120),
54+
]
55+
configs = []
56+
for shape in input_shapes:
57+
configs.append(
58+
ExperimentConfig(
59+
input_shape=shape,
60+
)
61+
)
62+
return configs
63+
64+
65+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
66+
batch_size, seq_len, dim = config.input_shape
67+
x = torch.randn(
68+
(batch_size * seq_len, dim),
69+
dtype=torch.bfloat16,
70+
device=device,
71+
)
72+
ref_x = x.detach().clone()
73+
74+
# Max output tokens per rank is worst case where one rank receives all tokens
75+
input_tokens_per_rank = batch_size * seq_len
76+
max_output_tokens_per_rank = input_tokens_per_rank * dist.get_world_size()
77+
78+
def using_bf16(
79+
input_tensor: torch.Tensor, input_splits: torch.Tensor
80+
) -> torch.Tensor:
81+
# Calculate output splits from input splits
82+
output_splits = torch.empty_like(input_splits)
83+
dist.all_to_all_single(output_splits, input_splits)
84+
85+
# Perform all-to-all
86+
out = all_to_all_single_autograd(
87+
input_tensor,
88+
output_splits.tolist(),
89+
input_splits.tolist(),
90+
dist.group.WORLD,
91+
)
92+
out = torch.ops._c10d_functional.wait_tensor(out)
93+
return out
94+
95+
def using_mxfp8(
96+
input_tensor: torch.Tensor, input_splits: torch.Tensor
97+
) -> torch.Tensor:
98+
output, output_splits = mxfp8_on_device_all_to_all_v(
99+
input_tensor,
100+
input_splits,
101+
max_output_tokens_per_rank,
102+
dist.group.WORLD.group_name,
103+
)
104+
output = torch.ops._c10d_functional.wait_tensor(output)
105+
output_splits = torch.ops._c10d_functional.wait_tensor(output_splits)
106+
return output
107+
108+
def warmup(func_no_args):
109+
for _ in range(2):
110+
func_no_args()
111+
112+
num_splits = dist.get_world_size()
113+
input_splits = generate_split_sizes(
114+
num_splits, input_tokens_per_rank, device=device
115+
)
116+
117+
print(
118+
"Benchmarking using bf16",
119+
"batch_size",
120+
batch_size,
121+
"seq_len",
122+
seq_len,
123+
"dim",
124+
dim,
125+
"input_tokens_per_rank",
126+
input_tokens_per_rank,
127+
"max_output_tokens_per_rank",
128+
max_output_tokens_per_rank,
129+
)
130+
warmup(lambda: using_bf16(ref_x, input_splits))
131+
start_ns = time.perf_counter()
132+
using_bf16(ref_x, input_splits)
133+
end_ns = time.perf_counter()
134+
bf16_us = (end_ns - start_ns) * 1e6
135+
136+
print(
137+
"Benchmarking using_mxfp8",
138+
"batch_size",
139+
batch_size,
140+
"seq_len",
141+
seq_len,
142+
"dim",
143+
dim,
144+
"input_tokens_per_rank",
145+
input_tokens_per_rank,
146+
"max_output_tokens_per_rank",
147+
max_output_tokens_per_rank,
148+
)
149+
warmup(lambda: using_mxfp8(x, input_splits))
150+
start_ns = time.perf_counter()
151+
using_mxfp8(x, input_splits)
152+
end_ns = time.perf_counter()
153+
mxfp8_us = (end_ns - start_ns) * 1e6
154+
155+
return ExperimentResult(
156+
bf16_us=bf16_us,
157+
mxfp8_us=mxfp8_us,
158+
)
159+
160+
161+
def print_results(experiments: List[Experiment]):
162+
headers = [
163+
"input_shape",
164+
"num_splits",
165+
"bf16_us",
166+
"mxfp8_us",
167+
]
168+
rows = []
169+
num_splits = dist.get_world_size()
170+
for experiment in experiments:
171+
rows.append(
172+
[
173+
str(experiment.config.input_shape),
174+
num_splits,
175+
experiment.result.bf16_us,
176+
experiment.result.mxfp8_us,
177+
]
178+
)
179+
print(tabulate(rows, headers=headers))
180+
181+
182+
def generate_split_sizes(K: int, N: int, device: str = "cuda") -> torch.Tensor:
183+
"""
184+
Generates a tensor of K random non-negative integers that sum to N.
185+
Used for testing mxfp8_all_to_all_v implementation.
186+
"""
187+
if K <= 0:
188+
raise ValueError("K must be a positive integer.")
189+
if N < 0:
190+
raise ValueError("N must be a non-negative integer.")
191+
192+
if K == 1:
193+
return torch.tensor([N], dtype=torch.long, device=device)
194+
195+
# Generate K-1 random "dividers" in the range [0, N].
196+
dividers = torch.randint(0, N + 1, (K - 1,), device=device)
197+
198+
# Add 0 and N to the set of dividers to form the boundaries.
199+
boundaries = torch.cat(
200+
[torch.tensor([0], device=device), dividers, torch.tensor([N], device=device)]
201+
)
202+
203+
# Sort the boundaries to ensure they are in order
204+
sorted_boundaries = torch.sort(boundaries).values
205+
206+
# The K integers are the differences between consecutive boundaries (will sum to N)
207+
result = sorted_boundaries[1:] - sorted_boundaries[:-1]
208+
209+
return result.to(dtype=torch.int64)
210+
211+
212+
def main():
213+
torch.random.manual_seed(123)
214+
215+
# Set up process group
216+
setup_distributed()
217+
218+
# Generate experiment configs
219+
configs = get_configs()
220+
results = []
221+
for config in tqdm(configs):
222+
result = run_experiment(config)
223+
results.append(Experiment(config=config, result=result))
224+
225+
# Use Tabulate to print results
226+
print_results(results)
227+
228+
# Clean up process group
229+
dist.destroy_process_group()
230+
231+
232+
def setup_distributed():
233+
rank = int(os.environ["RANK"])
234+
world_size = int(os.environ["WORLD_SIZE"])
235+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
236+
torch.cuda.set_device(rank)
237+
238+
239+
if __name__ == "__main__":
240+
main()

0 commit comments

Comments
 (0)