Skip to content

Commit f22ec7f

Browse files
authored
Benchmarking V2: framework impl (#40486)
* Start revamping benchmarking * Start refactoring benchmarking * Use Pandas for CSV * import fix * Remove benchmark files * Remove sample data * Address review comments * Benchmarking v2 * Fix llama bench parameters * Working checkpoint * Readme touchups * Remove unnecessary test * Massage the framework a bit * Small cleanup * Remove unnecessary flushes * Remove references to mock benchmark * Take commit ID from CLI * Address review comments * Use Events for thread comms * Tiny renaming
1 parent 459c1fa commit f22ec7f

File tree

7 files changed

+1851
-0
lines changed

7 files changed

+1851
-0
lines changed

benchmark_v2/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
benchmark_results/

benchmark_v2/README.md

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Benchmarking v2
2+
3+
A comprehensive benchmarking framework for transformer models that supports multiple execution modes (eager, compiled, kernelized), detailed performance metrics collection, and structured output format.
4+
5+
6+
## Quick Start
7+
8+
### Running All Benchmarks
9+
10+
```bash
11+
# Run all benchmarks with default settings
12+
python run_benchmarks.py
13+
14+
# Specify output directory
15+
python run_benchmarks.py --output-dir my_results
16+
17+
# Run with custom parameters
18+
python run_benchmarks.py \
19+
--warmup-iterations 5 \
20+
--measurement-iterations 10 \
21+
--num-tokens-to-generate 200
22+
```
23+
24+
### Running Specific Benchmarks
25+
26+
```bash
27+
# Include only specific benchmarks
28+
python run_benchmarks.py --include llama
29+
30+
# Exclude specific benchmarks
31+
python run_benchmarks.py --exclude old_benchmark
32+
33+
## Output Format
34+
35+
Results are saved as JSON files with the following structure:
36+
37+
```json
38+
{
39+
"model_name": "llama_2_7b",
40+
"benchmark_scenarios": [
41+
{
42+
"scenario_name": "eager_variant",
43+
"metadata": {
44+
"timestamp": "2025-01-XX...",
45+
"commit_id": "abc123...",
46+
"hardware_info": {
47+
"gpu_name": "NVIDIA A100",
48+
"gpu_memory_total": 40960,
49+
"cpu_count": 64
50+
},
51+
"config": {
52+
"variant": "eager",
53+
"warmup_iterations": 3,
54+
"measurement_iterations": 5
55+
}
56+
},
57+
"measurements": {
58+
"latency": {
59+
"mean": 2.45,
60+
"median": 2.43,
61+
"std": 0.12,
62+
"min": 2.31,
63+
"max": 2.67,
64+
"p95": 2.61,
65+
"p99": 2.65
66+
},
67+
"time_to_first_token": {
68+
"mean": 0.15,
69+
"std": 0.02
70+
},
71+
"tokens_per_second": {
72+
"mean": 87.3,
73+
"unit": "tokens/sec"
74+
}
75+
},
76+
"gpu_metrics": {
77+
"gpu_utilization_mean": 85.2,
78+
"gpu_memory_used_mean": 12450
79+
}
80+
}
81+
]
82+
}
83+
```
84+
85+
### Debug Mode
86+
87+
```bash
88+
python run_benchmarks.py --log-level DEBUG
89+
```
90+
91+
## Contributing
92+
93+
To add new benchmarks:
94+
95+
1. Create a new file in `benches/`
96+
2. Implement the `ModelBenchmark` interface
97+
3. Add a runner function (`run_<benchmark_name>` or `run_benchmark`)
98+
4. run_benchmarks.py

benchmark_v2/benches/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Benchmark implementations directory

benchmark_v2/benches/llama.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import logging
17+
from typing import Dict, Any, List
18+
19+
from benchmark_framework import ModelBenchmark
20+
21+
import torch
22+
23+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
24+
os.environ["TOKENIZERS_PARALLELISM"] = "1"
25+
torch.set_float32_matmul_precision("high")
26+
27+
class LLaMABenchmark(ModelBenchmark):
28+
"""Simplified LLaMA model benchmark implementation using the ModelBenchmark base class."""
29+
30+
def __init__(self, logger: logging.Logger):
31+
super().__init__(logger)
32+
self._default_prompt = "Why dogs are so cute?" # Custom prompt for LLaMA
33+
34+
35+
36+
def get_scenario_configs(self) -> List[Dict[str, Any]]:
37+
"""
38+
Get LLaMA-specific scenario configurations.
39+
40+
Returns:
41+
List of scenario configuration dictionaries
42+
"""
43+
return [
44+
# Eager variants
45+
{"variant": "eager", "compile_mode": None, "use_cache": True, "description": "Eager execution with cache"},
46+
47+
# Compiled variants
48+
{"variant": "compiled", "compile_mode": "max-autotune", "use_cache": True, "description": "Compiled with max autotune"},
49+
50+
# Kernelized variant (if available)
51+
{"variant": "kernelized", "compile_mode": "max-autotune", "use_cache": True, "description": "Kernelized execution"},
52+
]
53+
54+
def _is_kernelization_available(self) -> bool:
55+
"""Check if kernelization is available for LLaMA."""
56+
try:
57+
from kernels import Mode, kernelize
58+
return True
59+
except ImportError:
60+
self.logger.debug("Kernelization not available: kernels module not found")
61+
return False
62+
63+
def get_default_generation_config(self) -> Dict[str, Any]:
64+
"""Get LLaMA-specific generation configuration."""
65+
return {
66+
"do_sample": False,
67+
"top_p": 1.0,
68+
"temperature": 1.0,
69+
"repetition_penalty": 1.0,
70+
"max_new_tokens": None, # Will be set per scenario
71+
}
72+
73+
def get_model_init_kwargs(self, config) -> Dict[str, Any]:
74+
"""Get LLaMA-specific model initialization kwargs."""
75+
from benchmark_framework import BenchmarkConfig
76+
return {
77+
"torch_dtype": getattr(torch, config.torch_dtype),
78+
"attn_implementation": config.attn_implementation,
79+
"use_cache": True,
80+
}
81+
82+
def get_default_torch_dtype(self) -> str:
83+
"""Get default torch dtype for LLaMA."""
84+
return "float16" # LLaMA works well with float16
85+
86+
def get_default_device(self) -> str:
87+
"""Get default device for LLaMA."""
88+
return "cuda" # LLaMA prefers CUDA
89+
90+
91+
def run_llama(logger, output_dir, **kwargs):
92+
"""
93+
Run LLaMA benchmark with the given configuration.
94+
95+
Args:
96+
logger: Logger instance
97+
output_dir: Output directory for results
98+
**kwargs: Additional configuration options
99+
100+
Returns:
101+
Path to output file if successful
102+
"""
103+
from benchmark_framework import BenchmarkRunner
104+
105+
# Extract parameters with defaults
106+
model_id = kwargs.get('model_id', 'meta-llama/Llama-2-7b-hf')
107+
warmup_iterations = kwargs.get('warmup_iterations', 3)
108+
measurement_iterations = kwargs.get('measurement_iterations', 5)
109+
num_tokens_to_generate = kwargs.get('num_tokens_to_generate', 100)
110+
include_sdpa_variants = kwargs.get('include_sdpa_variants', True)
111+
device = kwargs.get('device', 'cuda')
112+
torch_dtype = kwargs.get('torch_dtype', 'float16')
113+
batch_size = kwargs.get('batch_size', 1)
114+
commit_id = kwargs.get('commit_id', None)
115+
116+
logger.info(f"Starting LLaMA benchmark for model: {model_id}")
117+
logger.info(f"Configuration: warmup={warmup_iterations}, measurement={measurement_iterations}, tokens={num_tokens_to_generate}")
118+
119+
try:
120+
# Create benchmark instance
121+
benchmark = LLaMABenchmark(logger)
122+
123+
# Create scenarios
124+
scenarios = benchmark.create_scenarios(
125+
model_id=model_id,
126+
warmup_iterations=warmup_iterations,
127+
measurement_iterations=measurement_iterations,
128+
num_tokens_to_generate=num_tokens_to_generate,
129+
include_sdpa_variants=include_sdpa_variants,
130+
device=device,
131+
torch_dtype=torch_dtype,
132+
batch_size=batch_size
133+
)
134+
135+
logger.info(f"Created {len(scenarios)} benchmark scenarios")
136+
137+
# Create runner and execute benchmarks
138+
runner = BenchmarkRunner(logger, output_dir)
139+
results = runner.run_benchmark(benchmark, scenarios, commit_id=commit_id)
140+
141+
if not results:
142+
logger.warning("No successful benchmark results")
143+
return None
144+
145+
# Save results
146+
model_name = model_id.split('/')[-1] # Extract model name from ID
147+
output_file = runner.save_results(model_name, results)
148+
149+
logger.info(f"LLaMA benchmark completed successfully. Results saved to: {output_file}")
150+
return output_file
151+
152+
except Exception as e:
153+
logger.error(f"LLaMA benchmark failed: {e}")
154+
import traceback
155+
logger.debug(traceback.format_exc())
156+
raise

0 commit comments

Comments
 (0)