Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significant performance penalty w.r.t. whisper.cpp #74

Closed
wdoppenberg opened this issue Jul 18, 2023 · 10 comments
Closed

Significant performance penalty w.r.t. whisper.cpp #74

wdoppenberg opened this issue Jul 18, 2023 · 10 comments

Comments

@wdoppenberg
Copy link

wdoppenberg commented Jul 18, 2023

First of all, thank you for your work creating a safe wrapper around whisper.cpp.

As mentioned in #73, the performance of whisper-rs is quite poor compared to the reference implementation. I'll attempt to demonstrate below.

Setup

I'm using an M2 Max Macbook Pro with 64GB of (shared) memory. My goal is to run a web server with CoreML enabled, but if necessary I can run the tests with CPU only later. I'll attach output generated by flamegraph.

Rust script

Click me
// src/main.rs
mod utils;

use clap::Parser;
use anyhow::Result;
use log::info;

use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
use crate::utils::{decode, forward_pass, load_audio};

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct WhisperCli {
    #[arg(short, long, default_value = "sample_data/jfk.wav")]
    input_file: String,
    #[arg(short, long, default_value = "models/ggml-medium.bin")]
    model_path: String,
}


fn main() -> Result<()> {
    env_logger::init();
    let args = WhisperCli::parse();

    let mut params = FullParams::new(
        SamplingStrategy::BeamSearch { beam_size: 12, patience: 0.1}
    );

    params.set_n_threads(8);
    params.set_translate(false);


    info!("Loading audio file {}", args.input_file);
    let raw_audio = load_audio(&args.input_file)?;

    info!("Decoding audio file");
    let decoded_audio = decode(raw_audio)?;

    info!("Loading model from {}", args.model_path);
    let whisper_context = WhisperContext::new(&args.model_path)?;
    let mut whisper_state = whisper_context.create_state()?;

    info!("Running forward pass");
    let result = forward_pass(&decoded_audio, params, &mut whisper_state);

    println!("{}", result);
    Ok(())
}
// src/utils.rs
use anyhow::Result;
use std::fs::File;
use std::io::Read;

use std::io::Cursor;
use log::{debug, trace};
use rodio::{Decoder, Source};
use rodio::source::UniformSourceIterator;
use whisper_rs::{FullParams, WhisperState};

const SAMPLE_RATE: u32 = 16000;
const CHANNELS: u16 = 1;

const LOW_PASS: u32 = 3000;
const HIGH_PASS: u32 = 200;

/// Load an audio file from the given path
/// and return as a vector of u8
///
/// # Arguments
/// * `path` - The path to the audio file
pub fn load_audio(path: &str) -> Result<Vec<u8>> {
    let mut file = File::open(path)?;
    let mut buffer = Vec::new();

    file.read_to_end(&mut buffer)?;

    Ok(buffer)
}


/// Decode the audio file and return as a vector of f32
pub fn decode(bytes: Vec<u8>) -> Result<Vec<f32>> {
    // Decode the audio file
    let input = Cursor::new(bytes);
    let source = Decoder::new(input)?;

    // Resample to output sample rate and channels
    let resample = UniformSourceIterator::new(
        source, CHANNELS, SAMPLE_RATE,
    );
    // High and low pass filters to enhance the audio
    let pass_filter = resample
        .low_pass(LOW_PASS)
        .high_pass(HIGH_PASS)
        .convert_samples();

    Ok(whisper_rs::convert_integer_to_float_audio(&pass_filter.collect::<Vec<i16>>()))
}


pub fn forward_pass(decoded_audio: &[f32], params: FullParams, whisper_state: &mut WhisperState) -> String {
    debug!("Starting forward pass");
    let _ = whisper_state
        .full(params, decoded_audio)
        .expect("Failed to run model");
    let mut result = String::new();

    debug!("Decoding results");
    // fetch the results
    let num_segments = whisper_state
        .full_n_segments()
        .expect("failed to get number of segments");

    for i in 0..num_segments {
        let segment = whisper_state
            .full_get_segment_text(i)
            .expect("failed to get segment");
        let start_timestamp = whisper_state
            .full_get_segment_t0(i)
            .expect("failed to get segment start timestamp");
        let end_timestamp = whisper_state
            .full_get_segment_t1(i)
            .expect("failed to get segment end timestamp");
        trace!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
        result.push_str(&format!("[{} - {}]: {}\n", start_timestamp, end_timestamp, segment));
    }
    result
}
# Cargo.toml
[package]
name = "whisper-rs-cli"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
anyhow = "1.0.71"
whisper-rs = { version = "0.8.0", features = ["coreml"] }
clap = { version = "4.3.12", features = ["derive"]}
env_logger = "0.10.0"
log = "0.4.19"
rodio = "0.17.1"


[[bin]]
    name = "whisper-rs-cli"
    path = "src/main.rs"

Data

For testing, I've converted a short JFK speech to WAV. See this link. Converting to WAV is done using ffmpeg:

ffmpeg -i main.mp4 -acodec pcm_s16le -ac 1 -ar 16000 jfk.wav

Results

I won't do averages of iterations since the differences are quite clear. Furthermore I've run both scripts before to ensure that the CoreML model is properly compiled for my architecture. I can confirm that the model loading step is not the issue. The chosen model is ggml-medium.bin. The commands used are, given that you have compiled whisper.cpp & whisper-rs and are in the root of each repository, as follows:

sudo time flamegraph -- ./main -m models/ggml-medium.bin -t 8  -f jfk.wav
sudo time flamegraph -- target/release/whisper-rs-cli

CoreML enabled

whisper-rs

85.27 real       537.47 user         6.41 sys

Rust Flamegraph

whisper.cpp

12.30 real        42.14 user        15.36 sys

Cpp Flamegraph

Please let me know what you think and where I can help out. Admittedly I'm a bit inexperienced with Rust but I'd love to learn, especially solving such an issue.

@randomairborne
Copy link
Contributor

I was able to reproduce this, and am working on it with @tazz4843

@tazz4843
Copy link
Owner

At this point the only thing I have that could be the difference is that by default whisper-rs is using v1.4.2 of whisper.cpp, while you're likely using git master upstream for these tests. Seems that upgrading whisper-rs's version to git master speeds it up somewhat, but I don't have Apple Silicon myself to test on so I can't do much myself in terms of poking around.

@jbrough
Copy link
Contributor

jbrough commented Jul 22, 2023

My test isn't the same but runs 2 seconds slower after updating whisper-rs to use whisper.cpp master

@wdoppenberg
Copy link
Author

wdoppenberg commented Jul 26, 2023

I've also pulled main for whisper.cpp. I had to edit 2 function signatures in whisper-rs to get it to work. I had to add *mut whisper_context as an argument to the following:

/// Get the ID of the translate task token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_translate ()`
pub fn token_translate(ctx: *mut whisper_context) -> WhisperToken {
    unsafe { whisper_rs_sys::whisper_token_translate(ctx) }
}

/// Get the ID of the transcribe task token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_transcribe()`
pub fn token_transcribe(ctx: *mut whisper_context) -> WhisperToken {
    unsafe { whisper_rs_sys::whisper_token_transcribe(ctx) }
}

Only slight improvements:

77.39 real       354.37 user       118.68 sys

@chen-rn
Copy link

chen-rn commented Jul 31, 2023

@wdoppenberg do you have any guesses as to why there's such a big performance discrepancy?

@tazz4843
Copy link
Owner

Given I don't have Apple Silicon myself to test on, I can't do much to help with this besides suggest x86 instead. Hopefully someone else can figure it out.

@jbrough
Copy link
Contributor

jbrough commented Aug 1, 2023

I have M1 and M2 and can try it out if someone provides a repo with a test case that can be run.

right now, for the canonical "your country" jfk test case, whisper-rs and whisper.cpp are the same.

I put a timer around whisper.cpp whisper_full_params - it took 1612.61 ms

I put a timer around state.full -> state.get_segment_text - it took 1648 ms

C code changes:

     struct whisper_full_params   params,
                    const float * samples,
                            int   n_samples) {
-    // clear old results
+
+   struct timeval start_time, end_time;
+    gettimeofday(&start_time, NULL);
+  // clear old results
     auto & result_all = state->result_all;

     result_all.clear();
@@ -4761,7 +4768,16 @@ int whisper_full_with_state(
         }
     }

-    return 0;
+      gettimeofday(&end_time, NULL);
+
+    // Calculate the elapsed time in milliseconds
+    double elapsed_ms = (end_time.tv_sec - start_time.tv_sec) * 1000.0 +
+                       (end_time.tv_usec - start_time.tv_usec) / 1000.0;
+
+    // Print the result to stdout
+    printf("Elapsed time: %.2f ms\n", elapsed_ms);
+
+     return 0;
 }
❯ time ./main -m models/ggml-medium.en.bin -t 8  -f jfk.wav
whisper_init_from_file_no_state: loading model from 'models/ggml-medium.en.bin'
whisper_model_load: loading model
whisper_model_load: n_vocab       = 51864
whisper_model_load: n_audio_ctx   = 1500
whisper_model_load: n_audio_state = 1024
whisper_model_load: n_audio_head  = 16
whisper_model_load: n_audio_layer = 24
whisper_model_load: n_text_ctx    = 448
whisper_model_load: n_text_state  = 1024
whisper_model_load: n_text_head   = 16
whisper_model_load: n_text_layer  = 24
whisper_model_load: n_mels        = 80
whisper_model_load: ftype         = 1
whisper_model_load: qntvr         = 0
whisper_model_load: type          = 4
whisper_model_load: mem required  = 1899.00 MB (+   43.00 MB per decoder)
whisper_model_load: adding 1607 extra tokens
whisper_model_load: model ctx     = 1462.58 MB
whisper_model_load: model size    = 1462.12 MB
whisper_init_state: kv self size  =   42.00 MB
whisper_init_state: kv cross size =  140.62 MB
whisper_init_state: loading Core ML model from 'models/ggml-medium.en-encoder.mlmodelc'
whisper_init_state: first run on a device may take a while ...
whisper_init_state: Core ML model loaded

system_info: n_threads = 8 / 8 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | COREML = 1 | OPENVINO = 0 |

main: processing 'jfk.wav' (176000 samples, 11.0 sec), 8 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...


[00:00:00.000 --> 00:00:11.000]   And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
Elapsed time: 1612.61 ms

whisper-rs:

whisper_init_state: first run on a device may take a while ...
whisper_init_state: Core ML model loaded
[src/main.rs:66] idx = 1041
1041 [00:00:10.410]  And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
[src/main.rs:98] elapsed = 1648
 65             if let Some(frames) = buf.add(idx, mel) {
 66                 dbg!(idx);
 67                 let path = format!("{}/frame_{}.tga", mel_path, idx);
 68                 let _ = save_tga_8bit(&frames, n_mels, &path);
 69
 70                 let ms = duration_ms_for_n_frames(hop_size, sampling_rate, idx);
 71                 let time = format_milliseconds(ms as u64);
 72
 73                 let start = std::time::Instant::now();
 74                 let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
 75                 params.set_n_threads(6);
 76                 params.set_single_segment(true);
 77                 params.set_language(Some("en"));
 78                 params.set_print_special(false);
 79                 params.set_print_progress(false);
 80                 params.set_print_realtime(false);
 81                 params.set_print_timestamps(false);
 82                 state.set_mel(&frames).unwrap();
 83
 84                 let empty = vec![];
 85                 state.full(params, &empty[..]).unwrap();
 86
 87                 let num_segments = state.full_n_segments().unwrap();
 88                 if num_segments > 0 {
 89                     if let Ok(text) = state.full_get_segment_text(0) {
 90                         let msg = format!("{} [{}] {}", idx, time, text);
 91                         println!("{}", msg);
 92                     } else {
 93                         println!("Error retrieving text for segment.");
 94                     }
 95                 }
 96
 97                 let elapsed = start.elapsed().as_millis();
 98                 dbg!(elapsed);
 99             }

I'm using the set_mel api but there's no reason to think pcm audio will be slower rust vs c. If someone can provide a test case that can be run easily I can try them out

@wdoppenberg
Copy link
Author

wdoppenberg commented Aug 1, 2023

Still haven't been able to find the issue. When examining the call tree I find that, as one might expect, almost all calls are part of the ggml_compute_forward_mul_mat when running in a single thread.

Screenshot 2023-08-01 at 09 05 02

It almost feels like the whisper.cpp lib is not compiled with optimization flags enabled, which is not the case ofcourse.

@jbrough
Copy link
Contributor

jbrough commented Aug 1, 2023

think it's your machine?

@wdoppenberg
Copy link
Author

I have found the issue:

In my script, I used SamplingStrategy::BeamSearch { beam_size: 12, patience: 0.1}. This is quite heavy...

Using SamplingStrategy::Greedy { best_of: 0 } it runs in roughly 10.11 seconds (with a sync'd whisper.cpp submodule).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants