diff --git a/CHANGELOG.md b/CHANGELOG.md index edbf963..264fb6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - Telegram voice and audio message handling with automatic file download (#524) - STT bootstrap wiring: `WhisperProvider` created from `[llm.stt]` config behind `stt` feature (#529) - Slack audio file upload handling with host validation and size limits (#525) +- Local Whisper backend via candle for offline STT with symphonia audio decode and rubato resampling (#523) - Shell-based installation script (`install/install.sh`) with SHA256 verification, platform detection, and `--version` flag - Shellcheck lint job in CI pipeline - Per-job permission scoping in release workflow (least privilege) diff --git a/Cargo.lock b/Cargo.lock index 45467f7..aa3118b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1842,6 +1842,12 @@ dependencies = [ "regex", ] +[[package]] +name = "extended" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af9673d8203fcb076b19dfd17e38b3d4ae9f44959416ea532ce72415a6020365" + [[package]] name = "fancy-regex" version = "0.11.0" @@ -4538,6 +4544,15 @@ dependencies = [ "syn 2.0.116", ] +[[package]] +name = "primal-check" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0d895b311e3af9902528fbb8f928688abbd95872819320517cc24ca6b2bd08" +dependencies = [ + "num-integer", +] + [[package]] name = "proc-macro-error-attr2" version = "2.0.0" @@ -5056,6 +5071,15 @@ dependencies = [ "erasable", ] +[[package]] +name = "realfft" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f821338fddb99d089116342c46e9f1fbf3828dba077674613e734e01d6ea8677" +dependencies = [ + "rustfft", +] + [[package]] name = "reborrow" version = "0.5.5" @@ -5310,6 +5334,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rubato" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5258099699851cfd0082aeb645feb9c084d9a5e1f1b8d5372086b989fc5e56a1" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "realfft", +] + [[package]] name = "rust-embed" version = "8.11.0" @@ -5365,6 +5401,20 @@ dependencies = [ "semver", ] +[[package]] +name = "rustfft" +version = "6.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21db5f9893e91f41798c88680037dba611ca6674703c1a18601b01a72c8adb89" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", +] + [[package]] name = "rustix" version = "1.1.3" @@ -6289,6 +6339,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "string_cache" version = "0.9.0" @@ -6381,6 +6437,115 @@ version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" +[[package]] +name = "symphonia" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5773a4c030a19d9bfaa090f49746ff35c75dfddfa700df7a5939d5e076a57039" +dependencies = [ + "lazy_static", + "symphonia-bundle-flac", + "symphonia-bundle-mp3", + "symphonia-codec-pcm", + "symphonia-core", + "symphonia-format-ogg", + "symphonia-format-riff", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-bundle-flac" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c91565e180aea25d9b80a910c546802526ffd0072d0b8974e3ebe59b686c9976" +dependencies = [ + "log", + "symphonia-core", + "symphonia-metadata", + "symphonia-utils-xiph", +] + +[[package]] +name = "symphonia-bundle-mp3" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4872dd6bb56bf5eac799e3e957aa1981086c3e613b27e0ac23b176054f7c57ed" +dependencies = [ + "lazy_static", + "log", + "symphonia-core", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-codec-pcm" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e89d716c01541ad3ebe7c91ce4c8d38a7cf266a3f7b2f090b108fb0cb031d95" +dependencies = [ + "log", + "symphonia-core", +] + +[[package]] +name = "symphonia-core" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea00cc4f79b7f6bb7ff87eddc065a1066f3a43fe1875979056672c9ef948c2af" +dependencies = [ + "arrayvec", + "bitflags 1.3.2", + "bytemuck", + "lazy_static", + "log", +] + +[[package]] +name = "symphonia-format-ogg" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b4955c67c1ed3aa8ae8428d04ca8397fbef6a19b2b051e73b5da8b1435639cb" +dependencies = [ + "log", + "symphonia-core", + "symphonia-metadata", + "symphonia-utils-xiph", +] + +[[package]] +name = "symphonia-format-riff" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2d7c3df0e7d94efb68401d81906eae73c02b40d5ec1a141962c592d0f11a96f" +dependencies = [ + "extended", + "log", + "symphonia-core", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-metadata" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36306ff42b9ffe6e5afc99d49e121e0bd62fe79b9db7b9681d48e29fa19e6b16" +dependencies = [ + "encoding_rs", + "lazy_static", + "log", + "symphonia-core", +] + +[[package]] +name = "symphonia-utils-xiph" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee27c85ab799a338446b68eec77abf42e1a6f1bb490656e121c6e27bfbab9f16" +dependencies = [ + "symphonia-core", + "symphonia-metadata", +] + [[package]] name = "syn" version = "1.0.109" @@ -7166,6 +7331,16 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + [[package]] name = "tree-sitter" version = "0.26.5" @@ -8775,9 +8950,11 @@ dependencies = [ "ollama-rs", "proptest", "reqwest 0.13.2", + "rubato", "schemars 1.2.1", "serde", "serde_json", + "symphonia", "thiserror 2.0.18", "tokenizers", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 1b58928..0976621 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,11 +43,13 @@ reqwest = { version = "0.13", default-features = false } rmcp = "0.15" scrape-core = "0.2.2" subtle = "2.6" +rubato = "0.16" schemars = "1.2" similar = "2.7" serde = "1.0" serde_json = "1.0" serial_test = "3.3" +symphonia = { version = "0.5.5", default-features = false, features = ["mp3", "ogg", "wav", "flac", "pcm"] } sqlx = { version = "0.8", default-features = false, features = ["macros"] } teloxide = { version = "0.17", default-features = false, features = ["rustls", "ctrlc_handler", "macros"] } tempfile = "3" diff --git a/README.md b/README.md index 6bffadd..21ba498 100644 --- a/README.md +++ b/README.md @@ -239,7 +239,7 @@ Skills **evolve**: failure detection triggers self-reflection, and the agent gen |----------|-------------| | **MCP** | Connect external tool servers (stdio + HTTP) with SSRF protection | | **A2A** | Agent-to-agent communication via JSON-RPC 2.0 with SSE streaming | -| **Audio input** | Speech-to-text via OpenAI Whisper (25 MB limit); Telegram and Slack audio files transcribed automatically | +| **Audio input** | Speech-to-text via OpenAI Whisper API or local Candle Whisper (offline, feature-gated); Telegram and Slack audio files transcribed automatically | | **Channels** | CLI, Telegram (text + voice), Discord, Slack, TUI — all with streaming support | | **Gateway** | HTTP webhook ingestion with bearer auth and rate limiting | | **Native tool_use** | Structured tool calling via Claude/OpenAI APIs; text fallback for local models | diff --git a/crates/zeph-llm/Cargo.toml b/crates/zeph-llm/Cargo.toml index 425f594..e469faf 100644 --- a/crates/zeph-llm/Cargo.toml +++ b/crates/zeph-llm/Cargo.toml @@ -10,7 +10,7 @@ repository.workspace = true default = [] mock = [] stt = ["reqwest/multipart"] -candle = ["dep:candle-core", "dep:candle-nn", "dep:candle-transformers", "dep:hf-hub", "dep:tokenizers"] +candle = ["dep:candle-core", "dep:candle-nn", "dep:candle-transformers", "dep:hf-hub", "dep:tokenizers", "dep:symphonia", "dep:rubato"] cuda = ["candle", "candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"] metal = ["candle", "candle-core/metal", "candle-nn/metal", "candle-transformers/metal"] @@ -28,6 +28,8 @@ schemars.workspace = true serde = { workspace = true, features = ["derive"] } serde_json.workspace = true tokenizers = { workspace = true, optional = true } +rubato = { workspace = true, optional = true } +symphonia = { workspace = true, optional = true } tokio = { workspace = true, features = ["rt", "sync", "time"] } tokio-stream.workspace = true tracing.workspace = true diff --git a/crates/zeph-llm/README.md b/crates/zeph-llm/README.md index 6892c37..0cf2404 100644 --- a/crates/zeph-llm/README.md +++ b/crates/zeph-llm/README.md @@ -19,6 +19,7 @@ Defines the `LlmProvider` trait and ships concrete backends for Ollama, Claude, | `orchestrator` | Multi-model coordination and fallback | | `router` | Model selection and routing logic | | `stt` | `SpeechToText` trait and `WhisperProvider` (OpenAI Whisper, feature-gated behind `stt`) | +| `candle_whisper` | Local offline STT via Candle (whisper-tiny/base/small, feature-gated behind `candle`) | | `error` | `LlmError` — unified error type | **Re-exports:** `LlmProvider`, `LlmError` diff --git a/crates/zeph-llm/src/candle_whisper.rs b/crates/zeph-llm/src/candle_whisper.rs new file mode 100644 index 0000000..b0f7d1f --- /dev/null +++ b/crates/zeph-llm/src/candle_whisper.rs @@ -0,0 +1,379 @@ +use std::future::Future; +use std::io::Cursor; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; + +use candle_core::{Device, IndexOp, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::whisper::{self as m, Config}; +use tokenizers::Tokenizer; + +use crate::error::LlmError; +use crate::stt::{SpeechToText, Transcription}; + +#[derive(Clone)] +pub struct CandleWhisperProvider { + model: Arc>, + config: Config, + mel_filters: Vec, + tokenizer: Arc, + device: Device, +} + +impl std::fmt::Debug for CandleWhisperProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CandleWhisperProvider") + .field("device", &device_name(&self.device)) + .finish_non_exhaustive() + } +} + +fn device_name(d: &Device) -> &'static str { + match d { + Device::Cpu => "cpu", + Device::Cuda(_) => "cuda", + Device::Metal(_) => "metal", + } +} + +fn detect_device() -> Device { + #[cfg(feature = "metal")] + { + if let Ok(d) = Device::new_metal(0) { + return d; + } + } + #[cfg(feature = "cuda")] + { + if let Ok(d) = Device::new_cuda(0) { + return d; + } + } + Device::Cpu +} + +impl CandleWhisperProvider { + /// Load a Whisper model from a HuggingFace repo. + /// + /// # Errors + /// + /// Returns `LlmError::ModelLoad` if downloading or loading fails. + pub fn load(repo_id: &str, device: Option) -> Result { + let device = device.unwrap_or_else(detect_device); + tracing::info!( + repo = repo_id, + device = device_name(&device), + "loading candle whisper model" + ); + + let api = hf_hub::api::sync::Api::new() + .map_err(|e| LlmError::ModelLoad(format!("hf-hub init: {e}")))?; + let repo = api.model(repo_id.to_string()); + + let config_path = repo + .get("config.json") + .map_err(|e| LlmError::ModelLoad(format!("config.json: {e}")))?; + let tokenizer_path = repo + .get("tokenizer.json") + .map_err(|e| LlmError::ModelLoad(format!("tokenizer.json: {e}")))?; + let weights_path = repo + .get("model.safetensors") + .map_err(|e| LlmError::ModelLoad(format!("model.safetensors: {e}")))?; + + let config: Config = serde_json::from_reader(std::io::BufReader::new( + std::fs::File::open(&config_path) + .map_err(|e| LlmError::ModelLoad(format!("open config: {e}")))?, + ))?; + + let tokenizer = Tokenizer::from_file(&tokenizer_path) + .map_err(|e| LlmError::ModelLoad(format!("tokenizer: {e}")))?; + + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device) + .map_err(|e| LlmError::ModelLoad(format!("weights: {e}")))? + }; + + let model = m::model::Whisper::load(&vb, config.clone())?; + + let mel_bytes = match config.num_mel_bins { + 80 => include_bytes!("melfilters.bytes").as_slice(), + 128 => include_bytes!("melfilters128.bytes").as_slice(), + n => { + return Err(LlmError::ModelLoad(format!( + "unsupported num_mel_bins: {n}" + ))); + } + }; + let mut mel_filters = vec![0f32; mel_bytes.len() / 4]; + for (i, chunk) in mel_bytes.chunks_exact(4).enumerate() { + mel_filters[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); + } + + tracing::info!("candle whisper model loaded"); + + Ok(Self { + model: Arc::new(Mutex::new(model)), + config, + mel_filters, + tokenizer: Arc::new(tokenizer), + device, + }) + } + + fn transcribe_sync(&self, audio: &[u8]) -> Result { + let pcm = decode_audio(audio)?; + let mel = m::audio::pcm_to_mel(&self.config, &pcm, &self.mel_filters); + let mel_len = mel.len(); + let n_mel = self.config.num_mel_bins; + + let mel = Tensor::from_vec(mel, (1, n_mel, mel_len / n_mel), &self.device)?; + + let sot = self + .tokenizer + .token_to_id(m::SOT_TOKEN) + .ok_or_else(|| LlmError::TranscriptionFailed("missing SOT token".into()))?; + let transcribe = self + .tokenizer + .token_to_id(m::TRANSCRIBE_TOKEN) + .ok_or_else(|| LlmError::TranscriptionFailed("missing TRANSCRIBE token".into()))?; + let no_timestamps = self + .tokenizer + .token_to_id(m::NO_TIMESTAMPS_TOKEN) + .ok_or_else(|| LlmError::TranscriptionFailed("missing NO_TIMESTAMPS token".into()))?; + let eot = self + .tokenizer + .token_to_id(m::EOT_TOKEN) + .ok_or_else(|| LlmError::TranscriptionFailed("missing EOT token".into()))?; + + let language_token = self.tokenizer.token_to_id("<|en|>").ok_or_else(|| { + LlmError::TranscriptionFailed("language token not found in tokenizer".into()) + })?; + + let mut model = self + .model + .lock() + .map_err(|e| LlmError::TranscriptionFailed(format!("lock: {e}")))?; + model.reset_kv_cache(); + + let audio_features = model.encoder.forward(&mel, true)?; + + const MAX_DECODE_TOKENS: usize = 224; + + let mut tokens = vec![sot, language_token, transcribe, no_timestamps]; + + for _ in 0..MAX_DECODE_TOKENS { + let token_tensor = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; + let logits = + model + .decoder + .forward(&token_tensor, &audio_features, tokens.len() == 4)?; + + let (_, seq_len, _) = logits.dims3()?; + let next_logits = logits.i((0, seq_len - 1))?; + let next_token = next_logits + .argmax(candle_core::D::Minus1)? + .to_scalar::()?; + + if next_token == eot { + break; + } + tokens.push(next_token); + } + + // Decode only generated tokens (skip prompt tokens) + let generated = &tokens[4..]; + let text = self + .tokenizer + .decode(generated, true) + .map_err(|e| LlmError::TranscriptionFailed(format!("decode: {e}")))?; + + Ok(Transcription { + text: text.trim().to_string(), + language: Some("en".into()), + duration_secs: Some(pcm.len() as f32 / m::SAMPLE_RATE as f32), + }) + } +} + +impl SpeechToText for CandleWhisperProvider { + fn transcribe( + &self, + audio: &[u8], + _filename: Option<&str>, + ) -> Pin> + Send + '_>> { + let audio = audio.to_vec(); + Box::pin(async move { + let provider = self.clone(); + tokio::task::spawn_blocking(move || provider.transcribe_sync(&audio)) + .await + .map_err(|e| LlmError::TranscriptionFailed(e.to_string()))? + }) + } +} + +fn decode_audio(bytes: &[u8]) -> Result, LlmError> { + use symphonia::core::audio::SampleBuffer; + use symphonia::core::codecs::DecoderOptions; + use symphonia::core::formats::FormatOptions; + use symphonia::core::io::MediaSourceStream; + use symphonia::core::meta::MetadataOptions; + use symphonia::core::probe::Hint; + + let cursor = Cursor::new(bytes.to_vec()); + let mss = MediaSourceStream::new(Box::new(cursor), Default::default()); + + let probed = symphonia::default::get_probe() + .format( + &Hint::new(), + mss, + &FormatOptions::default(), + &MetadataOptions::default(), + ) + .map_err(|e| LlmError::TranscriptionFailed(format!("probe: {e}")))?; + + let mut format = probed.format; + let track = format + .default_track() + .ok_or_else(|| LlmError::TranscriptionFailed("no audio track".into()))?; + let sample_rate = track + .codec_params + .sample_rate + .ok_or_else(|| LlmError::TranscriptionFailed("unknown sample rate".into()))?; + let channels = track.codec_params.channels.map_or(1, |c| c.count()); + let track_id = track.id; + + let mut decoder = symphonia::default::get_codecs() + .make(&track.codec_params, &DecoderOptions::default()) + .map_err(|e| LlmError::TranscriptionFailed(format!("decoder: {e}")))?; + + let mut pcm = Vec::new(); + + while let Ok(packet) = format.next_packet() { + if packet.track_id() != track_id { + continue; + } + let decoded = match decoder.decode(&packet) { + Ok(d) => d, + Err(e) => { + tracing::trace!("skipping packet decode error: {e}"); + continue; + } + }; + let spec = *decoded.spec(); + let mut sample_buf = SampleBuffer::::new(decoded.capacity() as u64, spec); + sample_buf.copy_interleaved_ref(decoded); + let samples = sample_buf.samples(); + + if channels > 1 { + for chunk in samples.chunks(channels) { + let avg = chunk.iter().sum::() / channels as f32; + pcm.push(avg); + } + } else { + pcm.extend_from_slice(samples); + } + } + + if pcm.is_empty() { + return Err(LlmError::TranscriptionFailed("no audio decoded".into())); + } + + // Guard against pathological inputs: max 5 minutes at the source sample rate + let max_samples = 5 * 60 * sample_rate as usize; + if pcm.len() > max_samples { + return Err(LlmError::TranscriptionFailed(format!( + "audio too long: {} samples exceeds {max_samples} limit (5 min)", + pcm.len() + ))); + } + + // Resample to 16kHz if needed + if sample_rate != m::SAMPLE_RATE as u32 { + pcm = resample(&pcm, sample_rate, m::SAMPLE_RATE as u32)?; + } + + Ok(pcm) +} + +fn resample(input: &[f32], from_rate: u32, to_rate: u32) -> Result, LlmError> { + use rubato::{ + Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction, + }; + + let params = SincInterpolationParameters { + sinc_len: 256, + f_cutoff: 0.95, + interpolation: SincInterpolationType::Linear, + oversampling_factor: 256, + window: WindowFunction::BlackmanHarris2, + }; + + let ratio = f64::from(to_rate) / f64::from(from_rate); + let mut resampler = SincFixedIn::::new(ratio, 2.0, params, input.len(), 1) + .map_err(|e| LlmError::TranscriptionFailed(format!("resampler init: {e}")))?; + + let output = resampler + .process(&[input], None) + .map_err(|e| LlmError::TranscriptionFailed(format!("resample: {e}")))?; + + Ok(output.into_iter().next().unwrap_or_default()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn device_detection_returns_cpu_by_default() { + let d = detect_device(); + // On CI without GPU, should be CPU + assert!(matches!( + d, + Device::Cpu | Device::Metal(_) | Device::Cuda(_) + )); + } + + #[test] + fn debug_format() { + let d = detect_device(); + let name = device_name(&d); + assert!(!name.is_empty()); + } + + #[test] + fn decode_audio_rejects_invalid_bytes() { + let result = decode_audio(&[0, 1, 2, 3, 4]); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("probe"), "expected probe error, got: {err}"); + } + + #[test] + fn decode_audio_rejects_empty_input() { + let result = decode_audio(&[]); + assert!(result.is_err()); + } + + #[test] + fn resample_zeros_preserves_silence() { + let input = vec![0.0_f32; 1000]; + let output = resample(&input, 44100, 16000).unwrap(); + assert!(!output.is_empty()); + for &s in &output { + assert!(s.abs() < 1e-6, "expected silence, got {s}"); + } + } + + #[test] + fn resample_changes_length() { + let input = vec![0.5_f32; 44100]; + let output = resample(&input, 44100, 16000).unwrap(); + let expected_len = (44100.0 * 16000.0 / 44100.0) as usize; + let tolerance = expected_len / 10; + assert!( + output.len().abs_diff(expected_len) < tolerance, + "expected ~{expected_len} samples, got {}", + output.len() + ); + } +} diff --git a/crates/zeph-llm/src/lib.rs b/crates/zeph-llm/src/lib.rs index a3de000..4bd089a 100644 --- a/crates/zeph-llm/src/lib.rs +++ b/crates/zeph-llm/src/lib.rs @@ -3,6 +3,8 @@ pub mod any; #[cfg(feature = "candle")] pub mod candle_provider; +#[cfg(feature = "candle")] +pub mod candle_whisper; pub mod claude; pub mod compatible; pub mod error; diff --git a/crates/zeph-llm/src/melfilters.bytes b/crates/zeph-llm/src/melfilters.bytes new file mode 100644 index 0000000..0874829 Binary files /dev/null and b/crates/zeph-llm/src/melfilters.bytes differ diff --git a/crates/zeph-llm/src/melfilters128.bytes b/crates/zeph-llm/src/melfilters128.bytes new file mode 100644 index 0000000..f287c5b Binary files /dev/null and b/crates/zeph-llm/src/melfilters128.bytes differ diff --git a/docs/src/feature-flags.md b/docs/src/feature-flags.md index ccc8da6..b61e589 100644 --- a/docs/src/feature-flags.md +++ b/docs/src/feature-flags.md @@ -20,7 +20,7 @@ Zeph uses Cargo feature flags to control optional functionality. As of M26, eigh | Feature | Description | |---------|-------------| | `tui` | ratatui-based TUI dashboard with real-time agent metrics | -| `candle` | Local HuggingFace model inference via [candle](https://github.com/huggingface/candle) (GGUF quantized models) | +| `candle` | Local HuggingFace model inference via [candle](https://github.com/huggingface/candle) (GGUF quantized models) and local Whisper STT ([guide](guide/audio-input.md#local-whisper-candle)) | | `metal` | Metal GPU acceleration for candle on macOS (implies `candle`) | | `cuda` | CUDA GPU acceleration for candle on Linux (implies `candle`) | | `discord` | Discord channel adapter with Gateway v10 WebSocket and slash commands ([guide](guide/channels.md#discord-channel)) | diff --git a/docs/src/guide/audio-input.md b/docs/src/guide/audio-input.md index 4e92f72..b769646 100644 --- a/docs/src/guide/audio-input.md +++ b/docs/src/guide/audio-input.md @@ -33,7 +33,58 @@ The Whisper provider inherits the OpenAI API key from the `[llm.openai]` section | Backend | Provider | Feature | Status | |---------|----------|---------|--------| | OpenAI Whisper API | `whisper` | `stt` | Available | -| Local Whisper (candle) | — | — | Planned | +| Local Whisper (candle) | `candle-whisper` | `candle` | Available | + +## Local Whisper (Candle) + +The `candle-whisper` backend runs Whisper inference locally via [candle](https://github.com/huggingface/candle) — no network calls, fully offline after the initial model download. + +### Requirements + +Enable the `candle` feature flag: + +```bash +cargo build --release --features candle # CPU +cargo build --release --features metal # macOS Metal GPU (implies candle) +cargo build --release --features cuda # Linux NVIDIA GPU (implies candle) +``` + +### Configuration + +```toml +[llm.stt] +provider = "candle-whisper" +model = "openai/whisper-tiny" +``` + +### Model Options + +Models are downloaded from HuggingFace on first use and cached locally. + +| Model | HuggingFace ID | Parameters | Disk | +|-------|---------------|------------|------| +| Tiny | `openai/whisper-tiny` | 39M | ~150 MB | +| Base | `openai/whisper-base` | 74M | ~290 MB | +| Small | `openai/whisper-small` | 244M | ~950 MB | + +Smaller models are faster but less accurate. `whisper-tiny` is a good starting point for low-latency use cases. + +### Device Auto-Detection + +The backend automatically selects the best available compute device: + +1. **Metal** — if `metal` feature is enabled and running on macOS +2. **CUDA** — if `cuda` feature is enabled and an NVIDIA GPU is available +3. **CPU** — fallback + +### Audio Pipeline + +Incoming audio is processed through: symphonia decode, rubato resample to 16 kHz mono, mel spectrogram extraction, then candle Whisper inference. + +### Limitations + +- **5-minute audio duration guard** — recordings longer than 5 minutes are rejected. +- **No streaming** — the entire file is decoded and transcribed in one pass. ## Telegram Voice Messages diff --git a/src/main.rs b/src/main.rs index e36ad3e..a37b0b3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -423,8 +423,39 @@ async fn main() -> anyhow::Result<()> { let agent = agent.with_mcp(mcp_tools, mcp_registry, Some(mcp_manager), &config.mcp); let agent = agent.with_learning(config.skills.learning.clone()); + #[cfg(feature = "candle")] + let agent = if config + .llm + .stt + .as_ref() + .is_some_and(|s| s.provider == "candle-whisper") + { + let model = config + .llm + .stt + .as_ref() + .map_or("openai/whisper-tiny", |s| s.model.as_str()); + match zeph_llm::candle_whisper::CandleWhisperProvider::load(model, None) { + Ok(provider) => { + tracing::info!("STT enabled via candle-whisper (model: {model})"); + agent.with_stt(Box::new(provider)) + } + Err(e) => { + tracing::error!("failed to load candle-whisper: {e}"); + agent + } + } + } else { + agent + }; + #[cfg(feature = "stt")] - let agent = if config.llm.stt.is_some() { + let agent = if config + .llm + .stt + .as_ref() + .is_some_and(|s| s.provider != "candle-whisper") + { if let Some(ref api_key) = config.secrets.openai_api_key { let base_url = config .llm @@ -442,7 +473,7 @@ async fn main() -> anyhow::Result<()> { base_url, model, ); - tracing::info!("STT enabled via Whisper (model: {model})"); + tracing::info!("STT enabled via Whisper API (model: {model})"); agent.with_stt(Box::new(whisper)) } else { tracing::warn!("STT configured but ZEPH_OPENAI_API_KEY not found");