Skip to content

Commit 9edbd0a

Browse files
authored
extra: Add benchmark script implemented in Python (#1298)
* Create bench.py * Various benchmark results * Update benchmark script with hardware name, and file checks * Remove old benchmark results * Add git shorthash * Round to 2 digits on calculated floats * Fix the header reference when sorting results * FIx order of models * Parse file name * Simplify filecheck * Improve print run print statement * Use simplified model name * Update benchmark_results.csv * Process single or lists of processors and threads * Ignore benchmark results, dont check in * Move bench.py to extra folder * Readme section on how to use * Move command to correct location * Use separate list for models that exist * Handle subprocess error in git short hash check * Fix filtered models list initialization
1 parent 707507f commit 9edbd0a

File tree

3 files changed

+237
-0
lines changed

3 files changed

+237
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,5 @@ models/*.mlpackage
4646
bindings/java/.gradle/
4747
bindings/java/.idea/
4848
.idea/
49+
50+
benchmark_results.csv

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,19 @@ took to execute it. The results are summarized in the following Github issue:
709709
710710
[Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
711711
712+
Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](bench.py).
713+
714+
You can run it with the following command, by default it will run against any standard model in the models folder.
715+
716+
```bash
717+
python3 extra/bench.py -f samples/jfk.wav -t 2,4,8 -p 1,2
718+
```
719+
720+
It is written in python with the intention of being easy to modify and extend for your benchmarking use case.
721+
722+
It outputs a csv file with the results of the benchmarking.
723+
724+
712725
## ggml format
713726
714727
The original models are converted to a custom binary format. This allows to pack everything needed into a single file:

extra/bench.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
import os
2+
import subprocess
3+
import re
4+
import csv
5+
import wave
6+
import contextlib
7+
import argparse
8+
9+
10+
# Custom action to handle comma-separated list
11+
class ListAction(argparse.Action):
12+
def __call__(self, parser, namespace, values, option_string=None):
13+
setattr(namespace, self.dest, [int(val) for val in values.split(",")])
14+
15+
16+
parser = argparse.ArgumentParser(description="Benchmark the speech recognition model")
17+
18+
# Define the argument to accept a list
19+
parser.add_argument(
20+
"-t",
21+
"--threads",
22+
dest="threads",
23+
action=ListAction,
24+
default=[4],
25+
help="List of thread counts to benchmark (comma-separated, default: 4)",
26+
)
27+
28+
parser.add_argument(
29+
"-p",
30+
"--processors",
31+
dest="processors",
32+
action=ListAction,
33+
default=[1],
34+
help="List of processor counts to benchmark (comma-separated, default: 1)",
35+
)
36+
37+
38+
parser.add_argument(
39+
"-f",
40+
"--filename",
41+
type=str,
42+
default="./samples/jfk.wav",
43+
help="Relative path of the file to transcribe (default: ./samples/jfk.wav)",
44+
)
45+
46+
# Parse the command line arguments
47+
args = parser.parse_args()
48+
49+
sample_file = args.filename
50+
51+
threads = args.threads
52+
processors = args.processors
53+
54+
# Define the models, threads, and processor counts to benchmark
55+
models = [
56+
"ggml-tiny.en.bin",
57+
"ggml-tiny.bin",
58+
"ggml-base.en.bin",
59+
"ggml-base.bin",
60+
"ggml-small.en.bin",
61+
"ggml-small.bin",
62+
"ggml-medium.en.bin",
63+
"ggml-medium.bin",
64+
"ggml-large.bin",
65+
]
66+
67+
68+
metal_device = ""
69+
70+
# Initialize a dictionary to hold the results
71+
results = {}
72+
73+
gitHashHeader = "Commit"
74+
modelHeader = "Model"
75+
hardwareHeader = "Hardware"
76+
recordingLengthHeader = "Recording Length (seconds)"
77+
threadHeader = "Thread"
78+
processorCountHeader = "Processor Count"
79+
loadTimeHeader = "Load Time (ms)"
80+
sampleTimeHeader = "Sample Time (ms)"
81+
encodeTimeHeader = "Encode Time (ms)"
82+
decodeTimeHeader = "Decode Time (ms)"
83+
sampleTimePerRunHeader = "Sample Time per Run (ms)"
84+
encodeTimePerRunHeader = "Encode Time per Run (ms)"
85+
decodeTimePerRunHeader = "Decode Time per Run (ms)"
86+
totalTimeHeader = "Total Time (ms)"
87+
88+
89+
def check_file_exists(file: str) -> bool:
90+
return os.path.isfile(file)
91+
92+
93+
def get_git_short_hash() -> str:
94+
try:
95+
return (
96+
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
97+
.decode()
98+
.strip()
99+
)
100+
except subprocess.CalledProcessError as e:
101+
return ""
102+
103+
104+
def wav_file_length(file: str = sample_file) -> float:
105+
with contextlib.closing(wave.open(file, "r")) as f:
106+
frames = f.getnframes()
107+
rate = f.getframerate()
108+
duration = frames / float(rate)
109+
return duration
110+
111+
112+
def extract_metrics(output: str, label: str) -> tuple[float, float]:
113+
match = re.search(rf"{label} \s*=\s*(\d+\.\d+)\s*ms\s*/\s*(\d+)\s*runs", output)
114+
time = float(match.group(1)) if match else None
115+
runs = float(match.group(2)) if match else None
116+
return time, runs
117+
118+
119+
def extract_device(output: str) -> str:
120+
match = re.search(r"picking default device: (.*)", output)
121+
device = match.group(1) if match else "Not found"
122+
return device
123+
124+
125+
# Check if the sample file exists
126+
if not check_file_exists(sample_file):
127+
raise FileNotFoundError(f"Sample file {sample_file} not found")
128+
129+
recording_length = wav_file_length()
130+
131+
132+
# Check that all models exist
133+
# Filter out models from list that are not downloaded
134+
filtered_models = []
135+
for model in models:
136+
if check_file_exists(f"models/{model}"):
137+
filtered_models.append(model)
138+
else:
139+
print(f"Model {model} not found, removing from list")
140+
141+
models = filtered_models
142+
143+
# Loop over each combination of parameters
144+
for model in filtered_models:
145+
for thread in threads:
146+
for processor_count in processors:
147+
# Construct the command to run
148+
cmd = f"./main -m models/{model} -t {thread} -p {processor_count} -f {sample_file}"
149+
# Run the command and get the output
150+
process = subprocess.Popen(
151+
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
152+
)
153+
154+
output = ""
155+
while process.poll() is None:
156+
output += process.stdout.read().decode()
157+
158+
# Parse the output
159+
load_time_match = re.search(r"load time\s*=\s*(\d+\.\d+)\s*ms", output)
160+
load_time = float(load_time_match.group(1)) if load_time_match else None
161+
162+
metal_device = extract_device(output)
163+
sample_time, sample_runs = extract_metrics(output, "sample time")
164+
encode_time, encode_runs = extract_metrics(output, "encode time")
165+
decode_time, decode_runs = extract_metrics(output, "decode time")
166+
167+
total_time_match = re.search(r"total time\s*=\s*(\d+\.\d+)\s*ms", output)
168+
total_time = float(total_time_match.group(1)) if total_time_match else None
169+
170+
model_name = model.replace("ggml-", "").replace(".bin", "")
171+
172+
print(
173+
f"Ran model={model_name} threads={thread} processor_count={processor_count}, took {total_time}ms"
174+
)
175+
# Store the times in the results dictionary
176+
results[(model_name, thread, processor_count)] = {
177+
loadTimeHeader: load_time,
178+
sampleTimeHeader: sample_time,
179+
encodeTimeHeader: encode_time,
180+
decodeTimeHeader: decode_time,
181+
sampleTimePerRunHeader: round(sample_time / sample_runs, 2),
182+
encodeTimePerRunHeader: round(encode_time / encode_runs, 2),
183+
decodeTimePerRunHeader: round(decode_time / decode_runs, 2),
184+
totalTimeHeader: total_time,
185+
}
186+
187+
# Write the results to a CSV file
188+
with open("benchmark_results.csv", "w", newline="") as csvfile:
189+
fieldnames = [
190+
gitHashHeader,
191+
modelHeader,
192+
hardwareHeader,
193+
recordingLengthHeader,
194+
threadHeader,
195+
processorCountHeader,
196+
loadTimeHeader,
197+
sampleTimeHeader,
198+
encodeTimeHeader,
199+
decodeTimeHeader,
200+
sampleTimePerRunHeader,
201+
encodeTimePerRunHeader,
202+
decodeTimePerRunHeader,
203+
totalTimeHeader,
204+
]
205+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
206+
207+
writer.writeheader()
208+
209+
shortHash = get_git_short_hash()
210+
# Sort the results by total time in ascending order
211+
sorted_results = sorted(results.items(), key=lambda x: x[1].get(totalTimeHeader, 0))
212+
for params, times in sorted_results:
213+
row = {
214+
gitHashHeader: shortHash,
215+
modelHeader: params[0],
216+
hardwareHeader: metal_device,
217+
recordingLengthHeader: recording_length,
218+
threadHeader: params[1],
219+
processorCountHeader: params[2],
220+
}
221+
row.update(times)
222+
writer.writerow(row)

0 commit comments

Comments
 (0)