Skip to content

Commit 35fc1fe

Browse files
authored
Fix accuracy of kernel summation (#77)
* Refactor and remove unused imports * Data collection script updates * Added more warmup plus mapping to full graph trace * More models to data collection script * add more models and gpus to sample data
1 parent 3a9dc08 commit 35fc1fe

File tree

5 files changed

+827
-267
lines changed

5 files changed

+827
-267
lines changed

centml/compiler/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class Config(BaseSettings):
3838

3939
CENTML_MODE: OperationMode = OperationMode.REMOTE_COMPILATION
4040
CENTML_PREDICTION_DATA_FILE: str = 'tests/sample_data.csv'
41-
CENTML_PREDICTION_GPUS: str = "A10G,A100SXM440GB"
41+
CENTML_PREDICTION_GPUS: str = "A10G,A100SXM440GB,L4,H10080GBHBM3"
4242
CENTML_PROMETHEUS_PORT: int = 8000
4343

4444

centml/compiler/prediction/profiler.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch.fx
55
from torch.fx.node import Node
66

7+
from scripts.timer import timed
8+
79

810
class Profiler:
911
def __init__(self, mod, gpu, treeDB, data_collection_mode=False):
@@ -13,11 +15,30 @@ def __init__(self, mod, gpu, treeDB, data_collection_mode=False):
1315
self.tree_db = treeDB
1416
self.gpu = gpu
1517
self.data_collection_mode = data_collection_mode
18+
self.trace_event_idx = 0
1619

1720
def propagate(self, *args):
1821
args_iter = iter(args)
1922
env: Dict[str, Node] = {}
20-
total_time = 0
23+
total_gpu_time = 0
24+
actual_time = 0
25+
trace_events = []
26+
if self.data_collection_mode:
27+
# Warmup before profiling
28+
for _ in range(10):
29+
_, t = timed(lambda: self.mod(*args))
30+
31+
# actual_time is to compare prediction to execution time of GraphModule
32+
actual_time = t
33+
34+
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
35+
self.mod(*args)
36+
for event in prof.events():
37+
# Ignore CPU events for now
38+
if event.trace_name is None or event.device_type == torch.autograd.DeviceType.CPU:
39+
continue
40+
# Create a mapping of kernel execution times to the corresponding trace events
41+
trace_events.append(event.time_range.elapsed_us())
2142

2243
def load_arg(a):
2344
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
@@ -81,14 +102,26 @@ def find_dtypes(results):
81102
def get_time_or_profile(key, inp_shapes, operation, *args, **kwargs):
82103
t = self.tree_db.get(key, inp_shapes)
83104

84-
if self.data_collection_mode and t is None:
105+
if self.data_collection_mode:
85106
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
86107
operation(*args, **kwargs)
87-
event_time_total = 0
88-
for event in prof.key_averages():
89-
event_time_total += event.cuda_time_total
90-
t = event_time_total
91-
self.tree_db.add(key, inp_shapes, t)
108+
109+
if t is None:
110+
# New key
111+
event_time_total = 0
112+
for event in prof.events():
113+
if event.trace_name is None or event.device_type == torch.autograd.DeviceType.CPU:
114+
continue
115+
event_time_total += trace_events[self.trace_event_idx]
116+
self.trace_event_idx += 1
117+
t = event_time_total
118+
self.tree_db.add(key, inp_shapes, t)
119+
else:
120+
# Existing key, increment trace_event_idx by # of events to maintain mapping to trace_events list
121+
for event in prof.events():
122+
if event.trace_name is None or event.device_type == torch.autograd.DeviceType.CPU:
123+
continue
124+
self.trace_event_idx += 1
92125

93126
return t
94127

@@ -110,7 +143,7 @@ def get_time_or_profile(key, inp_shapes, operation, *args, **kwargs):
110143

111144
t = get_time_or_profile(key, inp_shapes, node.target, *args, **kwargs)
112145

113-
total_time += t
146+
total_gpu_time += t
114147
elif node.op == 'call_method':
115148
self_obj, *args = load_arg(node.args)
116149
kwargs = load_arg(node.kwargs)
@@ -123,7 +156,7 @@ def get_time_or_profile(key, inp_shapes, operation, *args, **kwargs):
123156

124157
t = get_time_or_profile(key, inp_shapes, getattr(self_obj, node.target), *args, **kwargs)
125158

126-
total_time += t
159+
total_gpu_time += t
127160
elif node.op == 'call_module':
128161
mod = self.modules[node.target]
129162
args = load_arg(node.args)
@@ -145,9 +178,12 @@ def get_time_or_profile(key, inp_shapes, operation, *args, **kwargs):
145178

146179
t = get_time_or_profile(key, inp_shapes, mod, *args, **kwargs)
147180

148-
total_time += t
181+
total_gpu_time += t
149182
elif node.op == 'output':
150183
args = load_arg(node.args)
151-
return args[0], total_time
184+
if self.data_collection_mode:
185+
# Return full graph execution time as well for accuracy comparison
186+
return args[0], total_gpu_time, actual_time
187+
return args[0], total_gpu_time
152188

153189
env[node.name] = result

scripts/data_collection.py

Lines changed: 132 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,20 @@
1-
import argparse
21
import csv
32
import gc
43
import json
5-
import os
6-
import random
7-
import statistics
8-
import time
94

10-
import numpy as np
115
import torch
12-
import torchvision.models as models
13-
from sklearn.neighbors import KDTree
14-
from torch.profiler import ProfilerActivity, profile, record_function
156
from transformers import (
16-
AutoConfig,
177
AutoModelForCausalLM,
188
AutoTokenizer,
19-
BertConfig,
20-
BertForMaskedLM,
21-
GPT2ForSequenceClassification,
22-
PegasusConfig,
23-
PegasusForCausalLM,
9+
AutoModelForImageClassification,
10+
AutoModelForObjectDetection
2411
)
2512

13+
14+
2615
from centml.compiler.prediction.kdtree import KDTreeWithValues
2716
from centml.compiler.prediction.profiler import Profiler
17+
from scripts.timer import timed
2818

2919
torch.set_float32_matmul_precision('high')
3020
torch.set_default_device('cuda')
@@ -34,32 +24,74 @@
3424
OUTPUT_FILE = 'data.csv'
3525

3626
# Different HuggingFace Models + Different Input Sizes
37-
hf_model_tests = [
38-
("EleutherAI/gpt-neo-2.7B", (1, 512)),
27+
llm_tests = [
28+
("google/gemma-7b", (1, 128)),
29+
("microsoft/phi-2", (1,512)),
30+
("microsoft/phi-2", (2,512)),
31+
("facebook/bart-large", (1, 1024)),
32+
("facebook/bart-large", (2, 512)),
3933
("gpt2-xl", (1, 1024)),
40-
("gpt2-large", (1, 1024)),
34+
("gpt2-xl", (1, 720)),
4135
("gpt2-xl", (1, 512)),
36+
("gpt2-xl", (2, 512)),
37+
("gpt2-xl", (4, 256)),
38+
("EleutherAI/gpt-neo-2.7B", (1, 512)),
39+
("EleutherAI/gpt-neo-2.7B", (1, 256)),
40+
("gpt2-large", (1, 1024)),
41+
("gpt2-large", (1, 720)),
42+
("gpt2-large", (1, 512)),
4243
("google-bert/bert-large-uncased", (8, 512)),
4344
("google-bert/bert-large-uncased", (16, 512)),
44-
("meta-llama/Meta-Llama-3.1-8B", (1, 512)),
4545
("meta-llama/Meta-Llama-3.1-8B", (1, 256)),
4646
("gpt2-medium", (1, 1024)),
47-
("facebook/bart-large", (1, 1024)),
47+
("gpt2-medium", (1, 512)),
48+
("gpt2-medium", (2, 512)),
4849
("google/pegasus-cnn_dailymail", (1, 1024)),
50+
("google/pegasus-cnn_dailymail", (1, 512)),
51+
("google/pegasus-cnn_dailymail", (2, 512)),
4952
]
5053

51-
# Different Batch Sizes for each ResNet Model (torchvision)
52-
resnet_tests = [1024, 720, 1440]
53-
54+
# Tests for larger GPUs (A100, H100, etc.)
55+
# large_llm_tests = [
56+
# ("google/gemma-7b", (1, 256)),
57+
# ("google/gemma-7b", (1, 512)),
58+
# ("google/gemma-7b", (1, 1024)),
59+
# ("microsoft/phi-2", (1,1024)),
60+
# ("microsoft/phi-2", (1,2048)),
61+
# ("microsoft/phi-2", (2,1024)),
62+
# ("EleutherAI/gpt-neo-2.7B", (1, 1024)),
63+
# ("gpt2-xl", (2, 1024)),
64+
# ("gpt2-xl", (4, 512)),
65+
# ("meta-llama/Meta-Llama-3.1-8B", (1, 1024)),
66+
# ("meta-llama/Meta-Llama-3.1-8B", (1, 512)),
67+
# ("google/pegasus-cnn_dailymail", (4, 1024)),
68+
# ("facebook/bart-large", (4, 1024)),
69+
# ("facebook/bart-large", (2, 1024)),
70+
# ("google-bert/bert-large-uncased", (16, 512)),
71+
# ("gpt2-medium", (2, 1024)),
72+
# ("gpt2-medium", (4, 512)),
73+
# ("gpt2-large", (2, 1024)),
74+
# ("gpt2-large", (4, 512)),
75+
# ]
76+
77+
# Different Batch Sizes for each image classification model
78+
image_classification_tests = [
79+
("google/efficientnet-b0", 512),
80+
("google/efficientnet-b0", 256),
81+
("google/efficientnet-b0", 128),
82+
("google/vit-base-patch16-224", 128),
83+
("microsoft/resnet-50", 256),
84+
("microsoft/resnet-50", 512),
85+
]
5486

55-
def timed(fn):
56-
start = torch.cuda.Event(enable_timing=True)
57-
end = torch.cuda.Event(enable_timing=True)
58-
start.record()
59-
result = fn()
60-
end.record()
61-
torch.cuda.synchronize()
62-
return result, start.elapsed_time(end) / 1000
87+
# Different Batch Sizes for each object detection model
88+
object_detection_tests = [
89+
("hustvl/yolos-tiny", 128),
90+
("hustvl/yolos-tiny", 256),
91+
("hustvl/yolos-tiny", 512),
92+
("facebook/detr-resnet-50", 128),
93+
("facebook/detr-resnet-50", 256),
94+
]
6395

6496

6597
def percent_error(observed, true):
@@ -90,24 +122,28 @@ def get(self, key, inp):
90122

91123

92124
db = DataCollectionTreeDB()
93-
added_time = 0
125+
cuda_kernel_time = 0
126+
actual_time = 0
94127

95128

96129
def custom_backend(gm: torch.fx.GraphModule, inps):
97130
print("Compiling")
98131
profiler = Profiler(mod=gm, gpu=CURR_GPU, treeDB=db, data_collection_mode=True)
99132

100133
def forward(*args):
101-
global added_time
102-
out, t = profiler.propagate(*args)
103-
added_time += t
134+
global cuda_kernel_time
135+
global actual_time
136+
out, t, actual_t = profiler.propagate(*args)
137+
cuda_kernel_time += t
138+
actual_time += actual_t
104139
return out
105140

106141
return forward
107142

108143

109-
def hf_model_test(model_name, input_size, custom_backend):
110-
global added_time
144+
def llm_test(model_name, input_size, custom_backend):
145+
global cuda_kernel_time
146+
global actual_time
111147
models_without_tokenizer = {"google/pegasus-cnn_dailymail"}
112148

113149
model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda:0")
@@ -131,22 +167,55 @@ def hf_model_test(model_name, input_size, custom_backend):
131167
compiled_model = torch.compile(model, backend=custom_backend)
132168
compiled_model(inp)
133169

134-
added_time /= 1000000
170+
cuda_kernel_time /= 1000000
135171

136172
print(f"{model_name}, {input_size}")
137-
print("Real time: ", t)
138-
print("TOTAL TIME: ", added_time)
139-
print("Error: ", percent_error(added_time, t))
173+
print("Real time: ", actual_time)
174+
print("Kernel execution time: ", cuda_kernel_time)
175+
print("Error: ", percent_error(cuda_kernel_time, actual_time))
140176

141-
added_time = 0
177+
cuda_kernel_time = 0
178+
actual_time = 0
142179
del model, inp, compiled_model
143180
gc.collect()
144181
torch.cuda.empty_cache()
145182

146183

147-
def resnet_test(batch_size, custom_backend):
148-
global added_time
149-
model = models.resnet50(weights=True, num_classes=1000).cuda()
184+
def image_classification_test(model_name, batch_size, custom_backend):
185+
global cuda_kernel_time
186+
global actual_time
187+
model = AutoModelForImageClassification.from_pretrained(model_name).to("cuda:0")
188+
model.eval()
189+
if model_name == "google/vit-base-patch16-224":
190+
inp = torch.randn(batch_size, 3, 224, 224).cuda(0)
191+
else:
192+
inp = torch.randn(batch_size, 3, 128, 128).cuda(0)
193+
194+
with torch.inference_mode():
195+
for _ in range(10):
196+
_, t = timed(lambda: model(inp))
197+
print(t)
198+
199+
compiled_model = torch.compile(model, backend=custom_backend)
200+
compiled_model(inp)
201+
202+
cuda_kernel_time /= 1000000
203+
204+
print(f"{model_name}, {batch_size}")
205+
print("Real time: ", actual_time)
206+
print("TOTAL TIME: ", cuda_kernel_time)
207+
print("Error: ", percent_error(cuda_kernel_time, actual_time))
208+
209+
cuda_kernel_time = 0
210+
actual_time = 0
211+
del model, inp, compiled_model
212+
gc.collect()
213+
torch.cuda.empty_cache()
214+
215+
def object_detection_test(model_name, batch_size, custom_backend):
216+
global cuda_kernel_time
217+
global actual_time
218+
model = AutoModelForObjectDetection.from_pretrained(model_name).to("cuda:0")
150219
model.eval()
151220
inp = torch.randn(batch_size, 3, 128, 128).cuda(0)
152221

@@ -157,22 +226,31 @@ def resnet_test(batch_size, custom_backend):
157226

158227
compiled_model = torch.compile(model, backend=custom_backend)
159228
compiled_model(inp)
160-
print(f"resnet, ({batch_size}, 3, 128, 128)")
161-
print("Real time: ", t)
162-
print("TOTAL TIME: ", added_time)
163-
print("Error: ", percent_error(added_time, t))
164229

165-
added_time = 0
230+
cuda_kernel_time /= 1000000
231+
232+
print(f"{model_name}, {batch_size}")
233+
print("Real time: ", actual_time)
234+
print("TOTAL TIME: ", cuda_kernel_time)
235+
print("Error: ", percent_error(cuda_kernel_time, actual_time))
236+
237+
cuda_kernel_time = 0
238+
actual_time = 0
166239
del model, inp, compiled_model
167240
gc.collect()
168241
torch.cuda.empty_cache()
169242

243+
# for model_name, input_size in large_llm_tests:
244+
# llm_test(model_name, input_size, custom_backend)
245+
246+
for model_name, input_size in llm_tests:
247+
llm_test(model_name, input_size, custom_backend)
170248

171-
for model_name, input_size in hf_model_tests:
172-
hf_model_test(model_name, input_size, custom_backend)
249+
for model_name, batch_size in object_detection_tests:
250+
object_detection_test(model_name, batch_size, custom_backend)
173251

174-
for batch_size in resnet_tests:
175-
resnet_test(batch_size, custom_backend)
252+
for model_name, batch_size in image_classification_tests:
253+
image_classification_test(model_name, batch_size, custom_backend)
176254

177255
# Write to CSV
178256
with open(OUTPUT_FILE, 'w', newline='') as csvfile:

scripts/timer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
3+
def timed(fn):
4+
start = torch.cuda.Event(enable_timing=True)
5+
end = torch.cuda.Event(enable_timing=True)
6+
start.record()
7+
result = fn()
8+
end.record()
9+
torch.cuda.synchronize()
10+
return result, start.elapsed_time(end) / 1000

0 commit comments

Comments
 (0)