Skip to content

Commit

Permalink
save pytorch profile data to csv
Browse files Browse the repository at this point in the history
  • Loading branch information
kgajdamo committed Apr 4, 2023
1 parent 82317c2 commit 0b52f29
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 35 deletions.
34 changes: 25 additions & 9 deletions benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import warnings
from collections import defaultdict
from contextlib import nullcontext

Expand Down Expand Up @@ -37,6 +38,10 @@ def full_batch_inference(model, data):
def run(args: argparse.ArgumentParser):
csv_data = defaultdict(list)

if args.write_csv == 'prof' and not args.profile:
warnings.warn(
"Cannot write profile data to csv because profiling is disabled.")

# cuda device is not suitable for full batch mode
device = torch.device(
'cuda' if not args.full_batch and torch.cuda.is_available() else 'cpu')
Expand Down Expand Up @@ -170,7 +175,8 @@ def run(args: argparse.ArgumentParser):
else:
cpu_affinity = nullcontext()
profile = torch_profile(
) if args.profile else nullcontext()
args.export_chrome_trace, csv_data,
args.write_csv) if args.profile else nullcontext()
itt = emit_itt(
) if args.vtune_profile else nullcontext()

Expand Down Expand Up @@ -213,7 +219,7 @@ def run(args: argparse.ArgumentParser):
print(f'Mini Batch Test Accuracy: \
{test_acc:.4f}')

if args.profile:
if args.profile and args.export_chrome_trace:
rename_profile_file(model_name, dataset_name,
str(batch_size), str(layers),
str(hidden_channels),
Expand All @@ -228,13 +234,20 @@ def run(args: argparse.ArgumentParser):
print(f'Throughput: {throughput:.3f} samples/s')
print(f'Latency: {latency:.3f} ms')

save_benchmark_data(csv_data, batch_size, layers,
num_neighbors, hidden_channels,
total_time, model_name,
dataset_name,
args.use_sparse_tensor)
num_records = 1
if args.write_csv == 'prof':
# For profiling with pytorch, we save top 5 most
# time consuming ops. Therefore, the same data
# should be entered for each of them.
num_records = 5
for _ in range(num_records):
save_benchmark_data(csv_data, batch_size, layers,
num_neighbors, hidden_channels,
total_time, model_name,
dataset_name,
args.use_sparse_tensor)
if args.write_csv:
write_to_csv(csv_data)
write_to_csv(csv_data, args.write_csv)


if __name__ == '__main__':
Expand Down Expand Up @@ -274,5 +287,8 @@ def run(args: argparse.ArgumentParser):
add('--full-batch', action='store_true', help='Use full batch mode')
add('--evaluate', action='store_true')
add('--ckpt_path', type=str, help='Checkpoint path for loading a model')
add('--write-csv', action='store_true', help='Write benchmark data to csv')
add('--write-csv', choices=[None, 'bench', 'prof'], default=None,
help='Write benchmark or pytorch profile data to csv.')
add('--export-chrome-trace', default=True, type=bool,
help='Export chrome trace file. Works only with pytorch profile')
run(argparser.parse_args())
44 changes: 31 additions & 13 deletions benchmark/training/training_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def train_hetero(model, loader, optimizer, device, progress_bar=True, desc="",
def run(args: argparse.ArgumentParser):
csv_data = defaultdict(list)

if args.write_csv == 'prof' and not args.profile:
warnings.warn(
"Cannot write profile data to csv because profiling is disabled.")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# If we use a custom number of steps, then we need to use RandomSampler,
# which already does shuffle.
Expand Down Expand Up @@ -242,15 +246,19 @@ def run(args: argparse.ArgumentParser):
print(f'Test Accuracy: {test_acc:.4f}')

if args.profile:
with torch_profile():
profile = torch_profile(
args.export_chrome_trace, csv_data,
args.write_csv)
with profile:
train(model, subgraph_loader, optimizer,
device, progress_bar=progress_bar,
desc="Profile training")
rename_profile_file(model_name, dataset_name,
str(batch_size),
str(layers),
str(hidden_channels),
str(num_neighbors))
if args.export_chrome_trace:
rename_profile_file(
model_name, dataset_name,
str(batch_size), str(layers),
str(hidden_channels),
str(num_neighbors))

total_time = t.duration
if args.num_steps != -1:
Expand All @@ -262,13 +270,20 @@ def run(args: argparse.ArgumentParser):
print(f'Throughput: {throughput:.3f} samples/s')
print(f'Latency: {latency:.3f} ms')

save_benchmark_data(csv_data, batch_size, layers,
num_neighbors, hidden_channels,
total_time, model_name,
dataset_name,
args.use_sparse_tensor)
num_records = 1
if args.write_csv == 'prof':
# For profiling with pytorch, we save top 5 most
# time consuming ops. Therefore, the same data
# should be entered for each of them.
num_records = 5
for _ in range(num_records):
save_benchmark_data(csv_data, batch_size, layers,
num_neighbors, hidden_channels,
total_time, model_name,
dataset_name,
args.use_sparse_tensor)
if args.write_csv:
write_to_csv(csv_data, training=True)
write_to_csv(csv_data, args.write_csv, training=True)


if __name__ == '__main__':
Expand Down Expand Up @@ -309,7 +324,10 @@ def run(args: argparse.ArgumentParser):
help='Enable filter-per-worker feature of the dataloader.')
add('--measure-load-time', action='store_true')
add('--evaluate', action='store_true')
add('--write-csv', action='store_true', help='Write benchmark data to csv')
add('--write-csv', choices=[None, 'bench', 'prof'], default=None,
help='Write benchmark or pytorch profile data to csv.')
add('--export-chrome-trace', default=True, type=bool,
help='Export chrome trace file. Works only with pytorch profile')
add('--trim', action='store_true', help="Use `trim_to_layer` optimization")
args = argparser.parse_args()

Expand Down
10 changes: 7 additions & 3 deletions benchmark/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,20 @@ def save_benchmark_data(csv_data, batch_size, layers, num_neighbors,
csv_data['SPARSE'].append(use_sparse_tensor)


def write_to_csv(csv_data, training=False):
def write_to_csv(csv_data, write_csv='bench', training=False):
import pandas as pd
results_path = osp.join(osp.dirname(osp.realpath(__file__)), '../results/')
os.makedirs(results_path, exist_ok=True)

name = 'training' if training else 'inference'
csv_path = osp.join(results_path, f'TOTAL_{name}_benchmark.csv')
csv_file_name = f'TOTAL_{name}_benchmark.csv' if write_csv == 'bench' \
else f'TOTAL_prof_{name}_benchmark.csv'
csv_path = osp.join(results_path, csv_file_name)
index_label = 'TEST_ID' if write_csv == 'bench' else 'ID'

with_header = not osp.exists(csv_path)
df = pd.DataFrame(csv_data)
df.to_csv(csv_path, mode='a', index_label='TEST_ID', header=with_header)
df.to_csv(csv_path, mode='a', index_label=index_label, header=with_header)


@torch.no_grad()
Expand Down
8 changes: 7 additions & 1 deletion torch_geometric/profile/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from .profile import profileit, timeit, get_stats_summary
from .profile import trace_handler, rename_profile_file, torch_profile
from .profile import (
trace_handler,
print_time_total,
rename_profile_file,
torch_profile,
)
from .utils import count_parameters
from .utils import get_model_size
from .utils import get_data_size
Expand All @@ -13,6 +18,7 @@
'timeit',
'get_stats_summary',
'trace_handler',
'print_time_total',
'rename_profile_file',
'torch_profile',
'count_parameters',
Expand Down
70 changes: 61 additions & 9 deletions torch_geometric/profile/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, List, NamedTuple, Tuple

import torch
from torch.autograd.profiler_util import EventList
from torch.profiler import ProfilerActivity, profile

from torch_geometric.profile.utils import (
Expand Down Expand Up @@ -206,17 +207,19 @@ def read_from_memlab(line_profiler: Any) -> List[float]: # pragma: no cover


def trace_handler(p):
if torch.cuda.is_available():
profile_sort = 'self_cuda_time_total'
else:
profile_sort = 'self_cpu_time_total'
output = p.key_averages().table(sort_by=profile_sort)
print(output)
print_time_total(p)
profile_dir = str(pathlib.Path.cwd()) + '/'
timeline_file = profile_dir + 'timeline' + '.json'
p.export_chrome_trace(timeline_file)


def print_time_total(p):
profile_sort = 'self_cuda_time_total' if torch.cuda.is_available(
) else 'self_cpu_time_total'
output = p.key_averages().table(sort_by=profile_sort)
print(output)


def rename_profile_file(*args):
profile_dir = str(pathlib.Path.cwd()) + '/'
timeline_file = profile_dir + 'profile'
Expand All @@ -227,11 +230,60 @@ def rename_profile_file(*args):


@contextmanager
def torch_profile():
def torch_profile(export_chrome_trace=True, csv_data=None, write_csv=None):
use_cuda = torch.cuda.is_available()
activities = [ProfilerActivity.CPU]
if torch.cuda.is_available():
if use_cuda:
activities.append(ProfilerActivity.CUDA)

with profile(activities=activities, on_trace_ready=trace_handler) as p:
p_trace_handler = trace_handler if export_chrome_trace \
else print_time_total
p = profile(activities=activities, on_trace_ready=p_trace_handler)

with p:
yield
p.step()

if csv_data is not None and write_csv == 'prof':
profile_sort = 'self_cuda_time_total' if use_cuda \
else 'self_cpu_time_total'
events = EventList(
sorted(p.key_averages(),
key=lambda evt: getattr(evt, profile_sort), reverse=True),
use_cuda=use_cuda)

save_profile_data(csv_data, events, use_cuda)


def format_prof_time(time):
# The profile time is in micro seconds, so it needs to be formatted
# appropriately
return round(time / 1e6, 3)


def save_profile_data(csv_data, events, use_cuda):
sum_self_cpu_time_total = sum(
[event.self_cpu_time_total for event in events])
sum_cpu_time_total = sum([event.self_cpu_time_total for event in events])
sum_self_cuda_time_total = sum(
[event.self_cuda_time_total for event in events]) if use_cuda else 0

# Save top 5 most time consuming ops
for e in events[:5]:
csv_data['NAME'].append(e.key)
csv_data['SELF CPU %'].append(
round(e.self_cpu_time_total * 100.0 / sum_self_cpu_time_total, 3))
csv_data['SELF CPU'].append(format_prof_time(e.self_cpu_time_total))
csv_data['CPU TOTAL %'].append(
round(e.cpu_time_total * 100.0 / sum_cpu_time_total, 3))
csv_data['CPU TOTAL'].append(format_prof_time(e.cpu_time_total))
csv_data['CPU TIME AVG'].append(format_prof_time(e.cpu_time_total))
if use_cuda:
csv_data['SELF CUDA %'].append(e.self_cuda_time_total * 100.0 /
sum_self_cuda_time_total)
csv_data['SELF CUDA'].append(
format_prof_time(e.self_cuda_time_total))
csv_data['CUDA TOTAL'].append(format_prof_time(e.cpu_time_total))
csv_data['CUDA TIME AVG'].append(format_prof_time(
e.cpu_time_total))
csv_data['# OF CALLS'].append(e.count)

0 comments on commit 0b52f29

Please sign in to comment.