Skip to content

Commit f423b62

Browse files
committed
fix creating leaf folder
Summary: the leaf folder wasn't being created so and no profiles were being written, so create it if it doesn't exist
1 parent 3fe45dc commit f423b62

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

torchtitan/tools/profiling.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,14 @@ def maybe_enable_profiling(
4040

4141
def trace_handler(prof):
4242
curr_trace_dir_name = "iteration_" + str(prof.step_num)
43-
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
43+
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name, leaf_folder)
4444
if not os.path.exists(curr_trace_dir):
4545
os.makedirs(curr_trace_dir, exist_ok=True)
4646

4747
logger.info(f"Dumping profiler traces at step {prof.step_num}")
4848
begin = time.monotonic()
4949

50-
output_file = os.path.join(
51-
curr_trace_dir, leaf_folder, f"rank{rank}_trace.json"
52-
)
50+
output_file = os.path.join(curr_trace_dir, f"rank{rank}_trace.json")
5351
prof.export_chrome_trace(output_file)
5452
logger.info(
5553
f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds"
@@ -123,13 +121,13 @@ def step(self, exit_ctx: bool = False):
123121
# dump as iteration_0_exit if OOM at iter 1
124122
curr_step = self.step_num - 1
125123
dir_name = f"iteration_{curr_step}_exit"
126-
curr_snapshot_dir = os.path.join(snapshot_dir, dir_name)
124+
curr_snapshot_dir = os.path.join(snapshot_dir, dir_name, leaf_folder)
127125
if not os.path.exists(curr_snapshot_dir):
128126
os.makedirs(curr_snapshot_dir, exist_ok=True)
129127
logger.info(f"Dumping memory snapshot at step {curr_step}")
130128
begin = time.monotonic()
131129
output_file = os.path.join(
132-
curr_snapshot_dir, leaf_folder, f"rank{rank}_memory_snapshot.pickle"
130+
curr_snapshot_dir, f"rank{rank}_memory_snapshot.pickle"
133131
)
134132
with open(output_file, "wb") as output:
135133
pickle.dump(torch.cuda.memory._snapshot(), output)

0 commit comments

Comments
 (0)