-
Notifications
You must be signed in to change notification settings - Fork 188
/
profiling_utils.py
146 lines (123 loc) · 5.7 KB
/
profiling_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import os
import time
import logging
import torch
import torch.distributed
from functools import partial
import shutil
from torch.profiler import tensorboard_trace_handler
WARMUP = 3
logger = logging.getLogger()
#adapted from https://github.com/pytorch/torchtitan
def trace_handler(prof: torch.profiler.profiler.profile, rank, export_memory_timeline, output_dir, metric="self_cuda_time_total", with_stack=True, group_by_stack=0, group_by_input_shape=False, row_limit=25):
curr_trace_dir_name = "iteration_" + str(prof.step_num)
curr_trace_dir = os.path.join(output_dir, curr_trace_dir_name)
if not os.path.exists(curr_trace_dir):
os.makedirs(curr_trace_dir, exist_ok=True)
#Export chrome / tensorboard trace
logger.info(f"Dumping traces at step {prof.step_num}")
begin = time.monotonic()
#Use tensorboard trace handler rather than directly exporting chrome traces since
#tensorboard doesn't seem to be able to parse traces when with prof.export_chrome_trace
exporter = tensorboard_trace_handler(curr_trace_dir, worker_name=f"rank{rank}", use_gzip=True)
exporter(prof)
#prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")
logger.info(
f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds"
)
#Construct the memory timeline file.
if export_memory_timeline:
try:
prof.export_memory_timeline(
f"{curr_trace_dir}/rank{rank}_memory-timeline.html"
)
except:
logger.info("Failed to export memory timeline to html, retrying as gzipped json.")
try:
prof.export_memory_timeline(
f"{curr_trace_dir}/rank{rank}_memory-timeline.json.gz"
)
except:
logger.info("Failed to export memory timeline to gzipped json. Saving profiler timeline object instead.")
from torch.profiler._memory_profiler import MemoryProfileTimeline
memory_profile = MemoryProfileTimeline(prof._memory_profile())
torch.save(memory_profile, f"{curr_trace_dir}/rank{rank}_memory-timeline.pt")
#Dump stack traces
if with_stack:
prof.export_stacks(f"{curr_trace_dir}/rank{rank}_stacks.txt", metric=metric)
#Export event averages
key_avgs = prof.key_averages(
group_by_input_shape=group_by_input_shape, group_by_stack_n=group_by_stack
).table(sort_by=metric, row_limit=row_limit)
with open(f"{curr_trace_dir}/rank{rank}_key_averages.txt", "w") as f:
print(
key_avgs, file=f
)
if rank == 0:
print(f"Saving profiling results to {curr_trace_dir}")
#TODO: Is this necessary?
torch.distributed.barrier()
@contextlib.contextmanager
def profiling_context(args, rank):
enable_profiling = args["profile"]
if enable_profiling:
model_name = args["model_name"].split("/")[-1]
train_type = args["train_type"]
output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}"
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Profiling enabled. Traces will be saved at {output_dir}")
warmup = args["warmup_steps"]
active = args["active_steps"]
repeat = args["repeat"]
if repeat == 0:
steps_per_cycle = args["profiling_frequency"]
wait = steps_per_cycle - (active + warmup)
else:
wait = args["wait_steps"]
steps_per_cycle = wait + warmup + active
assert (
wait >= 0
), "profile_freq must be greater than or equal to warmup + active"
logger.info(f"Profiler schedule - steps per cycle: {steps_per_cycle} wait: {wait} warmup: {warmup} active: {active} repeat: {repeat if repeat !=0 else 'inf'}")
profile_memory = args["export_memory_timeline"]
export_memory_timeline = args["export_memory_timeline"]
with_stack = args["with_stack"] or args["export_memory_timeline"]
with_shapes = args["with_shapes"] or export_memory_timeline
callback = partial(trace_handler, rank=rank,
export_memory_timeline=export_memory_timeline,
output_dir=output_dir,
with_stack=with_stack,
group_by_input_shape=with_shapes,
group_by_stack=5 if with_stack else 0)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=with_stack,
profile_memory=profile_memory,
record_shapes=with_shapes,
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat),
on_trace_ready=callback,
experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True) if with_stack else None,
) as torch_profiler:
yield torch_profiler
else:
class FakeProfiler:
"""
Fake profiler object when profiling is not enabled.
"""
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
pass
def step(self):
pass
yield FakeProfiler()