diff --git a/Cargo.lock b/Cargo.lock index 60ac841784..c4a9b9afb7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2426,11 +2426,11 @@ dependencies = [ "futures-util", "hound", "kalosm-sound", - "rodio", "serde", + "silero", "thiserror 2.0.12", "tokio", - "vad", + "tracing", ] [[package]] @@ -12935,6 +12935,19 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "silero" +version = "0.1.0" +source = "git+https://github.com/emotechlab/silero-rs?rev=26a6460#26a646003cd8532ae2dde424ccdab1b6cdf5d7b0" +dependencies = [ + "anyhow", + "ndarray", + "ort", + "rubato", + "thiserror 2.0.12", + "tracing", +] + [[package]] name = "simd-adler32" version = "0.3.7" @@ -16320,17 +16333,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "vad" -version = "0.1.0" -dependencies = [ - "data", - "ndarray", - "ort", - "serde", - "thiserror 2.0.12", -] - [[package]] name = "valuable" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index 0357322901..c07f931d1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,6 @@ hypr-slack = { path = "crates/slack", package = "slack" } hypr-stt = { path = "crates/stt", package = "stt", features = ["realtime", "recorded"] } hypr-template = { path = "crates/template", package = "template" } hypr-turso = { path = "crates/turso", package = "turso" } -hypr-vad = { path = "crates/vad", package = "vad" } hypr-whisper = { path = "crates/whisper", package = "whisper" } hypr-whisper-cloud = { path = "crates/whisper-cloud", package = "whisper-cloud" } hypr-whisper-local = { path = "crates/whisper-local", package = "whisper-local" } @@ -193,6 +192,7 @@ hound = "3.5.1" realfft = "3.5.0" ringbuf = "0.4.8" rodio = { version = "0.20.1", features = ["symphonia"] } +silero-rs = { git = "https://github.com/emotechlab/silero-rs", rev = "26a6460", package = "silero" } kalosm-common = { git = "https://github.com/floneum/floneum", rev = "52967ae" } kalosm-llama = { git = "https://github.com/floneum/floneum", rev = "52967ae" } diff --git a/crates/chunker/Cargo.toml b/crates/chunker/Cargo.toml index c0b0c85099..8f1cbb8d17 100644 --- a/crates/chunker/Cargo.toml +++ b/crates/chunker/Cargo.toml @@ -8,15 +8,12 @@ hound = { workspace = true } hypr-data = { workspace = true } [dependencies] -hypr-vad = { workspace = true } +serde = { workspace = true } +thiserror = { workspace = true } + kalosm-sound = { workspace = true, default-features = false } -rodio = { workspace = true } +silero-rs = { workspace = true } futures-util = { workspace = true } -serde = { workspace = true } -thiserror = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } - -[features] -default = [] -load-dynamic = ["hypr-vad/load-dynamic"] +tracing = { workspace = true } diff --git a/crates/chunker/src/error.rs b/crates/chunker/src/error.rs index 8b2ae65742..806640a9b0 100644 --- a/crates/chunker/src/error.rs +++ b/crates/chunker/src/error.rs @@ -1,16 +1,7 @@ -use serde::{ser::Serializer, Serialize}; - #[derive(Debug, thiserror::Error)] pub enum Error { - #[error(transparent)] - VadError(#[from] hypr_vad::Error), -} - -impl Serialize for Error { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - serializer.serialize_str(self.to_string().as_ref()) - } + #[error("Failed to create VAD session")] + VadSessionCreationFailed, + #[error("Failed to process audio")] + VadProcessingFailed(String), } diff --git a/crates/chunker/src/lib.rs b/crates/chunker/src/lib.rs index 98fb4e9634..45b6e86e00 100644 --- a/crates/chunker/src/lib.rs +++ b/crates/chunker/src/lib.rs @@ -1,61 +1,143 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use futures_util::Stream; +use kalosm_sound::AsyncSource; + +use silero_rs::{VadConfig, VadSession, VadTransition}; + mod error; -mod predictor; -mod stream; +use error::*; -pub use error::*; -pub use predictor::*; -pub use stream::*; +pub struct ChunkStream { + source: S, + chunk_samples: usize, + buffer: Vec, +} -use kalosm_sound::AsyncSource; -use std::time::Duration; - -pub trait ChunkerExt: AsyncSource + Sized { - fn chunks( - self, - predictor: P, - chunk_duration: Duration, - ) -> ChunkStream +impl ChunkStream { + fn new(source: S, chunk_duration: Duration) -> Self { + let sample_rate = source.sample_rate(); + let chunk_samples = (chunk_duration.as_secs_f64() * sample_rate as f64) as usize; + + Self { + source, + chunk_samples, + buffer: Vec::with_capacity(chunk_samples), + } + } +} + +impl Stream for ChunkStream { + type Item = Vec; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let stream = this.source.as_stream(); + let mut stream = std::pin::pin!(stream); + + while this.buffer.len() < this.chunk_samples { + match stream.as_mut().poll_next(cx) { + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(Some(sample)) => { + this.buffer.push(sample); + } + Poll::Ready(None) => { + if this.buffer.is_empty() { + return Poll::Ready(None); + } else { + let chunk = std::mem::take(&mut this.buffer); + return Poll::Ready(Some(chunk)); + } + } + } + } + + let mut chunk = Vec::with_capacity(this.chunk_samples); + chunk.extend(this.buffer.drain(..this.chunk_samples)); + Poll::Ready(Some(chunk)) + } +} + +pub trait VadExt: AsyncSource + Sized { + fn vad_chunks(self) -> VadChunkStream where Self: Unpin, { - ChunkStream::new(self, predictor, chunk_duration) + let config = VadConfig { + post_speech_pad: Duration::from_millis(50), + ..Default::default() + }; + + VadChunkStream::new(self, config).unwrap() } } -impl ChunkerExt for T {} +impl VadExt for T {} -#[cfg(test)] -mod tests { - use super::*; - use futures_util::StreamExt; +pub struct VadChunkStream { + chunk_stream: ChunkStream, + vad_session: VadSession, + pending_chunks: Vec, +} - #[tokio::test] - async fn test_chunker() { - let audio_source = rodio::Decoder::new(std::io::BufReader::new( - std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), - )) - .unwrap(); +impl VadChunkStream { + fn new(source: S, mut config: VadConfig) -> Result { + config.sample_rate = source.sample_rate() as usize; - let spec = hound::WavSpec { - channels: 1, - sample_rate: 16000, - bits_per_sample: 32, - sample_format: hound::SampleFormat::Float, - }; + // https://github.com/emotechlab/silero-rs/blob/26a6460/src/lib.rs#L775 + let chunk_duration = Duration::from_millis(30); - let mut stream = audio_source.chunks(RMS::new(), Duration::from_secs(15)); - let mut i = 0; + Ok(Self { + chunk_stream: ChunkStream::new(source, chunk_duration), + vad_session: VadSession::new(config).map_err(|_| Error::VadSessionCreationFailed)?, + pending_chunks: Vec::new(), + }) + } +} + +#[derive(Debug, Clone)] +pub struct AudioChunk { + pub samples: Vec, +} + +impl Stream for VadChunkStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + if let Some(chunk) = this.pending_chunks.pop() { + return Poll::Ready(Some(Ok(chunk))); + } - let _ = std::fs::remove_dir_all("tmp/english_1"); - let _ = std::fs::create_dir_all("tmp/english_1"); + loop { + match Pin::new(&mut this.chunk_stream).poll_next(cx) { + Poll::Ready(Some(samples)) => match this.vad_session.process(&samples) { + Ok(transitions) => { + for transition in transitions { + if let VadTransition::SpeechEnd { samples, .. } = transition { + this.pending_chunks.push(AudioChunk { samples }); + } + } - while let Some(chunk) = stream.next().await { - let file = std::fs::File::create(format!("tmp/english_1/chunk_{}.wav", i)).unwrap(); - let mut writer = hound::WavWriter::new(file, spec).unwrap(); - for sample in chunk { - writer.write_sample(sample).unwrap(); + if let Some(chunk) = this.pending_chunks.pop() { + return Poll::Ready(Some(Ok(chunk))); + } + } + Err(e) => { + let error = Error::VadProcessingFailed(e.to_string()); + return Poll::Ready(Some(Err(error))); + } + }, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, } - i += 1; } } } diff --git a/crates/chunker/src/predictor.rs b/crates/chunker/src/predictor.rs deleted file mode 100644 index ee73507a49..0000000000 --- a/crates/chunker/src/predictor.rs +++ /dev/null @@ -1,45 +0,0 @@ -pub trait Predictor: Send + Sync { - fn predict(&self, samples: &[f32]) -> Result; -} - -#[derive(Debug)] -pub struct RMS {} - -impl RMS { - pub fn new() -> Self { - Self {} - } -} - -impl Predictor for RMS { - fn predict(&self, samples: &[f32]) -> Result { - if samples.is_empty() { - return Ok(false); - } - - let sum_squares: f32 = samples.iter().map(|&sample| sample * sample).sum(); - let mean_square = sum_squares / samples.len() as f32; - let rms = mean_square.sqrt(); - Ok(rms > 0.009) - } -} - -#[derive(Debug)] -pub struct Silero { - #[allow(dead_code)] - inner: hypr_vad::Vad, -} - -impl Silero { - pub fn new() -> Result { - Ok(Self { - inner: hypr_vad::Vad::new()?, - }) - } -} - -impl Predictor for Silero { - fn predict(&self, _samples: &[f32]) -> Result { - Ok(true) - } -} diff --git a/crates/chunker/src/stream.rs b/crates/chunker/src/stream.rs deleted file mode 100644 index 7e0f9d5d6a..0000000000 --- a/crates/chunker/src/stream.rs +++ /dev/null @@ -1,104 +0,0 @@ -use futures_util::Stream; -use std::{ - pin::Pin, - task::{Context, Poll}, - time::Duration, -}; - -use kalosm_sound::AsyncSource; -use rodio::buffer::SamplesBuffer; - -use crate::Predictor; - -pub struct ChunkStream { - source: S, - predictor: P, - buffer: Vec, - max_duration: Duration, -} - -impl ChunkStream { - pub fn new(source: S, predictor: P, max_duration: Duration) -> Self { - Self { - source, - predictor, - buffer: Vec::new(), - max_duration, - } - } - - fn max_samples(&self) -> usize { - (self.source.sample_rate() as f64 * self.max_duration.as_secs_f64()) as usize - } - - fn samples_for_duration(&self, duration: Duration) -> usize { - (self.source.sample_rate() as f64 * duration.as_secs_f64()) as usize - } - - fn trim_silence(predictor: &P, data: &mut Vec) { - const WINDOW_SIZE: usize = 100; - - let mut trim_index = 0; - for start_idx in (0..data.len()).step_by(WINDOW_SIZE) { - let end_idx = (start_idx + WINDOW_SIZE).min(data.len()); - let window = &data[start_idx..end_idx]; - - if let Ok(false) = predictor.predict(window) { - trim_index = start_idx; - break; - } - } - - data.drain(0..trim_index); - } -} - -impl Stream for ChunkStream { - type Item = SamplesBuffer; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - let max_samples = this.max_samples(); - let sample_rate = this.source.sample_rate(); - - let min_buffer_samples = this.samples_for_duration(Duration::from_secs(6)); - let silence_window_samples = this.samples_for_duration(Duration::from_millis(500)); - - let stream = this.source.as_stream(); - let mut stream = std::pin::pin!(stream); - - while this.buffer.len() < max_samples { - match stream.as_mut().poll_next(cx) { - Poll::Ready(Some(sample)) => { - this.buffer.push(sample); - - if this.buffer.len() >= min_buffer_samples { - let buffer_len = this.buffer.len(); - let silence_start = buffer_len.saturating_sub(silence_window_samples); - let last_samples = &this.buffer[silence_start..buffer_len]; - - if let Ok(false) = this.predictor.predict(last_samples) { - let mut data = std::mem::take(&mut this.buffer); - Self::trim_silence(&this.predictor, &mut data); - - return Poll::Ready(Some(SamplesBuffer::new(1, sample_rate, data))); - } - } - } - Poll::Ready(None) if !this.buffer.is_empty() => { - let mut data = std::mem::take(&mut this.buffer); - Self::trim_silence(&this.predictor, &mut data); - - return Poll::Ready(Some(SamplesBuffer::new(1, sample_rate, data))); - } - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, - } - } - - let mut chunk: Vec<_> = this.buffer.drain(0..max_samples).collect(); - Self::trim_silence(&this.predictor, &mut chunk); - - Poll::Ready(Some(SamplesBuffer::new(1, sample_rate, chunk))) - } -} diff --git a/crates/vad/Cargo.toml b/crates/vad/Cargo.toml deleted file mode 100644 index a114c4a8ed..0000000000 --- a/crates/vad/Cargo.toml +++ /dev/null @@ -1,18 +0,0 @@ -[package] -name = "vad" -version = "0.1.0" -edition = "2021" - -[features] -default = [] -load-dynamic = ["ort/load-dynamic"] - -[dependencies] -serde = { workspace = true } -thiserror = { workspace = true } - -ndarray = "0.16" -ort = { version = "=2.0.0-rc.10", features = ["ndarray"] } - -[dev-dependencies] -hypr-data = { workspace = true } diff --git a/crates/vad/assets/model.onnx b/crates/vad/assets/model.onnx deleted file mode 100644 index e6db48d6e2..0000000000 Binary files a/crates/vad/assets/model.onnx and /dev/null differ diff --git a/crates/vad/src/error.rs b/crates/vad/src/error.rs deleted file mode 100644 index b02763d7ef..0000000000 --- a/crates/vad/src/error.rs +++ /dev/null @@ -1,20 +0,0 @@ -use serde::{ser::Serializer, Serialize}; - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error(transparent)] - OrtError(#[from] ort::Error), - #[error(transparent)] - ShapeError(#[from] ndarray::ShapeError), - #[error("Invalid or missing output from model")] - InvalidOutput, -} - -impl Serialize for Error { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - serializer.serialize_str(self.to_string().as_ref()) - } -} diff --git a/crates/vad/src/lib.rs b/crates/vad/src/lib.rs deleted file mode 100644 index 82b422b1f7..0000000000 --- a/crates/vad/src/lib.rs +++ /dev/null @@ -1,153 +0,0 @@ -mod error; -pub use error::*; - -use ndarray::{Array1, Array2, Array3, ArrayBase, Ix1, Ix3, OwnedRepr}; -use ort::{ - session::{builder::GraphOptimizationLevel, Session}, - value::TensorRef, -}; - -const MODEL_BYTES: &[u8] = - include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/model.onnx")); - -const SAMPLE_RATE: i64 = 16000; - -const fn ms_to_samples(ms: usize) -> usize { - (ms * SAMPLE_RATE as usize) / 1000 -} - -#[derive(Debug)] -pub struct Vad { - session: Session, - h_tensor: ArrayBase, Ix3>, - c_tensor: ArrayBase, Ix3>, - sample_rate_tensor: ArrayBase, Ix1>, -} - -impl Vad { - pub fn new() -> Result { - let session = Session::builder()? - .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_intra_threads(4)? - .commit_from_memory(MODEL_BYTES)?; - - let h_tensor = Array3::::zeros((2, 1, 64)); - let c_tensor = Array3::::zeros((2, 1, 64)); - let sample_rate_tensor = Array1::from_vec(vec![SAMPLE_RATE]); - - Ok(Self { - session, - h_tensor, - c_tensor, - sample_rate_tensor, - }) - } - - /// Process a chunk of audio samples through the model and return the speech probability - fn forward(&mut self, audio_chunk: &[f32]) -> Result { - let samples = audio_chunk.len(); - let audio_tensor = Array2::from_shape_vec((1, samples), audio_chunk.to_vec())?; - - let mut result = self.session.run(ort::inputs![ - TensorRef::from_array_view(audio_tensor.view())?, - TensorRef::from_array_view(self.sample_rate_tensor.view())?, - TensorRef::from_array_view(self.h_tensor.view())?, - TensorRef::from_array_view(self.c_tensor.view())?, - ])?; - - // Update internal state tensors - self.h_tensor = result - .get("hn") - .ok_or(Error::InvalidOutput)? - .try_extract_array::()? - .to_owned() - .into_shape_with_order((2, 1, 64))?; - - self.c_tensor = result - .get("cn") - .ok_or(Error::InvalidOutput)? - .try_extract_array::()? - .to_owned() - .into_shape_with_order((2, 1, 64))?; - - let prob_tensor = result.remove("output").ok_or(Error::InvalidOutput)?; - let prob = *prob_tensor - .try_extract_array::()? - .first() - .ok_or(Error::InvalidOutput)?; - - Ok(prob) - } - - /// For longer audio, this will process in 30ms chunks and return the maximum probability - pub fn run(&mut self, audio_samples: &[f32]) -> Result { - if audio_samples.len() < ms_to_samples(30) { - return self.forward(audio_samples); - } - - let chunk_size = ms_to_samples(30); - let num_chunks = audio_samples.len() / chunk_size; - - let mut max_prob = 0.0f32; - - for i in 0..num_chunks { - let start = i * chunk_size; - let end = (start + chunk_size).min(audio_samples.len()); - let prob = self.forward(&audio_samples[start..end])?; - max_prob = max_prob.max(prob); - } - - let remaining_start = num_chunks * chunk_size; - if remaining_start < audio_samples.len() - && audio_samples.len() - remaining_start >= (chunk_size / 2) - { - let prob = self.forward(&audio_samples[remaining_start..])?; - max_prob = max_prob.max(prob); - } - - Ok(max_prob) - } - - pub fn reset(&mut self) { - self.h_tensor = Array3::::zeros((2, 1, 64)); - self.c_tensor = Array3::::zeros((2, 1, 64)); - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_vad_silence() { - let mut vad = Vad::new().unwrap(); - let audio_samples = vec![0.0; 16000]; - let prob = vad.run(&audio_samples).unwrap(); - assert!(prob < 0.1); - } - - #[test] - fn test_vad_english_1() { - let mut vad = Vad::new().unwrap(); - let audio_samples = to_f32(hypr_data::english_1::AUDIO); - let prob = vad.run(&audio_samples).unwrap(); - assert!(prob > 0.8); - } - - #[test] - fn test_vad_english_2() { - let mut vad = Vad::new().unwrap(); - let audio_samples = to_f32(hypr_data::english_2::AUDIO); - let prob = vad.run(&audio_samples).unwrap(); - assert!(prob > 0.8); - } - - fn to_f32(bytes: &[u8]) -> Vec { - let mut samples = Vec::with_capacity(bytes.len() / 2); - for chunk in bytes.chunks_exact(2) { - let sample = i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / 32768.0; - samples.push(sample); - } - samples - } -} diff --git a/plugins/local-stt/Cargo.toml b/plugins/local-stt/Cargo.toml index ac1bae27b7..9bf34386ec 100644 --- a/plugins/local-stt/Cargo.toml +++ b/plugins/local-stt/Cargo.toml @@ -21,7 +21,7 @@ openblas = ["hypr-whisper-local/openblas"] metal = ["hypr-whisper-local/metal"] vulkan = ["hypr-whisper-local/vulkan"] openmp = ["hypr-whisper-local/openmp"] -load-dynamic = ["hypr-pyannote-local/load-dynamic", "hypr-chunker/load-dynamic"] +load-dynamic = ["hypr-pyannote-local/load-dynamic"] [build-dependencies] tauri-plugin = { workspace = true, features = ["build"] } diff --git a/plugins/local-stt/src/server.rs b/plugins/local-stt/src/server.rs index caa191794e..023e60c4bc 100644 --- a/plugins/local-stt/src/server.rs +++ b/plugins/local-stt/src/server.rs @@ -15,10 +15,9 @@ use axum::{ }; use futures_util::{SinkExt, StreamExt}; -use rodio::Source; use tower_http::cors::{self, CorsLayer}; -use hypr_chunker::ChunkerExt; +use hypr_chunker::VadExt; use hypr_listener_interface::{ListenOutputChunk, ListenParams, Word}; use crate::manager::{ConnectionGuard, ConnectionManager}; @@ -149,20 +148,12 @@ async fn websocket_single_channel( model: hypr_whisper_local::Whisper, guard: ConnectionGuard, ) { - let stream = { - let audio_source = hypr_ws_utils::WebSocketAudioSource::new(ws_receiver, 16 * 1000); - let chunked = - audio_source.chunks(hypr_chunker::RMS::new(), std::time::Duration::from_secs(13)); - - let chunked = hypr_whisper_local::AudioChunkStream(chunked.map(|chunk| { - hypr_whisper_local::SimpleAudioChunk { - samples: chunk.convert_samples().collect(), - meta: Some(serde_json::json!({ "source": "mixed" })), - } - })); - hypr_whisper_local::TranscribeMetadataAudioStreamExt::transcribe(chunked, model) - }; + let audio_source = hypr_ws_utils::WebSocketAudioSource::new(ws_receiver, 16 * 1000); + let vad_chunks = audio_source.vad_chunks(); + let chunked = hypr_whisper_local::AudioChunkStream(process_vad_stream(vad_chunks, "mixed")); + + let stream = hypr_whisper_local::TranscribeMetadataAudioStreamExt::transcribe(chunked, model); process_transcription_stream(ws_sender, stream, guard).await; } @@ -175,24 +166,15 @@ async fn websocket_dual_channel( let (mic_source, speaker_source) = hypr_ws_utils::split_dual_audio_sources(ws_receiver, 16 * 1000); - let mic_chunked = - mic_source.chunks(hypr_chunker::RMS::new(), std::time::Duration::from_secs(13)); - let speaker_chunked = - speaker_source.chunks(hypr_chunker::RMS::new(), std::time::Duration::from_secs(13)); - - let mic_chunked = hypr_whisper_local::AudioChunkStream(mic_chunked.map(|chunk| { - hypr_whisper_local::SimpleAudioChunk { - samples: chunk.convert_samples().collect(), - meta: Some(serde_json::json!({ "source": "mic" })), - } - })); + let mic_chunked = { + let mic_vad_chunks = mic_source.vad_chunks(); + hypr_whisper_local::AudioChunkStream(process_vad_stream(mic_vad_chunks, "mic")) + }; - let speaker_chunked = hypr_whisper_local::AudioChunkStream(speaker_chunked.map(|chunk| { - hypr_whisper_local::SimpleAudioChunk { - samples: chunk.convert_samples().collect(), - meta: Some(serde_json::json!({ "source": "speaker" })), - } - })); + let speaker_chunked = { + let speaker_vad_chunks = speaker_source.vad_chunks(); + hypr_whisper_local::AudioChunkStream(process_vad_stream(speaker_vad_chunks, "speaker")) + }; let merged_stream = hypr_whisper_local::AudioChunkStream(futures_util::stream::select( mic_chunked.0, @@ -225,7 +207,7 @@ async fn process_transcription_stream( let duration = chunk.duration() as u64; let confidence = chunk.confidence(); - if confidence < 0.2 { + if confidence < 0.1 { tracing::warn!(confidence, "skipping_transcript: {}", text); continue; } @@ -267,3 +249,34 @@ async fn process_transcription_stream( let _ = ws_sender.close().await; } + +fn process_vad_stream( + stream: S, + source_name: &str, +) -> impl futures_util::Stream +where + S: futures_util::Stream>, + E: std::fmt::Display, +{ + let source_name = source_name.to_string(); + + stream + .take_while(move |chunk_result| { + futures_util::future::ready(match chunk_result { + Ok(_) => true, + Err(e) => { + tracing::error!("vad_error_disconnecting: {}", e); + false // This will end the stream + } + }) + }) + .filter_map(move |chunk_result| { + futures_util::future::ready(match chunk_result { + Err(_) => None, // This shouldn't happen due to take_while above + Ok(chunk) => Some(hypr_whisper_local::SimpleAudioChunk { + samples: chunk.samples, + meta: Some(serde_json::json!({ "source": source_name })), + }), + }) + }) +}