|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +###################################################################### |
| 7 | +# |
| 8 | +# To run these benchmarks, use the following command: |
| 9 | +# |
| 10 | +# torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py |
| 11 | +# |
| 12 | +####################################################################### |
| 13 | +import os |
| 14 | +import time |
| 15 | +from dataclasses import dataclass |
| 16 | +from typing import List |
| 17 | + |
| 18 | +import torch |
| 19 | +from tabulate import tabulate |
| 20 | +from torch import distributed as dist |
| 21 | +from torch.distributed._functional_collectives import ( |
| 22 | + all_to_all_single_autograd, |
| 23 | +) |
| 24 | +from tqdm import tqdm |
| 25 | + |
| 26 | +from torchao.prototype.moe_training.kernels.mxfp8.comms import ( |
| 27 | + mxfp8_on_device_all_to_all_v, |
| 28 | +) |
| 29 | + |
| 30 | +device = torch.device("cuda") |
| 31 | + |
| 32 | + |
| 33 | +@dataclass(frozen=True) |
| 34 | +class ExperimentConfig: |
| 35 | + input_shape: tuple[int] |
| 36 | + |
| 37 | + |
| 38 | +@dataclass(frozen=True) |
| 39 | +class ExperimentResult: |
| 40 | + bf16_us: float |
| 41 | + mxfp8_us: float |
| 42 | + |
| 43 | + |
| 44 | +@dataclass(frozen=True) |
| 45 | +class Experiment: |
| 46 | + config: ExperimentConfig |
| 47 | + result: ExperimentResult |
| 48 | + |
| 49 | + |
| 50 | +def get_configs() -> List[ExperimentConfig]: |
| 51 | + # (batch_size, seq_len, dim) |
| 52 | + input_shapes = [ |
| 53 | + (8, 8192, 5120), |
| 54 | + ] |
| 55 | + configs = [] |
| 56 | + for shape in input_shapes: |
| 57 | + configs.append( |
| 58 | + ExperimentConfig( |
| 59 | + input_shape=shape, |
| 60 | + ) |
| 61 | + ) |
| 62 | + return configs |
| 63 | + |
| 64 | + |
| 65 | +def run_experiment(config: ExperimentConfig) -> ExperimentResult: |
| 66 | + batch_size, seq_len, dim = config.input_shape |
| 67 | + x = torch.randn( |
| 68 | + (batch_size * seq_len, dim), |
| 69 | + dtype=torch.bfloat16, |
| 70 | + device=device, |
| 71 | + ) |
| 72 | + ref_x = x.detach().clone() |
| 73 | + |
| 74 | + # Max output tokens per rank is worst case where one rank receives all tokens |
| 75 | + input_tokens_per_rank = batch_size * seq_len |
| 76 | + max_output_tokens_per_rank = input_tokens_per_rank * dist.get_world_size() |
| 77 | + |
| 78 | + def using_bf16( |
| 79 | + input_tensor: torch.Tensor, input_splits: torch.Tensor |
| 80 | + ) -> torch.Tensor: |
| 81 | + # Calculate output splits from input splits |
| 82 | + output_splits = torch.empty_like(input_splits) |
| 83 | + dist.all_to_all_single(output_splits, input_splits) |
| 84 | + |
| 85 | + # Perform all-to-all |
| 86 | + out = all_to_all_single_autograd( |
| 87 | + input_tensor, |
| 88 | + output_splits.tolist(), |
| 89 | + input_splits.tolist(), |
| 90 | + dist.group.WORLD, |
| 91 | + ) |
| 92 | + out = torch.ops._c10d_functional.wait_tensor(out) |
| 93 | + return out |
| 94 | + |
| 95 | + def using_mxfp8( |
| 96 | + input_tensor: torch.Tensor, input_splits: torch.Tensor |
| 97 | + ) -> torch.Tensor: |
| 98 | + output, output_splits = mxfp8_on_device_all_to_all_v( |
| 99 | + input_tensor, |
| 100 | + input_splits, |
| 101 | + max_output_tokens_per_rank, |
| 102 | + dist.group.WORLD.group_name, |
| 103 | + ) |
| 104 | + output = torch.ops._c10d_functional.wait_tensor(output) |
| 105 | + output_splits = torch.ops._c10d_functional.wait_tensor(output_splits) |
| 106 | + return output |
| 107 | + |
| 108 | + def warmup(func_no_args): |
| 109 | + for _ in range(2): |
| 110 | + func_no_args() |
| 111 | + |
| 112 | + num_splits = dist.get_world_size() |
| 113 | + input_splits = generate_split_sizes( |
| 114 | + num_splits, input_tokens_per_rank, device=device |
| 115 | + ) |
| 116 | + |
| 117 | + print( |
| 118 | + "Benchmarking using bf16", |
| 119 | + "batch_size", |
| 120 | + batch_size, |
| 121 | + "seq_len", |
| 122 | + seq_len, |
| 123 | + "dim", |
| 124 | + dim, |
| 125 | + "input_tokens_per_rank", |
| 126 | + input_tokens_per_rank, |
| 127 | + "max_output_tokens_per_rank", |
| 128 | + max_output_tokens_per_rank, |
| 129 | + ) |
| 130 | + warmup(lambda: using_bf16(ref_x, input_splits)) |
| 131 | + start_ns = time.perf_counter() |
| 132 | + using_bf16(ref_x, input_splits) |
| 133 | + end_ns = time.perf_counter() |
| 134 | + bf16_us = (end_ns - start_ns) * 1e6 |
| 135 | + |
| 136 | + print( |
| 137 | + "Benchmarking using_mxfp8", |
| 138 | + "batch_size", |
| 139 | + batch_size, |
| 140 | + "seq_len", |
| 141 | + seq_len, |
| 142 | + "dim", |
| 143 | + dim, |
| 144 | + "input_tokens_per_rank", |
| 145 | + input_tokens_per_rank, |
| 146 | + "max_output_tokens_per_rank", |
| 147 | + max_output_tokens_per_rank, |
| 148 | + ) |
| 149 | + warmup(lambda: using_mxfp8(x, input_splits)) |
| 150 | + start_ns = time.perf_counter() |
| 151 | + using_mxfp8(x, input_splits) |
| 152 | + end_ns = time.perf_counter() |
| 153 | + mxfp8_us = (end_ns - start_ns) * 1e6 |
| 154 | + |
| 155 | + return ExperimentResult( |
| 156 | + bf16_us=bf16_us, |
| 157 | + mxfp8_us=mxfp8_us, |
| 158 | + ) |
| 159 | + |
| 160 | + |
| 161 | +def print_results(experiments: List[Experiment]): |
| 162 | + headers = [ |
| 163 | + "input_shape", |
| 164 | + "num_splits", |
| 165 | + "bf16_us", |
| 166 | + "mxfp8_us", |
| 167 | + ] |
| 168 | + rows = [] |
| 169 | + num_splits = dist.get_world_size() |
| 170 | + for experiment in experiments: |
| 171 | + rows.append( |
| 172 | + [ |
| 173 | + str(experiment.config.input_shape), |
| 174 | + num_splits, |
| 175 | + experiment.result.bf16_us, |
| 176 | + experiment.result.mxfp8_us, |
| 177 | + ] |
| 178 | + ) |
| 179 | + print(tabulate(rows, headers=headers)) |
| 180 | + |
| 181 | + |
| 182 | +def generate_split_sizes(K: int, N: int, device: str = "cuda") -> torch.Tensor: |
| 183 | + """ |
| 184 | + Generates a tensor of K random non-negative integers that sum to N. |
| 185 | + Used for testing mxfp8_all_to_all_v implementation. |
| 186 | + """ |
| 187 | + if K <= 0: |
| 188 | + raise ValueError("K must be a positive integer.") |
| 189 | + if N < 0: |
| 190 | + raise ValueError("N must be a non-negative integer.") |
| 191 | + |
| 192 | + if K == 1: |
| 193 | + return torch.tensor([N], dtype=torch.long, device=device) |
| 194 | + |
| 195 | + # Generate K-1 random "dividers" in the range [0, N]. |
| 196 | + dividers = torch.randint(0, N + 1, (K - 1,), device=device) |
| 197 | + |
| 198 | + # Add 0 and N to the set of dividers to form the boundaries. |
| 199 | + boundaries = torch.cat( |
| 200 | + [torch.tensor([0], device=device), dividers, torch.tensor([N], device=device)] |
| 201 | + ) |
| 202 | + |
| 203 | + # Sort the boundaries to ensure they are in order |
| 204 | + sorted_boundaries = torch.sort(boundaries).values |
| 205 | + |
| 206 | + # The K integers are the differences between consecutive boundaries (will sum to N) |
| 207 | + result = sorted_boundaries[1:] - sorted_boundaries[:-1] |
| 208 | + |
| 209 | + return result.to(dtype=torch.int64) |
| 210 | + |
| 211 | + |
| 212 | +def main(): |
| 213 | + torch.random.manual_seed(123) |
| 214 | + |
| 215 | + # Set up process group |
| 216 | + setup_distributed() |
| 217 | + |
| 218 | + # Generate experiment configs |
| 219 | + configs = get_configs() |
| 220 | + results = [] |
| 221 | + for config in tqdm(configs): |
| 222 | + result = run_experiment(config) |
| 223 | + results.append(Experiment(config=config, result=result)) |
| 224 | + |
| 225 | + # Use Tabulate to print results |
| 226 | + print_results(results) |
| 227 | + |
| 228 | + # Clean up process group |
| 229 | + dist.destroy_process_group() |
| 230 | + |
| 231 | + |
| 232 | +def setup_distributed(): |
| 233 | + rank = int(os.environ["RANK"]) |
| 234 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 235 | + dist.init_process_group("nccl", rank=rank, world_size=world_size) |
| 236 | + torch.cuda.set_device(rank) |
| 237 | + |
| 238 | + |
| 239 | +if __name__ == "__main__": |
| 240 | + main() |
0 commit comments