Skip to content

Commit e34f121

Browse files
committed
Add profiler and Perfetto UI link with comprehensive tests (#1984, #1992)
ghstack-source-id: a5f8301acb77a180a395aa8dd4c1aa9c2ccd7522 ghstack-comment-id: 2770609971 Pull Request resolved: #1997
1 parent 70fc520 commit e34f121

File tree

5 files changed

+688
-158
lines changed

5 files changed

+688
-158
lines changed

benchmarks/microbenchmarks/benchmark_inference.py

+89-65
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
BenchmarkResult,
2121
clean_caches,
2222
create_model_and_input,
23+
generate_memory_profile,
24+
generate_model_profile,
2325
model_inference_time_in_ms,
2426
string_to_config,
2527
)
@@ -29,70 +31,92 @@
2931

3032
def run(config: BenchmarkConfig) -> BenchmarkResult:
3133
"""Run inference benchmarks"""
32-
clean_caches() # Clean caches
33-
34-
# Create output directory if it doesn't exist
35-
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
36-
37-
base_model, input_data = create_model_and_input(
38-
config.model_type,
39-
config.m,
40-
config.k,
41-
config.n,
42-
high_precision_dtype=config.high_precision_dtype,
43-
device=config.device,
44-
)
45-
46-
# Use quantize_ to apply each quantization function to the model
47-
m_copy = deepcopy(base_model).eval().to(config.device)
48-
ao_base_config = string_to_config(
49-
config.quantization,
50-
config.sparsity,
51-
high_precision_dtype=config.high_precision_dtype,
52-
)
53-
54-
# Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA)
55-
is_cuda = config.device == "cuda" and torch.cuda.is_available()
56-
57-
if config.sparsity is not None and (
58-
config.quantization is None or "baseline" in config.quantization
59-
):
60-
if is_cuda:
61-
print(f"Applying {config.sparsity} sparsity to model")
62-
sparsify_(m_copy, ao_base_config)
34+
try:
35+
clean_caches() # Clean caches
36+
37+
# Create output directory if it doesn't exist
38+
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
39+
40+
base_model, input_data = create_model_and_input(
41+
config.model_type,
42+
config.m,
43+
config.k,
44+
config.n,
45+
high_precision_dtype=config.high_precision_dtype,
46+
device=config.device,
47+
)
48+
49+
# Use quantize_ to apply each quantization function to the model
50+
m_copy = deepcopy(base_model).eval().to(config.device)
51+
ao_base_config = string_to_config(
52+
config.quantization,
53+
config.sparsity,
54+
high_precision_dtype=config.high_precision_dtype,
55+
)
56+
57+
# Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA)
58+
is_cuda = config.device == "cuda" and torch.cuda.is_available()
59+
60+
if config.sparsity is not None and (
61+
config.quantization is None or "baseline" in config.quantization
62+
):
63+
if is_cuda:
64+
print(f"Applying {config.sparsity} sparsity to model")
65+
sparsify_(m_copy, ao_base_config)
66+
else:
67+
print(
68+
f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}"
69+
)
70+
elif config.sparsity is None and (
71+
config.quantization is None or "baseline" in config.quantization
72+
):
73+
pass # No quantization or sparsity specified, do nothing
6374
else:
64-
print(
65-
f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}"
75+
print("Quantizing model....")
76+
quantize_(m_copy, ao_base_config)
77+
78+
if config.use_torch_compile:
79+
print("Compiling model....")
80+
m_copy = torch.compile(
81+
m_copy, mode=config.torch_compile_mode, fullgraph=True
6682
)
67-
elif config.sparsity is None and (
68-
config.quantization is None or "baseline" in config.quantization
69-
):
70-
pass # No quantization or sparsity specified, do nothing
71-
else:
72-
print("Quantizing model....")
73-
quantize_(m_copy, ao_base_config)
74-
75-
if config.use_torch_compile:
76-
print("Compiling model....")
77-
m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True)
78-
79-
# Run benchmarks
80-
result = BenchmarkResult(config=config)
81-
82-
# Benchmark time to run an inference call for quantized model
83-
result.model_inference_time_in_ms = model_inference_time_in_ms(
84-
model=m_copy, input_data=input_data
85-
)
86-
87-
# TODO: Benchmark time using profiler
88-
# Profile dtype model evaluation
89-
# prof_dtype = benchmark_model_op_with_profiler_in_microseconds(m_copy, input_data, quantized_dtype)
90-
# prof_dtype.export_chrome_trace(f"{quantization}_model_{input_data[0].size()[0]}.json") # Save profiling details
91-
92-
# TODO: Benchmark gemm time using cuda graph
93-
# gemm_time = benchmark_torch_function_in_microseconds(gemm_op, *args, **kwargs)
94-
95-
# TODO: Benchmark op with cuda graph
96-
# time = benchmark_op_with_cuda_graph(op, args)
97-
98-
return result
83+
84+
# Run benchmarks
85+
result = BenchmarkResult(config=config)
86+
# Store result in model for memory profiling
87+
m_copy._benchmark_result = result
88+
89+
# Benchmark time to run an inference call for quantized model
90+
result.model_inference_time_in_ms = model_inference_time_in_ms(
91+
model=m_copy, input_data=input_data
92+
)
93+
94+
# Run profiler if enabled
95+
if config.enable_profiler:
96+
print("Running profiler...")
97+
try:
98+
result.profiler_json_path, result.perfetto_url = generate_model_profile(
99+
m_copy, input_data, config.profiler_file_name
100+
)
101+
except Exception as e:
102+
print(f"Error running profiler: {e}")
103+
104+
# Run memory profiler if enabled
105+
if config.enable_memory_profile:
106+
print("Running memory profiler...")
107+
try:
108+
result.memory_profile_path, result.memory_stats = (
109+
generate_memory_profile(
110+
m_copy, input_data, config.memory_profile_file_name
111+
)
112+
)
113+
except Exception as e:
114+
print(f"Error running memory profiler: {e}")
115+
116+
return result
117+
except Exception as e:
118+
print(f"Error in benchmark run: {e}")
119+
import traceback
120+
121+
print(traceback.format_exc())
122+
return None

benchmarks/microbenchmarks/benchmark_runner.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -164,16 +164,22 @@ def run_inference_benchmarks_from_config(configs: List[BenchmarkConfig]) -> None
164164
f"Running: {config.name} for Quantization: {config.quantization} and Sparsity: {config.sparsity}"
165165
)
166166
result = run_inference(config) # Pass the config object directly
167-
results.append(result)
168-
except Exception:
169-
print(f"Error running benchmark {config.name}")
170-
continue
167+
if result is not None: # Only add successful results
168+
results.append(result)
169+
except Exception as e:
170+
import traceback
171171

172-
# Add results to csv
173-
generate_results_csv(results, configs[0].output_dir)
172+
print(f"Error running benchmark {config.name} with error: {e}")
173+
print(traceback.format_exc())
174+
continue
174175

175-
# Print results
176-
print_results(results)
176+
# Add results to csv if there are any
177+
if results:
178+
generate_results_csv(results, configs[0].output_dir)
179+
# Print results
180+
print_results(results)
181+
else:
182+
print("No benchmark results were collected. All benchmarks failed.")
177183

178184
# TODO: Process results: Speedups:
179185
# 1. For different shapes for same model and quantization

benchmarks/microbenchmarks/test/benchmark_config.yml

+33-28
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,51 @@
22
benchmark_mode: "inference"
33
quantization_config_recipe_names:
44
# Will run a baseline inference for model by default, without quantization for comparison
5-
- "int4wo-32"
6-
- "marlin"
7-
sparsity_config_recipe_names:
5+
# - "int4wo-32"
6+
# - "marlin"
7+
- "int8wo"
8+
# sparsity_config_recipe_names:
89
# Will run a baseline inference for model by default, without sparsity for comparison
9-
- "semi-sparse"
10-
- "block"
10+
# - "semi-sparse"
11+
# - "block"
1112
output_dir: "benchmarks/microbenchmarks/results"
1213
model_params:
13-
- name: "small_bf16_linear"
14-
matrix_shapes:
15-
- name: "custom"
16-
shapes: [
17-
[1024, 1024, 1024], # [m, k, n]
18-
]
19-
high_precision_dtype: "torch.bfloat16"
20-
use_torch_compile: true
21-
torch_compile_mode: "max-autotune"
22-
device: "cuda"
23-
model_type: "linear"
14+
# - name: "small_bf16_linear"
15+
# matrix_shapes:
16+
# - name: "custom"
17+
# shapes: [
18+
# [1024, 1024, 1024], # [m, k, n]
19+
# ]
20+
# high_precision_dtype: "torch.bfloat16"
21+
# use_torch_compile: true
22+
# torch_compile_mode: "max-autotune"
23+
# device: "cuda"
24+
# model_type: "linear"
25+
# enable_profiler: true # Enable profiling for this model
2426

2527
- name: "large_bf16_ln_linear"
2628
matrix_shapes:
2729
- name: "custom"
2830
shapes: [
2931
[2048, 4096, 1024],
30-
[4096, 4096, 1024]
32+
# [4096, 4096, 1024]
3133
]
3234
high_precision_dtype: "torch.bfloat16"
3335
use_torch_compile: true
3436
torch_compile_mode: "max-autotune"
3537
device: "cuda"
36-
model_type: "ln_linear_sigmoid"
37-
38-
- name: "cpu_fp32_linear"
39-
matrix_shapes:
40-
- name: "custom"
41-
shapes: [
42-
[4096, 4096, 1024]
43-
]
44-
high_precision_dtype: "torch.float32"
45-
use_torch_compile: false
46-
device: "cpu"
4738
model_type: "linear"
39+
enable_profiler: true # Enable profiling for this model
40+
enable_memory_profile: true # Enable memory profiling for this model
41+
42+
# - name: "cpu_fp32_linear"
43+
# matrix_shapes:
44+
# - name: "custom"
45+
# shapes: [
46+
# [4096, 4096, 1024]
47+
# ]
48+
# high_precision_dtype: "torch.float32"
49+
# use_torch_compile: false
50+
# device: "cpu"
51+
# model_type: "linear"
52+
# enable_profiler: true # Enable profiling for this model

0 commit comments

Comments
 (0)