1010# torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
1111#
1212#######################################################################
13+ import argparse
1314import os
1415import time
1516from dataclasses import dataclass
1819import torch
1920from tabulate import tabulate
2021from torch import distributed as dist
22+ from torch .distributed import DeviceMesh , init_device_mesh
2123from torch .distributed ._functional_collectives import (
24+ all_to_all_single ,
2225 all_to_all_single_autograd ,
2326)
2427from tqdm import tqdm
2528
29+ from benchmarks .utils import profile_fn
2630from torchao .prototype .moe_training .kernels .mxfp8 .comms import (
2731 mxfp8_on_device_all_to_all_v ,
2832)
@@ -37,8 +41,8 @@ class ExperimentConfig:
3741
3842@dataclass (frozen = True )
3943class ExperimentResult :
40- bf16_us : float
41- mxfp8_us : float
44+ bf16_ms : float
45+ mxfp8_ms : float
4246
4347
4448@dataclass (frozen = True )
@@ -50,7 +54,7 @@ class Experiment:
5054def get_configs () -> List [ExperimentConfig ]:
5155 # (batch_size, seq_len, dim)
5256 input_shapes = [
53- (8 , 8192 , 5120 ),
57+ (16 , 8192 , 5120 ),
5458 ]
5559 configs = []
5660 for shape in input_shapes :
@@ -62,7 +66,111 @@ def get_configs() -> List[ExperimentConfig]:
6266 return configs
6367
6468
65- def run_experiment (config : ExperimentConfig ) -> ExperimentResult :
69+ # Copy/paste a2a impls added in https://github.com/pytorch/torchtitan/pull/1765
70+ def default_a2a_dispatch (
71+ routed_input : torch .Tensor ,
72+ num_tokens_per_expert : torch .Tensor ,
73+ device_mesh : DeviceMesh ,
74+ ):
75+ """
76+ Default implementation of all-to-all dispatch. Incurs device-to-host sync.
77+
78+ Returns:
79+ routed_input: the local tokens after all-to-all dispatch
80+ input_splits: the input splits for all-to-all dispatch
81+ output_splits: the output splits for all-to-all dispatch
82+ num_tokens_per_expert_group: the number of tokens per EP rank after all-to-all dispatch
83+ """
84+ ep_degree = device_mesh .size (0 )
85+ # generate the input splits and output splits for all-to-all
86+ with torch .no_grad ():
87+ num_tokens_per_expert_group = all_to_all_single (
88+ num_tokens_per_expert ,
89+ None ,
90+ None ,
91+ group = device_mesh .get_group (),
92+ )
93+ # Need to wait explicitly because it is used by a triton kernel later
94+ # which doesn't realize that AsyncCollectiveTensor needs unwrapping
95+ num_tokens_per_expert_group = torch .ops ._c10d_functional .wait_tensor (
96+ num_tokens_per_expert_group
97+ )
98+ input_splits = (
99+ num_tokens_per_expert .view (ep_degree , - 1 )
100+ .sum (dim = 1 )
101+ .to (torch .device ("cpu" ), non_blocking = True )
102+ )
103+ # NOTE: this would incur a device-to-host sync
104+ output_splits = (
105+ num_tokens_per_expert_group .view (ep_degree , - 1 )
106+ .sum (dim = 1 )
107+ .to (torch .device ("cpu" ), non_blocking = False )
108+ )
109+ input_splits_list = input_splits .tolist ()
110+ output_splits_list = output_splits .tolist ()
111+
112+ # perform all-to-all
113+ routed_input = all_to_all_single_autograd (
114+ routed_input ,
115+ output_splits_list ,
116+ input_splits_list ,
117+ device_mesh .get_group (),
118+ )
119+ routed_input = torch .ops ._c10d_functional .wait_tensor (routed_input )
120+ return (
121+ routed_input ,
122+ input_splits_list ,
123+ output_splits_list ,
124+ num_tokens_per_expert_group ,
125+ )
126+
127+
128+ def mxfp8_a2a_dispatch (
129+ routed_input : torch .Tensor ,
130+ num_tokens_per_expert : torch .Tensor ,
131+ device_mesh : DeviceMesh ,
132+ max_tokens_per_ep_rank : int ,
133+ ):
134+ """
135+ Perform on-device all-to-all dispatch with dynamically quantized mxfp8 inputs to save network bandwidth
136+ and avoid device-to-host sync.
137+
138+ Returns:
139+ routed_input: the local tokens after all-to-all dispatch
140+ input_splits: the input splits for all-to-all dispatch
141+ output_splits: the output splits for all-to-all dispatch
142+ """
143+
144+ ep_degree = device_mesh .size (0 )
145+ num_tokens_per_expert_group = all_to_all_single (
146+ num_tokens_per_expert ,
147+ None ,
148+ None ,
149+ group = device_mesh .get_group (),
150+ )
151+ input_splits_per_ep_rank = num_tokens_per_expert .view (ep_degree , - 1 ).sum (dim = 1 )
152+ num_tokens_per_expert_group = torch .ops ._c10d_functional .wait_tensor (
153+ num_tokens_per_expert_group
154+ )
155+ routed_input , output_splits_per_ep_rank = mxfp8_on_device_all_to_all_v (
156+ routed_input ,
157+ input_splits_per_ep_rank ,
158+ max_tokens_per_ep_rank ,
159+ device_mesh .get_group ().group_name ,
160+ )
161+ tokens_on_rank_after_a2a = output_splits_per_ep_rank .sum ()
162+ routed_input_no_padding = routed_input [:tokens_on_rank_after_a2a ]
163+ return (
164+ routed_input_no_padding ,
165+ input_splits_per_ep_rank ,
166+ output_splits_per_ep_rank ,
167+ num_tokens_per_expert_group ,
168+ )
169+
170+
171+ def run_experiment (
172+ config : ExperimentConfig , args : argparse .Namespace
173+ ) -> ExperimentResult :
66174 batch_size , seq_len , dim = config .input_shape
67175 x = torch .randn (
68176 (batch_size * seq_len , dim ),
@@ -71,99 +179,70 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
71179 )
72180 ref_x = x .detach ().clone ()
73181
182+ # Set up device mesh
183+ mesh = init_device_mesh ("cuda" , (dist .get_world_size (),))
184+
74185 # Max output tokens per rank is worst case where one rank receives all tokens
75186 input_tokens_per_rank = batch_size * seq_len
76187 max_output_tokens_per_rank = input_tokens_per_rank * dist .get_world_size ()
77188
78- def using_bf16 (
79- input_tensor : torch .Tensor , input_splits : torch .Tensor
80- ) -> torch .Tensor :
81- # Calculate output splits from input splits
82- output_splits = torch .empty_like (input_splits )
83- dist .all_to_all_single (output_splits , input_splits )
84-
85- # Perform all-to-all
86- out = all_to_all_single_autograd (
87- input_tensor ,
88- output_splits .tolist (),
89- input_splits .tolist (),
90- dist .group .WORLD ,
91- )
92- out = torch .ops ._c10d_functional .wait_tensor (out )
93- return out
94-
95- def using_mxfp8 (
96- input_tensor : torch .Tensor , input_splits : torch .Tensor
97- ) -> torch .Tensor :
98- output , output_splits = mxfp8_on_device_all_to_all_v (
99- input_tensor ,
100- input_splits ,
101- max_output_tokens_per_rank ,
102- dist .group .WORLD .group_name ,
103- )
104- output = torch .ops ._c10d_functional .wait_tensor (output )
105- output_splits = torch .ops ._c10d_functional .wait_tensor (output_splits )
106- return output
107-
108189 def warmup (func_no_args ):
109190 for _ in range (2 ):
110191 func_no_args ()
111192
112- num_splits = dist .get_world_size ()
193+ num_experts_per_rank = 2
194+ num_splits = dist .get_world_size () * num_experts_per_rank
113195 input_splits = generate_split_sizes (
114196 num_splits , input_tokens_per_rank , device = device
115197 )
116198
117- print (
118- "Benchmarking using bf16" ,
119- "batch_size" ,
120- batch_size ,
121- "seq_len" ,
122- seq_len ,
123- "dim" ,
124- dim ,
125- "input_tokens_per_rank" ,
126- input_tokens_per_rank ,
127- "max_output_tokens_per_rank" ,
128- max_output_tokens_per_rank ,
129- )
130- warmup (lambda : using_bf16 (ref_x , input_splits ))
131- start_ns = time .perf_counter ()
132- using_bf16 (ref_x , input_splits )
133- end_ns = time .perf_counter ()
134- bf16_us = (end_ns - start_ns ) * 1e6
135-
136- print (
137- "Benchmarking using_mxfp8" ,
138- "batch_size" ,
139- batch_size ,
140- "seq_len" ,
141- seq_len ,
142- "dim" ,
143- dim ,
144- "input_tokens_per_rank" ,
145- input_tokens_per_rank ,
146- "max_output_tokens_per_rank" ,
147- max_output_tokens_per_rank ,
199+ # Bench default a2a
200+ warmup (lambda : default_a2a_dispatch (ref_x , input_splits , mesh ))
201+ start_sec = time .perf_counter ()
202+ default_a2a_dispatch (ref_x , input_splits , mesh )
203+ end_sec = time .perf_counter ()
204+ bf16_ms = (end_sec - start_sec ) * 1e3
205+ if args .profile :
206+ profile_fn (
207+ default_a2a_dispatch ,
208+ ref_x ,
209+ input_splits ,
210+ mesh ,
211+ distributed = True ,
212+ profile_name = "all_to_all_single_autograd" ,
213+ )
214+
215+ # Bench mxfp8 a2a
216+ warmup (
217+ lambda : mxfp8_a2a_dispatch (x , input_splits , mesh , max_output_tokens_per_rank )
148218 )
149- warmup (lambda : using_mxfp8 (x , input_splits ))
150- start_ns = time .perf_counter ()
151- using_mxfp8 (x , input_splits )
152- end_ns = time .perf_counter ()
153- mxfp8_us = (end_ns - start_ns ) * 1e6
219+ start_sec = time .perf_counter ()
220+ mxfp8_a2a_dispatch (x , input_splits , mesh , max_output_tokens_per_rank )
221+ end_sec = time .perf_counter ()
222+ mxfp8_ms = (end_sec - start_sec ) * 1e3
223+ if args .profile :
224+ profile_fn (
225+ mxfp8_a2a_dispatch ,
226+ x ,
227+ input_splits ,
228+ mesh ,
229+ max_output_tokens_per_rank ,
230+ distributed = True ,
231+ profile_name = "mxfp8_all_to_all_v" ,
232+ )
154233
155234 return ExperimentResult (
156- bf16_us = bf16_us ,
157- mxfp8_us = mxfp8_us ,
235+ bf16_ms = bf16_ms ,
236+ mxfp8_ms = mxfp8_ms ,
158237 )
159238
160239
161240def print_results (experiments : List [Experiment ]):
162241 headers = [
163242 "input_shape" ,
164243 "num_splits" ,
165- "bf16_us " ,
166- "mxfp8_us " ,
244+ "bf16_ms " ,
245+ "mxfp8_ms " ,
167246 ]
168247 rows = []
169248 num_splits = dist .get_world_size ()
@@ -172,8 +251,8 @@ def print_results(experiments: List[Experiment]):
172251 [
173252 str (experiment .config .input_shape ),
174253 num_splits ,
175- experiment .result .bf16_us ,
176- experiment .result .mxfp8_us ,
254+ experiment .result .bf16_ms ,
255+ experiment .result .mxfp8_ms ,
177256 ]
178257 )
179258 print (tabulate (rows , headers = headers ))
@@ -209,7 +288,7 @@ def generate_split_sizes(K: int, N: int, device: str = "cuda") -> torch.Tensor:
209288 return result .to (dtype = torch .int64 )
210289
211290
212- def main ():
291+ def main (args : argparse . Namespace ):
213292 torch .random .manual_seed (123 )
214293
215294 # Set up process group
@@ -219,7 +298,7 @@ def main():
219298 configs = get_configs ()
220299 results = []
221300 for config in tqdm (configs ):
222- result = run_experiment (config )
301+ result = run_experiment (config , args )
223302 results .append (Experiment (config = config , result = result ))
224303
225304 # Use Tabulate to print results
@@ -237,4 +316,7 @@ def setup_distributed():
237316
238317
239318if __name__ == "__main__" :
240- main ()
319+ parser = argparse .ArgumentParser ()
320+ parser .add_argument ("--profile" , action = "store_true" )
321+ args = parser .parse_args ()
322+ main (args )
0 commit comments