Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/vulkan #170

Merged
merged 6 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ exclude = ["examples/full_usage"]

[package]
name = "whisper-rs"
version = "0.12.1"
version = "0.13.0"
edition = "2021"
description = "Rust bindings for whisper.cpp"
license = "Unlicense"
Expand All @@ -14,7 +14,7 @@ repository = "https://github.com/tazz4843/whisper-rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
whisper-rs-sys = { path = "sys", version = "0.10.1" }
whisper-rs-sys = { path = "sys", version = "0.11.0" }
log = { version = "0.4", optional = true }
tracing = { version = "0.1", optional = true }

Expand All @@ -23,19 +23,20 @@ hound = "3.5.0"
rand = "0.8.4"

[features]
default = []
default = ["openmp"]

raw-api = []
coreml = ["whisper-rs-sys/coreml"]
cuda = ["whisper-rs-sys/cuda", "_gpu"]
hipblas = ["whisper-rs-sys/hipblas", "_gpu"]
opencl = ["whisper-rs-sys/opencl"]
openblas = ["whisper-rs-sys/openblas"]
metal = ["whisper-rs-sys/metal", "_gpu"]
vulkan = ["whisper-rs-sys/vulkan", "_gpu"]
_gpu = []
test-with-tiny-model = []
whisper-cpp-log = ["dep:log"]
whisper-cpp-tracing = ["dep:tracing"]
openmp = ["whisper-rs-sys/openmp"]

[package.metadata.docs.rs]
features = ["simd"]
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,9 @@ All disabled by default unless otherwise specified.
as whisper-rs-sys may be upgraded to a breaking version in a patch release of whisper-rs.
* `cuda`: enable CUDA support. Implicitly enables hidden GPU flag at runtime.
* `hipblas`: enable ROCm/hipBLAS support. Only available on linux. Implicitly enables hidden GPU flag at runtime.
* `opencl`: enable OpenCL support. Upstream whisper.cpp does not treat OpenCL as a GPU, so it is always enabled at
runtime.
* `openblas`: enable OpenBLAS support.
* `metal`: enable Metal support. Implicitly enables hidden GPU flag at runtime.
* `vulkan`: enable Vulkan support. Implicitly enables hidden GPU flag at runtime.
* `whisper-cpp-log`: allows hooking into whisper.cpp's log output and sending it to the `log` backend. Requires calling
* `whisper-cpp-tracing`: allows hooking into whisper.cpp's log output and sending it to the `tracing` backend.

Expand Down
2 changes: 0 additions & 2 deletions src/standalone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ pub struct SystemInfo {
pub fma: bool,
pub f16c: bool,
pub blas: bool,
pub clblast: bool,
pub cuda: bool,
}

Expand All @@ -123,7 +122,6 @@ impl Default for SystemInfo {
fma: whisper_rs_sys::ggml_cpu_has_fma() != 0,
f16c: whisper_rs_sys::ggml_cpu_has_f16c() != 0,
blas: whisper_rs_sys::ggml_cpu_has_blas() != 0,
clblast: whisper_rs_sys::ggml_cpu_has_clblast() != 0,
cuda: whisper_rs_sys::ggml_cpu_has_cuda() != 0,
}
}
Expand Down
3 changes: 0 additions & 3 deletions src/whisper_ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,6 @@ unsafe impl Sync for WhisperInnerContext {}

pub struct WhisperContextParameters<'a> {
/// Use GPU if available.
///
/// **Warning**: Does not have an effect if OpenCL is selected as GPU backend
/// (in that case, GPU is always enabled).
pub use_gpu: bool,
/// Enable flash attention, default false
///
Expand Down
11 changes: 0 additions & 11 deletions src/whisper_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,6 @@ impl<'a, 'b> FullParams<'a, 'b> {
self.fp.max_tokens = max_tokens;
}

/// # EXPERIMENTAL
///
/// Speed up audio ~2x by using phase vocoder.
/// Note that this can significantly reduce the accuracy of the transcription.
///
/// Defaults to false.
pub fn set_speed_up(&mut self, speed_up: bool) {
self.fp.speed_up = speed_up;
}

/// # EXPERIMENTAL
///
/// Enables debug mode, such as dumping the log mel spectrogram.
Expand All @@ -242,7 +232,6 @@ impl<'a, 'b> FullParams<'a, 'b> {
/// # EXPERIMENTAL
///
/// Overwrite the audio context size. 0 = default.
/// As with [set_speed_up](FullParams::set_speed_up), this can significantly reduce the accuracy of the transcription.
///
/// Defaults to 0.
pub fn set_audio_ctx(&mut self, audio_ctx: c_int) {
Expand Down
39 changes: 0 additions & 39 deletions src/whisper_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,45 +64,6 @@ impl WhisperState {
}
}

/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
/// Applies a Phase Vocoder to speed up the audio x2.
/// The resulting spectrogram is stored in the context transparently.
///
/// # Arguments
/// * pcm: The raw PCM audio.
/// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
///
/// # Returns
/// Ok(()) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
pub fn pcm_to_mel_phase_vocoder(
&mut self,
pcm: &[f32],
threads: usize,
) -> Result<(), WhisperError> {
if threads < 1 {
return Err(WhisperError::InvalidThreadCount);
}
let ret = unsafe {
whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state(
self.ctx.ctx,
self.ptr,
pcm.as_ptr(),
pcm.len() as c_int,
threads as c_int,
)
};
if ret == -1 {
Err(WhisperError::UnableToCalculateSpectrogram)
} else if ret == 0 {
Ok(())
} else {
Err(WhisperError::GenericError(ret))
}
}

/// This can be used to set a custom log mel spectrogram inside the provided whisper state.
/// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
///
Expand Down
46 changes: 23 additions & 23 deletions sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "whisper-rs-sys"
version = "0.10.1"
version = "0.11.0"
edition = "2021"
description = "Rust bindings for whisper.cpp (FFI bindings)"
license = "Unlicense"
Expand All @@ -10,29 +10,28 @@ links = "whisper"
include = [
"whisper.cpp/bindings/javascript/package-tmpl.json",
"whisper.cpp/bindings/CMakeLists.txt",
"whisper.cpp/cmake",
"whisper.cpp/coreml",
"whisper.cpp/CMakeLists.txt",
"whisper.cpp/ggml.c",
"whisper.cpp/ggml.h",
"whisper.cpp/ggml-alloc.c",
"whisper.cpp/ggml-alloc.h",
"whisper.cpp/ggml-backend.c",
"whisper.cpp/ggml-backend.h",
"whisper.cpp/ggml-backend-impl.h",
"whisper.cpp/ggml-cuda.cu",
"whisper.cpp/ggml-cuda.h",
"whisper.cpp/ggml-impl.h",
"whisper.cpp/ggml-metal.h",
"whisper.cpp/ggml-metal.m",
"whisper.cpp/ggml-metal.metal",
"whisper.cpp/ggml-opencl.cpp",
"whisper.cpp/ggml-opencl.h",
"whisper.cpp/ggml-quants.h",
"whisper.cpp/ggml-quants.c",
"whisper.cpp/cmake",
"whisper.cpp/src/coreml",
"whisper.cpp/src/CMakeLists.txt",
"whisper.cpp/src/whisper.cpp",
"whisper.cpp/include/whisper.h",
"whisper.cpp/ggml/src/ggml.c",
"whisper.cpp/ggml/src/ggml-alloc.c",
"whisper.cpp/ggml/src/ggml-backend.c",
"whisper.cpp/ggml/src/ggml-cuda.cu",
"whisper.cpp/ggml/src/ggml-impl.h",
"whisper.cpp/ggml/src/ggml-metal.m",
"whisper.cpp/ggml/src/ggml-metal.metal",
"whisper.cpp/ggml/src/ggml-quants.h",
"whisper.cpp/ggml/src/ggml-quants.c",
"whisper.cpp/ggml/include/ggml.h",
"whisper.cpp/ggml/include/ggml-alloc.h",
"whisper.cpp/ggml/include/ggml-backend.h",
"whisper.cpp/ggml/include/ggml-backend-impl.h",
"whisper.cpp/ggml/include/ggml-cuda.h",
"whisper.cpp/ggml/include/ggml-metal.h",
"whisper.cpp/LICENSE",
"whisper.cpp/whisper.cpp",
"whisper.cpp/whisper.h",
"src/*.rs",
"build.rs",
"wrapper.h",
Expand All @@ -44,10 +43,11 @@ include = [
coreml = []
cuda = []
hipblas = []
opencl = []
openblas = []
metal = []
vulkan = []
force-debug = []
openmp = []

[build-dependencies]
cmake = "0.1"
Expand Down
76 changes: 51 additions & 25 deletions sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ fn main() {

#[cfg(feature = "coreml")]
println!("cargo:rustc-link-lib=static=whisper.coreml");
#[cfg(feature = "opencl")]
{
println!("cargo:rustc-link-lib=clblast");
println!("cargo:rustc-link-lib=OpenCL");
}
#[cfg(feature = "openblas")]
{
println!("cargo:rustc-link-lib=openblas");
Expand Down Expand Up @@ -81,6 +76,13 @@ fn main() {
}
}

#[cfg(feature = "openmp")]
{
if target.contains("gnu") {
println!("cargo:rustc-link-lib=gomp");
}
}

println!("cargo:rerun-if-changed=wrapper.h");

let out = PathBuf::from(env::var("OUT_DIR").unwrap());
Expand All @@ -104,10 +106,11 @@ fn main() {
let bindings = bindgen::Builder::default().header("wrapper.h");

#[cfg(feature = "metal")]
let bindings = bindings.header("whisper.cpp/ggml-metal.h");
let bindings = bindings.header("whisper.cpp/ggml/include/ggml-metal.h");

let bindings = bindings
.clang_arg("-I./whisper.cpp")
.clang_arg("-I./whisper.cpp/include")
.clang_arg("-I./whisper.cpp/ggml/include")
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.generate();

Expand Down Expand Up @@ -150,11 +153,11 @@ fn main() {
}

if cfg!(feature = "cuda") {
config.define("WHISPER_CUDA", "ON");
config.define("GGML_CUDA", "ON");
}

if cfg!(feature = "hipblas") {
config.define("WHISPER_HIPBLAS", "ON");
config.define("GGML_HIPBLAS", "ON");
config.define("CMAKE_C_COMPILER", "hipcc");
config.define("CMAKE_CXX_COMPILER", "hipcc");
println!("cargo:rerun-if-env-changed=AMDGPU_TARGETS");
Expand All @@ -163,21 +166,35 @@ fn main() {
}
}

if cfg!(feature = "openblas") {
config.define("WHISPER_OPENBLAS", "ON");
if cfg!(feature = "vulkan") {
config.define("GGML_VULKAN", "ON");
if cfg!(windows) {
println!("cargo:rerun-if-env-changed=VULKAN_SDK");
println!("cargo:rustc-link-lib=vulkan-1");
let vulkan_path = match env::var("VULKAN_SDK") {
Ok(path) => PathBuf::from(path),
Err(_) => panic!(
"Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set"
),
};
let vulkan_lib_path = vulkan_path.join("Lib");
println!("cargo:rustc-link-search={}", vulkan_lib_path.display());
} else {
println!("cargo:rustc-link-lib=vulkan");
}
}

if cfg!(feature = "opencl") {
config.define("WHISPER_CLBLAST", "ON");
if cfg!(feature = "openblas") {
config.define("GGML_BLAS", "ON");
}

if cfg!(feature = "metal") {
config.define("WHISPER_METAL", "ON");
config.define("WHISPER_METAL_NDEBUG", "ON");
config.define("WHISPER_METAL_EMBED_LIBRARY", "ON");
config.define("GGML_METAL", "ON");
config.define("GGML_METAL_NDEBUG", "ON");
config.define("GGML_METAL_EMBED_LIBRARY", "ON");
} else {
// Metal is enabled by default, so we need to explicitly disable it
config.define("WHISPER_METAL", "OFF");
config.define("GGML_METAL", "OFF");
}

if cfg!(debug_assertions) || cfg!(feature = "force-debug") {
Expand All @@ -197,18 +214,17 @@ fn main() {
}
}

if cfg!(not(feature = "openmp")) {
config.define("GGML_OPENMP", "OFF");
}

let destination = config.build();

if target.contains("window") && !target.contains("gnu") {
println!(
"cargo:rustc-link-search={}",
out.join("build").join("Release").display()
);
} else {
println!("cargo:rustc-link-search={}", out.join("build").display());
}
add_link_search_path(&out.join("lib")).unwrap();

println!("cargo:rustc-link-search=native={}", destination.display());
println!("cargo:rustc-link-lib=static=whisper");
println!("cargo:rustc-link-lib=static=ggml");

// for whatever reason this file is generated during build and triggers cargo complaining
_ = std::fs::remove_file("bindings/javascript/package.json");
Expand All @@ -226,3 +242,13 @@ fn get_cpp_link_stdlib(target: &str) -> Option<&'static str> {
Some("stdc++")
}
}

fn add_link_search_path(dir: &std::path::Path) -> std::io::Result<()> {
if dir.is_dir() {
println!("cargo:rustc-link-search={}", dir.display());
for entry in std::fs::read_dir(dir)? {
add_link_search_path(&entry?.path())?;
}
}
Ok(())
}
Loading
Loading