diff --git a/torchrec/distributed/benchmark/base.py b/torchrec/distributed/benchmark/base.py index dbd4e9a70..cac1d99f9 100644 --- a/torchrec/distributed/benchmark/base.py +++ b/torchrec/distributed/benchmark/base.py @@ -571,6 +571,11 @@ def main(self) -> None: ) +def create_trace_file_name(profile_name: str, rank: int) -> str: + """Create a unique trace file name for the given rank and profile name.""" + return f"trace-{profile_name}-rank{rank}.json.gz" + + def init_argparse_and_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -727,9 +732,10 @@ def _trace_handler(prof: torch.profiler.profile) -> None: if not all_rank_traces and rank > 0: # only save trace for rank 0 when all_rank_traces is disabled return - trace_file = f"{output_dir}/trace-{name}-rank{rank}.json" + trace_file = f"{output_dir}/{create_trace_file_name(name, rank)}" logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}") prof.export_chrome_trace(trace_file) + if export_stacks: prof.export_stacks( f"{output_dir}/stacks-cpu-{name}.stacks", "self_cpu_time_total"