1313import os
1414from copy import deepcopy
1515from pathlib import Path
16+ from typing import Dict , Tuple
1617
1718import torch
1819
3435 create_model_and_input_data ,
3536)
3637
38+ # -----------------------------------------------------------------------------
39+ # Baseline caching
40+ #
41+ # ``_BASELINE_CACHE`` maps a unique key constructed using _make_cache_key(config) -> (model_type, m, k, n, high_precision_dtype, device, torch_compile_mode) to a tuple
42+ # ``(eager_baseline_time, compile_baseline_time)``. See ``_make_cache_key`` for the key
43+ # construction. Users should not access this cache directly; it is
44+ # internal to this module.
45+ # Eg: (linear, 1024, 1024, 1024, torch.bfloat16, cuda, default) -> (95.00, 56.00)
46+ # The cache is used to store the baseline inference time for a given configuration, which is further used to calculate speedup metrics.
47+ # This helps in removing multiple baseline calculations, which in turn helps in reducing the benchmarking time.
48+ # -----------------------------------------------------------------------------
49+
50+ _BASELINE_CACHE : Dict [Tuple , Tuple [float , float ]] = {}
51+
52+
53+ def _make_cache_key (config : BenchmarkConfig ) -> Tuple :
54+ """Create a key for caching based on benchmark configuration.
55+
56+ Parameters that affect baseline performance are included:
57+
58+ * model type (e.g. ``linear`` or ``transformer_block``)
59+ * shape dimensions (m, k, n)
60+ * high precision dtype (bf16, fp16, etc.)
61+ * device (cuda, cpu, mps)
62+ * compile settings (whether compile is enabled and compile mode)
63+
64+ Sparsity and quantization settings are deliberately excluded
65+ because the baseline (non‑quantized, non‑sparse) performance is
66+ independent of those attributes.
67+ """
68+ return (
69+ config .model_type ,
70+ config .m ,
71+ config .k ,
72+ config .n ,
73+ config .high_precision_dtype ,
74+ config .device ,
75+ config .torch_compile_mode ,
76+ )
77+
3778
3879def run (config : BenchmarkConfig ) -> BenchmarkResult :
39- """Run inference benchmarks"""
80+ """
81+ Run inference benchmarks.
82+
83+ The function first checks if a baseline for the given configuration
84+ already exists in the internal cache. If not, it measures the baseline
85+ inference time and stores the result. When the baseline is cached,
86+ the function reuses the cached baselines to calculate speedup metrics.
87+
88+ Args:
89+ config (BenchmarkConfig): Benchmark configuration.
90+
91+ Returns:
92+ BenchmarkResult: Result of the benchmark.
93+ """
4094 try :
4195 clean_caches () # Clean caches
4296
4397 # Create output directory if it doesn't exist
4498 Path (config .output_dir ).mkdir (parents = True , exist_ok = True )
4599
100+ # Prepare result container
101+ result = BenchmarkResult (config = config )
102+
103+ # Create model and input data
46104 base_model , input_data = create_model_and_input_data (
47105 config .model_type ,
48106 config .m ,
@@ -51,28 +109,47 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
51109 high_precision_dtype = config .high_precision_dtype ,
52110 device = config .device ,
53111 )
54- # Copy base model for quantizing
55- m_copy = deepcopy (base_model )
56112
57- # Run benchmarks
58- result = BenchmarkResult ( config = config )
113+ # Generate a cache key for the current configuration
114+ cache_key = _make_cache_key ( config )
59115
60- # Store result in model for memory profiling
61- base_model ._benchmark_result = result
62-
63- # Run baseline benchmarking
64- base_model = base_model .eval ().to (config .device )
65- if config .use_torch_compile :
66- print ("Compiling baseline model...." )
67- base_model = torch .compile (
68- base_model , mode = config .torch_compile_mode , fullgraph = True
116+ # Check if the baseline for this configuration has been computed
117+ if cache_key not in _BASELINE_CACHE :
118+ # Switch model to eval and move to device
119+ m_copy = deepcopy (base_model )
120+ m_copy = m_copy .eval ().to (config .device )
121+ print ("Benchmarking eager baseline inference....." )
122+ eager_baseline_time = model_inference_time_in_ms (
123+ model = m_copy , input_data = input_data
69124 )
70- # Benchmark time to run an inference call for baseline model
71- print ("Benchmarking baseline inference....." )
72- result .baseline_inference_time_in_ms = model_inference_time_in_ms (
73- model = base_model , input_data = input_data
74- )
75125
126+ print ("Benchmarking compile baseline inference....." )
127+ m_copy = torch .compile (
128+ m_copy , mode = config .torch_compile_mode , fullgraph = True
129+ )
130+ compile_baseline_time = model_inference_time_in_ms (
131+ model = m_copy , input_data = input_data
132+ )
133+
134+ # Store uncompiled model, input and baseline time
135+ _BASELINE_CACHE [cache_key ] = (eager_baseline_time , compile_baseline_time )
136+
137+ result .baseline_model_eager_inference_time_in_ms = eager_baseline_time
138+ result .baseline_model_compiled_inference_time_in_ms = compile_baseline_time
139+ else :
140+ # Retrieve cached values
141+ cached_eager_time , cached_compile_time = _BASELINE_CACHE [cache_key ]
142+ result .baseline_model_eager_inference_time_in_ms = cached_eager_time
143+ result .baseline_model_compiled_inference_time_in_ms = cached_compile_time
144+
145+ # At this point, ``base_model`` is an uncompiled model ready for quantization,
146+ # and ``input_data`` is the corresponding input tensor. The baseline time
147+ # has been stored in ``result.baseline_inference_time_in_ms``.
148+
149+ # Copy base model for quantizing/sparsifying
150+ m_copy = deepcopy (base_model )
151+
152+ # Determine quantization/sparsity configuration
76153 ao_base_config = string_to_config (
77154 config .quantization ,
78155 config .sparsity ,
@@ -101,24 +178,39 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
101178 m_copy = m_copy .eval ().to (config .device )
102179 quantize_ (m_copy , ao_base_config )
103180
104- if config .use_torch_compile :
105- print ("Compiling quantized model...." )
106- m_copy = torch .compile (
107- m_copy , mode = config .torch_compile_mode , fullgraph = True
108- )
109-
110181 # Store result in model for memory profiling
111182 m_copy ._benchmark_result = result
112183
113- # Benchmark time to run an inference call for quantized model
114- print ("Benchmarking quantized model....." )
115- result .model_inference_time_in_ms = model_inference_time_in_ms (
184+ # Measure inference time for quantized model
185+ print ("Benchmarking eager quantized model....." )
186+ result .quantized_model_eager_inference_time_in_ms = model_inference_time_in_ms (
116187 model = m_copy , input_data = input_data
117188 )
118189
119- # Calculate speedup w.r.t. baseline
120- result .speedup = round (
121- result .baseline_inference_time_in_ms / result .model_inference_time_in_ms , 2
190+ # Measure inference time for compiled quantized model
191+ print ("Benchmarking quantized model....." )
192+ m_copy = torch .compile (m_copy , mode = config .torch_compile_mode , fullgraph = True )
193+ result .quantized_model_compiled_inference_time_in_ms = (
194+ model_inference_time_in_ms (model = m_copy , input_data = input_data )
195+ )
196+
197+ # Compute eager speedup relative to baseline
198+ result .eager_speedup_on_baseline = round (
199+ result .baseline_model_eager_inference_time_in_ms
200+ / result .quantized_model_eager_inference_time_in_ms ,
201+ ndigits = 2 ,
202+ )
203+ # Compute compile speedup relative to baseline
204+ result .compile_speedup_on_baseline = round (
205+ result .baseline_model_compiled_inference_time_in_ms
206+ / result .quantized_model_compiled_inference_time_in_ms ,
207+ ndigits = 2 ,
208+ )
209+ # Compute compile speedup for quantized model relative to eager quantized model
210+ result .compile_speedup_on_eager = round (
211+ result .quantized_model_eager_inference_time_in_ms
212+ / result .quantized_model_compiled_inference_time_in_ms ,
213+ ndigits = 2 ,
122214 )
123215
124216 # Run profiler if enabled
@@ -165,9 +257,9 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
165257 result .memory_profile_path
166258 )
167259 except ValueError as e :
168- if "not enough values to unpack" in e :
260+ if "not enough values to unpack" in str ( e ) :
169261 print (
170- "Failed due to existing bugs, re- run the code to generate memory profile. Please raise an issue if it persists."
262+ "Failed due to existing bugs, re‑ run the code to generate memory profile. Please raise an issue if it persists."
171263 )
172264 except Exception as e :
173265 print (f"Error running memory profiler: { e } " )
0 commit comments