2121see README.md for more details 
2222""" 
2323
24+ from  contextlib  import  nullcontext 
2425from  dataclasses  import  dataclass 
2526from  typing  import  Any , Dict , List , Optional 
2627
@@ -54,6 +55,8 @@ class AllToAllSingleRunConfig(BenchFuncConfig):
5455    num_profiles : int  =  2 
5556    num_mul : int  =  5 
5657    num_concat : int  =  100 
58+     multi_stream : bool  =  True 
59+     main_stream_allocation : bool  =  False 
5760
5861
5962def  _compute (
@@ -94,6 +97,7 @@ def a2a_sync_base(
9497    num_mul : int ,
9598    num_concat : int ,
9699    ctx : MultiProcessContext ,
100+     ** _kwargs : Dict [str , Any ],
97101) ->  None :
98102    with  record_function ("## pre-comms compute ##" ):
99103        pre_comms  =  _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
@@ -186,6 +190,7 @@ def a2a_async_twice(
186190    num_mul : int ,
187191    num_concat : int ,
188192    ctx : MultiProcessContext ,
193+     ** _kwargs : Dict [str , Any ],
189194) ->  None :
190195    with  record_function ("## pre-comms compute ##" ):
191196        pre_comms  =  _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
@@ -254,13 +259,14 @@ def a2a_async_twice(
254259        assert  checks1  and  checks2 
255260
256261
257- # all_to_all_single with sync and single stream  
262+ # LazyAwaitable  
258263def  lazyawaitable (
259264    _batch_inputs : List [Dict [str , Any ]],
260265    dim : int ,
261266    num_mul : int ,
262267    num_concat : int ,
263268    ctx : MultiProcessContext ,
269+     ** _kwargs : Dict [str , Any ],
264270) ->  None :
265271    with  record_function ("## pre-comms compute ##" ):
266272        pre_comms  =  _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
@@ -294,6 +300,149 @@ def lazyawaitable(
294300        assert  check_awaitable .item ()
295301
296302
303+ # muti-stream memory footprint 
304+ def  multi_stream_memory (
305+     _batch_inputs : List [Dict [str , Any ]],
306+     dim : int ,
307+     num_mul : int ,
308+     num_concat : int ,
309+     ctx : MultiProcessContext ,
310+     multi_stream : bool ,
311+     ** _kwargs : Dict [str , Any ],
312+ ) ->  None :
313+     with  record_function ("## setup ##" ):
314+         main_stream  =  torch .cuda .current_stream ()
315+         data_copy_stream  =  torch .cuda .Stream () if  multi_stream  else  nullcontext ()
316+         data_dist_stream  =  torch .cuda .Stream () if  multi_stream  else  nullcontext ()
317+         irrelevant_data  =  torch .rand (dim , dim , device = ctx .device ) -  0.5 
318+ 
319+         # the host to device data transfer will block cuda execution without the `pin_memory()` 
320+         host_data  =  (torch .rand (dim , dim ) -  0.5 ).pin_memory ()
321+ 
322+     with  record_function ("## irrelevant compute before h2d ##" ):
323+         pre_comms  =  _compute (
324+             dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = irrelevant_data 
325+         )
326+ 
327+     with  record_function ("## copy data to device ##" ):
328+         with  data_copy_stream :
329+             device_data  =  host_data .to (ctx .device , non_blocking = True )
330+ 
331+     with  record_function ("## irrelevant compute after h2d ##" ):
332+         pre_comms  =  _compute (
333+             dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = irrelevant_data 
334+         )
335+ 
336+     with  record_function ("## pre-comms compute ##" ):
337+         if  data_copy_stream  is  torch .cuda .Stream :
338+             main_stream .wait_stream (data_copy_stream )
339+             device_data .record_stream (main_stream )
340+         pre_comms  =  _compute (
341+             dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = device_data 
342+         )
343+ 
344+     with  data_dist_stream :
345+         with  record_function ("## all_to_all_single ##" ):
346+             if  data_dist_stream  is  torch .cuda .Stream :
347+                 data_dist_stream .wait_stream (main_stream )  # pyre-ignore[16] 
348+             post_comms  =  torch .zeros_like (pre_comms )
349+             req  =  dist .all_to_all_single (
350+                 output = post_comms ,
351+                 input = pre_comms ,
352+                 group = ctx .pg ,
353+                 async_op = True ,
354+             )
355+         with  record_function ("## a2a comm validation ##" ):
356+             req .wait ()
357+             checks  =  DeviceToHostTensorAwaitable (_validate (post_comms , ctx ))
358+ 
359+     with  record_function ("## irrelevant compute after a2a ##" ):
360+         pre_comms  =  _compute (
361+             dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = irrelevant_data 
362+         )
363+ 
364+     with  record_function ("## post-comms compute ##" ):
365+         req .wait ()
366+         post_comms .record_stream (main_stream )
367+         post_comms  =  _compute (
368+             dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = post_comms [0 ]
369+         )
370+ 
371+     with  record_function ("## assert ##" ):
372+         assert  checks .item ()
373+ 
374+ 
375+ def  multi_stream_optimized (
376+     _batch_inputs : List [Dict [str , Any ]],
377+     dim : int ,
378+     num_mul : int ,
379+     num_concat : int ,
380+     ctx : MultiProcessContext ,
381+     ** _kwargs : Dict [str , Any ],
382+ ) ->  None :
383+     with  record_function ("## setup ##" ):
384+         main_stream  =  torch .cuda .current_stream ()
385+         data_copy_stream  =  torch .cuda .Stream ()
386+         data_dist_stream  =  torch .cuda .Stream ()
387+         irrelevant_data  =  torch .rand (dim , dim , device = ctx .device ) -  0.5 
388+ 
389+         # the host to device data transfer will block cuda execution without the `pin_memory()` 
390+         host_data  =  (torch .rand (dim , dim ) -  0.5 ).pin_memory ()
391+         device_data  =  torch .empty_like (host_data , device = ctx .device )
392+ 
393+     with  record_function ("## irrelevant compute before h2d ##" ):
394+         pre_comms  =  _compute (
395+             dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = irrelevant_data 
396+         )
397+ 
398+     with  record_function ("## copy data to device ##" ):
399+         with  data_copy_stream :
400+             device_data .record_stream (data_copy_stream )
401+             device_data .copy_ (host_data , non_blocking = True )
402+ 
403+     with  record_function ("## irrelevant compute after h2d ##" ):
404+         pre_comms  =  _compute (
405+             dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = irrelevant_data 
406+         )
407+ 
408+     with  record_function ("## pre-comms compute ##" ):
409+         if  data_copy_stream  is  torch .cuda .Stream :
410+             main_stream .wait_stream (data_copy_stream )
411+         pre_comms  =  _compute (
412+             dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = device_data 
413+         )
414+ 
415+     with  record_function ("## pre-allocate memory for a2a on main stream ##" ):
416+         post_comms  =  torch .zeros_like (pre_comms )
417+ 
418+     with  data_dist_stream :
419+         with  record_function ("## all_to_all_single ##" ):
420+             data_dist_stream .wait_stream (main_stream )
421+             req  =  dist .all_to_all_single (
422+                 output = post_comms ,
423+                 input = pre_comms ,
424+                 group = ctx .pg ,
425+                 async_op = True ,
426+             )
427+ 
428+     with  record_function ("## irrelevant compute after a2a ##" ):
429+         pre_comms  =  _compute (
430+             dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = irrelevant_data 
431+         )
432+ 
433+     with  record_function ("## a2a comm validation ##" ):
434+         req .wait ()
435+         checks  =  DeviceToHostTensorAwaitable (_validate (post_comms , ctx ))
436+ 
437+     with  record_function ("## post-comms compute ##" ):
438+         post_comms  =  _compute (
439+             dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = post_comms [0 ]
440+         )
441+ 
442+     with  record_function ("## assert ##" ):
443+         assert  checks .item ()
444+ 
445+ 
297446# single-rank runner 
298447def  a2a_single_runner (rank : int , world_size : int , arg : AllToAllSingleRunConfig ) ->  None :
299448    # Ensure GPUs are available and we have enough of them 
@@ -317,6 +466,10 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
317466            func  =  a2a_async_twice 
318467        elif  arg .name .startswith ("lazyawaitable" ):
319468            func  =  lazyawaitable 
469+         elif  arg .name .startswith ("multi_stream_memory" ):
470+             func  =  multi_stream_memory 
471+         elif  arg .name .startswith ("multi_stream_optimized" ):
472+             func  =  multi_stream_optimized 
320473        else :
321474            raise  ValueError (f"Unknown benchmark name: { arg .name }  )
322475
@@ -328,6 +481,8 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
328481                "dim" : arg .dim ,
329482                "num_mul" : arg .num_mul ,
330483                "num_concat" : arg .num_concat ,
484+                 "multi_stream" : arg .multi_stream ,
485+                 "main_stream_allocation" : arg .main_stream_allocation ,
331486            },
332487            func_to_benchmark = func ,
333488            rank = rank ,
0 commit comments