22import json
33import logging
44import os
5+ import re
56import pathlib
67from collections import defaultdict
78from typing import Any , Callable , Dict
@@ -138,8 +139,28 @@ def _get_event_busbw_factor(evt):
138139
139140 return correction_factor_func (group_size )
140141
141-
142- def calculate_bw_ (trace_data ):
142+ def _calculate_busbw_for_uneven_all_to_all (evt , global_rank ):
143+ group_size = evt ["args" ]["Group size" ]
144+ local_rank = _parse_ranks (evt ["args" ]["Process Group Ranks" ], group_size ).index (global_rank )
145+ in_elems_count = evt ["args" ]["In msg nelems" ]
146+ out_elems_count = evt ["args" ]["Out msg nelems" ]
147+ in_split_size = ast .literal_eval (evt ["args" ]["In split size" ])
148+ out_split_size = ast .literal_eval (evt ["args" ]["Out split size" ])
149+ dtype_size = _dtype_size_map [evt ["args" ]["dtype" ]]
150+
151+ if in_split_size :
152+ send_elems = in_elems_count - in_split_size [local_rank ]
153+ else :
154+ send_elems = in_elems_count / group_size * (group_size - 1 )
155+
156+ if out_split_size :
157+ recv_elems = out_elems_count - out_split_size [local_rank ]
158+ else :
159+ recv_elems = out_elems_count / group_size * (group_size - 1 )
160+
161+ return round (max (send_elems , recv_elems ) * dtype_size / evt ["dur" ] * 1e-3 , 2 )
162+
163+ def calculate_bw_ (trace_data , global_rank ):
143164 nccl_events = [
144165 i
145166 for i in trace_data ["traceEvents" ]
@@ -163,7 +184,14 @@ def calculate_bw_(trace_data):
163184
164185 algbw = _calculate_algbw (evt )
165186 busbw_factor = _get_event_busbw_factor (evt )
166- busbw = round (algbw * busbw_factor , 2 )
187+ if (coll_name in ["all_to_all" , "all_to_allv" ]
188+ and (ast .literal_eval (evt ['args' ]['In split size' ])
189+ or ast .literal_eval (evt ['args' ]['Out split size' ]))
190+ ):
191+ # calculate busbw for uneven all_to_all
192+ busbw = _calculate_busbw_for_uneven_all_to_all (evt , global_rank )
193+ else :
194+ busbw = round (algbw * busbw_factor , 2 )
167195
168196 evt ["args" ]["algbw (GB/sec)" ] = algbw
169197 evt ["args" ]["busbw (GB/sec)" ] = busbw
@@ -282,18 +310,19 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
282310 # list of shared bw
283311 sbw_lst = []
284312
285- # key is (kernel_name, data size, ranks number )
313+ # key is (kernel_name, coll name, data size, ranks count )
286314 # value is list of [dur, algbw, busbw, pg]
287315 comm_bw_data = defaultdict (list )
288316
289317 for fpath in os .scandir (trace_dir ):
290318 if not fpath .is_file ():
291319 continue
292-
320+
321+ global_rank = int (re .search (r"rank-(\d+)" , fpath .name ).group (1 ))
293322 with open (fpath .path , "r" , encoding = "utf-8" ) as f :
294323 trace = json .load (f )
295324
296- calculate_bw_ (trace )
325+ calculate_bw_ (trace , global_rank )
297326 with open (
298327 os .path .join (processed_trace_dir , fpath .name ), "w" , encoding = "utf-8"
299328 ) as f :
0 commit comments