Skip to content

Commit ac529da

Browse files
committed
fix busbw calculation of uneven all_to_all
1 parent 263ac16 commit ac529da

File tree

1 file changed

+35
-6
lines changed

1 file changed

+35
-6
lines changed

et_replay/comm/profiler_trace_analysis.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import logging
44
import os
5+
import re
56
import pathlib
67
from collections import defaultdict
78
from typing import Any, Callable, Dict
@@ -138,8 +139,28 @@ def _get_event_busbw_factor(evt):
138139

139140
return correction_factor_func(group_size)
140141

141-
142-
def calculate_bw_(trace_data):
142+
def _calculate_busbw_for_uneven_all_to_all(evt, global_rank):
143+
group_size = evt["args"]["Group size"]
144+
local_rank = _parse_ranks(evt["args"]["Process Group Ranks"], group_size).index(global_rank)
145+
in_elems_count = evt["args"]["In msg nelems"]
146+
out_elems_count = evt["args"]["Out msg nelems"]
147+
in_split_size = ast.literal_eval(evt["args"]["In split size"])
148+
out_split_size = ast.literal_eval(evt["args"]["Out split size"])
149+
dtype_size = _dtype_size_map[evt["args"]["dtype"]]
150+
151+
if in_split_size:
152+
send_elems = in_elems_count - in_split_size[local_rank]
153+
else:
154+
send_elems = in_elems_count / group_size * (group_size - 1)
155+
156+
if out_split_size:
157+
recv_elems = out_elems_count - out_split_size[local_rank]
158+
else:
159+
recv_elems = out_elems_count / group_size * (group_size - 1)
160+
161+
return round(max(send_elems, recv_elems) * dtype_size / evt["dur"] * 1e-3, 2)
162+
163+
def calculate_bw_(trace_data, global_rank):
143164
nccl_events = [
144165
i
145166
for i in trace_data["traceEvents"]
@@ -163,7 +184,14 @@ def calculate_bw_(trace_data):
163184

164185
algbw = _calculate_algbw(evt)
165186
busbw_factor = _get_event_busbw_factor(evt)
166-
busbw = round(algbw * busbw_factor, 2)
187+
if (coll_name in ["all_to_all", "all_to_allv"]
188+
and (ast.literal_eval(evt['args']['In split size'])
189+
or ast.literal_eval(evt['args']['Out split size']))
190+
):
191+
# calculate busbw for uneven all_to_all
192+
busbw = _calculate_busbw_for_uneven_all_to_all(evt, global_rank)
193+
else:
194+
busbw = round(algbw * busbw_factor, 2)
167195

168196
evt["args"]["algbw (GB/sec)"] = algbw
169197
evt["args"]["busbw (GB/sec)"] = busbw
@@ -282,18 +310,19 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
282310
# list of shared bw
283311
sbw_lst = []
284312

285-
# key is (kernel_name, data size, ranks number)
313+
# key is (kernel_name, coll name, data size, ranks count)
286314
# value is list of [dur, algbw, busbw, pg]
287315
comm_bw_data = defaultdict(list)
288316

289317
for fpath in os.scandir(trace_dir):
290318
if not fpath.is_file():
291319
continue
292-
320+
321+
global_rank = int(re.search(r"rank-(\d+)", fpath.name).group(1))
293322
with open(fpath.path, "r", encoding="utf-8") as f:
294323
trace = json.load(f)
295324

296-
calculate_bw_(trace)
325+
calculate_bw_(trace, global_rank)
297326
with open(
298327
os.path.join(processed_trace_dir, fpath.name), "w", encoding="utf-8"
299328
) as f:

0 commit comments

Comments
 (0)