diff --git a/Cargo.lock b/Cargo.lock index 4230dbfb77..de3672d7ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1545,7 +1545,6 @@ dependencies = [ "matchit 0.8.4", "memchr", "mime", - "multer", "percent-encoding", "pin-project-lite", "rustversion", @@ -4349,12 +4348,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "extended" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af9673d8203fcb076b19dfd17e38b3d4ae9f44959416ea532ce72415a6020365" - [[package]] name = "eyre" version = "0.6.12" @@ -8806,23 +8799,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "multer" -version = "3.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83e87776546dc87511aa5ee218730c92b666d7264ab6ed41f9d215af9cd5224b" -dependencies = [ - "bytes", - "encoding_rs", - "futures-util", - "http 1.3.1", - "httparse", - "memchr", - "mime", - "spin 0.9.8", - "version_check", -] - [[package]] name = "multimap" version = "0.10.1" @@ -13574,30 +13550,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "815c942ae7ee74737bb00f965fa5b5a2ac2ce7b6c01c0cc169bbeaf7abd5f5a9" dependencies = [ "lazy_static", - "symphonia-bundle-flac", "symphonia-bundle-mp3", - "symphonia-codec-aac", - "symphonia-codec-adpcm", - "symphonia-codec-pcm", - "symphonia-codec-vorbis", "symphonia-core", - "symphonia-format-isomp4", - "symphonia-format-riff", "symphonia-metadata", ] -[[package]] -name = "symphonia-bundle-flac" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72e34f34298a7308d4397a6c7fbf5b84c5d491231ce3dd379707ba673ab3bd97" -dependencies = [ - "log", - "symphonia-core", - "symphonia-metadata", - "symphonia-utils-xiph", -] - [[package]] name = "symphonia-bundle-mp3" version = "0.5.4" @@ -13610,48 +13567,6 @@ dependencies = [ "symphonia-metadata", ] -[[package]] -name = "symphonia-codec-aac" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdbf25b545ad0d3ee3e891ea643ad115aff4ca92f6aec472086b957a58522f70" -dependencies = [ - "lazy_static", - "log", - "symphonia-core", -] - -[[package]] -name = "symphonia-codec-adpcm" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c94e1feac3327cd616e973d5be69ad36b3945f16b06f19c6773fc3ac0b426a0f" -dependencies = [ - "log", - "symphonia-core", -] - -[[package]] -name = "symphonia-codec-pcm" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f395a67057c2ebc5e84d7bb1be71cce1a7ba99f64e0f0f0e303a03f79116f89b" -dependencies = [ - "log", - "symphonia-core", -] - -[[package]] -name = "symphonia-codec-vorbis" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a98765fb46a0a6732b007f7e2870c2129b6f78d87db7987e6533c8f164a9f30" -dependencies = [ - "log", - "symphonia-core", - "symphonia-utils-xiph", -] - [[package]] name = "symphonia-core" version = "0.5.4" @@ -13665,31 +13580,6 @@ dependencies = [ "log", ] -[[package]] -name = "symphonia-format-isomp4" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abfdf178d697e50ce1e5d9b982ba1b94c47218e03ec35022d9f0e071a16dc844" -dependencies = [ - "encoding_rs", - "log", - "symphonia-core", - "symphonia-metadata", - "symphonia-utils-xiph", -] - -[[package]] -name = "symphonia-format-riff" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f7be232f962f937f4b7115cbe62c330929345434c834359425e043bfd15f50" -dependencies = [ - "extended", - "log", - "symphonia-core", - "symphonia-metadata", -] - [[package]] name = "symphonia-metadata" version = "0.5.4" @@ -13702,16 +13592,6 @@ dependencies = [ "symphonia-core", ] -[[package]] -name = "symphonia-utils-xiph" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "484472580fa49991afda5f6550ece662237b00c6f562c7d9638d1b086ed010fe" -dependencies = [ - "symphonia-core", - "symphonia-metadata", -] - [[package]] name = "syn" version = "1.0.109" @@ -14461,7 +14341,6 @@ dependencies = [ "audio-utils", "axum 0.8.4", "axum-extra", - "chunker", "data", "dirs 6.0.0", "file", @@ -14469,7 +14348,6 @@ dependencies = [ "inventory", "language", "listener-interface", - "pyannote-local", "reqwest 0.12.22", "rodio", "serde", @@ -14492,9 +14370,8 @@ dependencies = [ "tokio-util", "tower-http 0.6.6", "tracing", - "whisper", + "transcribe-whisper-local", "whisper-local", - "ws-utils", ] [[package]] @@ -15947,6 +15824,8 @@ dependencies = [ "aws-config", "aws-sdk-transcribe", "aws-sdk-transcribestreaming", + "aws-smithy-runtime-api", + "aws-smithy-types", "axum 0.8.4", "bytes", "data", @@ -16001,6 +15880,38 @@ dependencies = [ "tracing", ] +[[package]] +name = "transcribe-interface" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", + "thiserror 2.0.12", +] + +[[package]] +name = "transcribe-whisper-local" +version = "0.1.0" +dependencies = [ + "audio-utils", + "axum 0.8.4", + "chunker", + "futures-util", + "listener-interface", + "pyannote-local", + "rodio", + "serde_json", + "serde_qs 1.0.0-rc.3", + "thiserror 2.0.12", + "tokio", + "tokio-util", + "tower 0.5.2", + "tracing", + "whisper", + "whisper-local", + "ws-utils", +] + [[package]] name = "transpose" version = "0.2.3" diff --git a/Cargo.toml b/Cargo.toml index 97679b4741..49c5a7d2c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ hypr-template = { path = "crates/template", package = "template" } hypr-transcribe-aws = { path = "crates/transcribe-aws", package = "transcribe-aws" } hypr-transcribe-azure = { path = "crates/transcribe-azure", package = "transcribe-azure" } hypr-transcribe-gcp = { path = "crates/transcribe-gcp", package = "transcribe-gcp" } +hypr-transcribe-whisper-local = { path = "crates/transcribe-whisper-local", package = "transcribe-whisper-local" } hypr-turso = { path = "crates/turso", package = "turso" } hypr-whisper = { path = "crates/whisper", package = "whisper" } hypr-whisper-cloud = { path = "crates/whisper-cloud", package = "whisper-cloud" } diff --git a/crates/transcribe-aws/Cargo.toml b/crates/transcribe-aws/Cargo.toml index 32dd92b5b2..dc0fe98f8a 100644 --- a/crates/transcribe-aws/Cargo.toml +++ b/crates/transcribe-aws/Cargo.toml @@ -21,6 +21,8 @@ tracing = { workspace = true } aws-config = "1.8.3" aws-sdk-transcribe = "1.83.0" aws-sdk-transcribestreaming = "1.80.0" +aws-smithy-runtime-api = "1.8.5" +aws-smithy-types = "1.3.2" [dev-dependencies] hypr-data = { workspace = true } diff --git a/crates/transcribe-aws/src/error.rs b/crates/transcribe-aws/src/error.rs index aae0243e40..723ec8d4b2 100644 --- a/crates/transcribe-aws/src/error.rs +++ b/crates/transcribe-aws/src/error.rs @@ -1,7 +1,21 @@ #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("Invalid input: {0}")] - InvalidInput(String), - #[error("Service error: {0}")] - ServiceError(String), + #[error(transparent)] + GenericError(#[from] aws_sdk_transcribestreaming::Error), + #[error(transparent)] + TranscriptResultStreamError( + #[from] + aws_smithy_runtime_api::client::result::SdkError< + aws_sdk_transcribestreaming::types::error::TranscriptResultStreamError, + aws_smithy_types::event_stream::RawMessage, + >, + ), + #[error(transparent)] + StartStreamTranscriptionError( + #[from] + aws_smithy_runtime_api::client::result::SdkError< + aws_sdk_transcribestreaming::operation::start_stream_transcription::StartStreamTranscriptionError, + aws_smithy_runtime_api::client::orchestrator::HttpResponse, + >, + ), } diff --git a/crates/transcribe-aws/src/lib.rs b/crates/transcribe-aws/src/lib.rs index dc336ebaee..9fdbbd8d56 100644 --- a/crates/transcribe-aws/src/lib.rs +++ b/crates/transcribe-aws/src/lib.rs @@ -1,3 +1,5 @@ +// AWS draft + use bytes::Bytes; use serde::{Deserialize, Serialize}; use std::sync::Arc; @@ -28,7 +30,7 @@ use aws_sdk_transcribestreaming::primitives::Blob; use aws_sdk_transcribestreaming::types::{ AudioEvent, AudioStream, LanguageCode, MediaEncoding, TranscriptResultStream, }; -use aws_sdk_transcribestreaming::{config::Region, Client, Error}; +use aws_sdk_transcribestreaming::{config::Region, Client}; mod error; pub use error::*; @@ -72,7 +74,7 @@ pub struct TranscribeService { } impl TranscribeService { - pub async fn new(config: TranscribeConfig) -> Result { + pub async fn new(config: TranscribeConfig) -> Result { let region_provider = RegionProviderChain::first_try(config.region.clone().map(Region::new)) .or_default_provider() @@ -143,7 +145,7 @@ impl TranscribeService { &self, mut audio_rx: mpsc::Receiver, result_tx: mpsc::Sender, - ) -> Result<(), Error> { + ) -> Result<(), crate::Error> { // Create audio stream for AWS Transcribe let input_stream = stream! { while let Some(chunk) = audio_rx.recv().await { @@ -166,7 +168,6 @@ impl TranscribeService { .send() .await?; - // Process transcription results while let Some(event) = output.transcript_result_stream.recv().await? { match event { TranscriptResultStream::TranscriptEvent(transcript_event) => { diff --git a/crates/transcribe-interface/Cargo.toml b/crates/transcribe-interface/Cargo.toml new file mode 100644 index 0000000000..2755db1f65 --- /dev/null +++ b/crates/transcribe-interface/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "transcribe-interface" +version = "0.1.0" +edition = "2021" + +[dependencies] +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +thiserror = { workspace = true } diff --git a/crates/transcribe-interface/src/lib.rs b/crates/transcribe-interface/src/lib.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/crates/transcribe-interface/src/lib.rs @@ -0,0 +1 @@ + diff --git a/crates/transcribe-whisper-local/Cargo.toml b/crates/transcribe-whisper-local/Cargo.toml new file mode 100644 index 0000000000..bf2a64fe24 --- /dev/null +++ b/crates/transcribe-whisper-local/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "transcribe-whisper-local" +version = "0.1.0" +edition = "2021" + +[features] +default = [] +coreml = ["hypr-whisper-local/coreml"] +directml = ["hypr-pyannote-local/directml"] +cuda = ["hypr-whisper-local/cuda"] +hipblas = ["hypr-whisper-local/hipblas"] +openblas = ["hypr-whisper-local/openblas"] +metal = ["hypr-whisper-local/metal"] +vulkan = ["hypr-whisper-local/vulkan"] +openmp = ["hypr-whisper-local/openmp"] +load-dynamic = ["hypr-pyannote-local/load-dynamic"] + +[dependencies] +hypr-audio-utils = { workspace = true } +hypr-chunker = { workspace = true } +hypr-listener-interface = { workspace = true } +hypr-pyannote-local = { workspace = true } +hypr-whisper = { workspace = true } +hypr-whisper-local = { workspace = true } +hypr-ws-utils = { workspace = true } + +serde_json = { workspace = true } +serde_qs = { workspace = true } +thiserror = { workspace = true } + +axum = { workspace = true, features = ["ws"] } +futures-util = { workspace = true } +tokio = { workspace = true } +tokio-util = { workspace = true } +tower = { workspace = true } +tracing = { workspace = true } + +rodio = { workspace = true } diff --git a/crates/transcribe-whisper-local/src/error.rs b/crates/transcribe-whisper-local/src/error.rs new file mode 100644 index 0000000000..0afaa5e7f4 --- /dev/null +++ b/crates/transcribe-whisper-local/src/error.rs @@ -0,0 +1,2 @@ +#[derive(Debug, thiserror::Error)] +pub enum Error {} diff --git a/crates/transcribe-whisper-local/src/lib.rs b/crates/transcribe-whisper-local/src/lib.rs new file mode 100644 index 0000000000..8f0dc5169d --- /dev/null +++ b/crates/transcribe-whisper-local/src/lib.rs @@ -0,0 +1,6 @@ +mod error; +mod manager; +mod service; + +pub use error::*; +pub use service::*; diff --git a/crates/transcribe-whisper-local/src/manager.rs b/crates/transcribe-whisper-local/src/manager.rs new file mode 100644 index 0000000000..1b66b72039 --- /dev/null +++ b/crates/transcribe-whisper-local/src/manager.rs @@ -0,0 +1,40 @@ +use std::sync::{Arc, Mutex}; +use tokio_util::sync::CancellationToken; + +#[derive(Clone)] +pub struct ConnectionManager { + inner: Arc>>, +} + +impl Default for ConnectionManager { + fn default() -> Self { + Self { + inner: Arc::new(Mutex::new(None)), + } + } +} + +impl ConnectionManager { + pub fn acquire_connection(&self) -> ConnectionGuard { + let mut slot = self.inner.lock().unwrap(); + + if let Some(old) = slot.take() { + old.cancel(); + } + + let token = CancellationToken::new(); + *slot = Some(token.clone()); + + ConnectionGuard { token } + } +} + +pub struct ConnectionGuard { + token: CancellationToken, +} + +impl ConnectionGuard { + pub async fn cancelled(&self) { + self.token.cancelled().await + } +} diff --git a/crates/transcribe-whisper-local/src/service/mod.rs b/crates/transcribe-whisper-local/src/service/mod.rs new file mode 100644 index 0000000000..33c40125b7 --- /dev/null +++ b/crates/transcribe-whisper-local/src/service/mod.rs @@ -0,0 +1,5 @@ +mod streaming; +pub use streaming::*; + +mod recorded; +pub use recorded::*; diff --git a/crates/transcribe-whisper-local/src/service/recorded.rs b/crates/transcribe-whisper-local/src/service/recorded.rs new file mode 100644 index 0000000000..cbd313b68c --- /dev/null +++ b/crates/transcribe-whisper-local/src/service/recorded.rs @@ -0,0 +1,61 @@ +use hypr_listener_interface::Word; + +pub fn process_recorded( + model_path: impl AsRef, + audio_path: impl AsRef, +) -> Result, crate::Error> { + use rodio::Source; + + let decoder = rodio::Decoder::new(std::io::BufReader::new( + std::fs::File::open(audio_path.as_ref()).unwrap(), + )) + .unwrap(); + + let original_sample_rate = decoder.sample_rate(); + + let resampled_samples = if original_sample_rate != 16000 { + hypr_audio_utils::resample_audio(decoder, 16000).unwrap() + } else { + decoder.convert_samples().collect() + }; + + let samples_i16 = hypr_audio_utils::f32_to_i16_samples(&resampled_samples); + + let mut model = hypr_whisper_local::Whisper::builder() + .model_path(model_path.as_ref().to_str().unwrap()) + .languages(vec![]) + .static_prompt("") + .dynamic_prompt("") + .build(); + + let mut segmenter = hypr_pyannote_local::segmentation::Segmenter::new(16000).unwrap(); + let segments = segmenter.process(&samples_i16, 16000).unwrap(); + + let mut words = Vec::new(); + + for segment in segments { + let audio_f32 = hypr_audio_utils::i16_to_f32_samples(&segment.samples); + + let whisper_segments = model.transcribe(&audio_f32).unwrap(); + + for whisper_segment in whisper_segments { + let start_sec: f64 = segment.start + (whisper_segment.start() as f64); + let end_sec: f64 = segment.start + (whisper_segment.end() as f64); + let start_ms = (start_sec * 1000.0) as u64; + let end_ms = (end_sec * 1000.0) as u64; + + let word = Word { + text: whisper_segment.text().to_string(), + speaker: None, + confidence: Some(whisper_segment.confidence()), + start_ms: Some(start_ms), + end_ms: Some(end_ms), + }; + + // TODO + words.push(word.clone()); + } + } + + Ok(words) +} diff --git a/crates/transcribe-whisper-local/src/service/streaming.rs b/crates/transcribe-whisper-local/src/service/streaming.rs new file mode 100644 index 0000000000..c5e84ab5c9 --- /dev/null +++ b/crates/transcribe-whisper-local/src/service/streaming.rs @@ -0,0 +1,283 @@ +use std::{ + future::Future, + path::PathBuf, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use axum::{ + extract::{ + ws::{Message, WebSocket, WebSocketUpgrade}, + FromRequestParts, + }, + http::{Request, StatusCode}, + response::{IntoResponse, Response}, +}; +use futures_util::{SinkExt, StreamExt}; +use tower::Service; + +use hypr_chunker::VadExt; +use hypr_listener_interface::{ListenOutputChunk, ListenParams, Word}; + +use crate::manager::{ConnectionGuard, ConnectionManager}; + +#[derive(Clone)] +pub struct WhisperStreamingService { + model_path: PathBuf, + connection_manager: ConnectionManager, +} + +impl WhisperStreamingService { + pub fn builder() -> WhisperStreamingServiceBuilder { + WhisperStreamingServiceBuilder::default() + } +} + +#[derive(Default)] +pub struct WhisperStreamingServiceBuilder { + model_path: Option, + connection_manager: Option, +} + +impl WhisperStreamingServiceBuilder { + pub fn model_path(mut self, model_path: PathBuf) -> Self { + self.model_path = Some(model_path); + self + } + + pub fn build(self) -> WhisperStreamingService { + WhisperStreamingService { + model_path: self.model_path.unwrap(), + connection_manager: self + .connection_manager + .unwrap_or_else(ConnectionManager::default), + } + } +} + +impl Service> for WhisperStreamingService +where + B: Send + 'static, +{ + type Response = Response; + type Error = std::convert::Infallible; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let model_path = self.model_path.clone(); + let connection_manager = self.connection_manager.clone(); + + Box::pin(async move { + let uri = req.uri(); + let query_string = uri.query().unwrap_or(""); + let params: ListenParams = match serde_qs::from_str(query_string) { + Ok(p) => p, + Err(_) => { + return Ok(StatusCode::BAD_REQUEST.into_response()); + } + }; + + let (mut parts, _body) = req.into_parts(); + let ws_upgrade = match WebSocketUpgrade::from_request_parts(&mut parts, &()).await { + Ok(ws) => ws, + Err(_) => { + return Ok(StatusCode::BAD_REQUEST.into_response()); + } + }; + + let guard = connection_manager.acquire_connection(); + + let response = ws_upgrade.on_upgrade(move |socket| async move { + handle_websocket_connection(socket, params, model_path, guard).await + }); + + Ok(response.into_response()) + }) + } +} + +async fn handle_websocket_connection( + socket: WebSocket, + params: ListenParams, + model_path: PathBuf, + guard: ConnectionGuard, +) { + let languages: Vec = params + .languages + .into_iter() + .filter_map(|lang| lang.try_into().ok()) + .collect(); + + let model = hypr_whisper_local::Whisper::builder() + .model_path(model_path.to_str().unwrap()) + .languages(languages) + .static_prompt(¶ms.static_prompt) + .dynamic_prompt(¶ms.dynamic_prompt) + .build(); + + let (ws_sender, ws_receiver) = socket.split(); + + match params.audio_mode { + hypr_listener_interface::AudioMode::Single => { + handle_single_channel( + ws_sender, + ws_receiver, + model, + guard, + Duration::from_millis(params.redemption_time_ms), + ) + .await; + } + hypr_listener_interface::AudioMode::Dual => { + handle_dual_channel( + ws_sender, + ws_receiver, + model, + guard, + Duration::from_millis(params.redemption_time_ms), + ) + .await; + } + } +} + +async fn handle_single_channel( + ws_sender: futures_util::stream::SplitSink, + ws_receiver: futures_util::stream::SplitStream, + model: hypr_whisper_local::Whisper, + guard: ConnectionGuard, + redemption_time: Duration, +) { + let audio_source = hypr_ws_utils::WebSocketAudioSource::new(ws_receiver, 16 * 1000); + let vad_chunks = audio_source.vad_chunks(redemption_time); + + let chunked = hypr_whisper_local::AudioChunkStream(process_vad_stream(vad_chunks, "mixed")); + + let stream = hypr_whisper_local::TranscribeMetadataAudioStreamExt::transcribe(chunked, model); + process_transcription_stream(ws_sender, stream, guard).await; +} + +async fn handle_dual_channel( + ws_sender: futures_util::stream::SplitSink, + ws_receiver: futures_util::stream::SplitStream, + model: hypr_whisper_local::Whisper, + guard: ConnectionGuard, + redemption_time: Duration, +) { + let (mic_source, speaker_source) = + hypr_ws_utils::split_dual_audio_sources(ws_receiver, 16 * 1000); + + let mic_chunked = { + let mic_vad_chunks = mic_source.vad_chunks(redemption_time); + hypr_whisper_local::AudioChunkStream(process_vad_stream(mic_vad_chunks, "mic")) + }; + + let speaker_chunked = { + let speaker_vad_chunks = speaker_source.vad_chunks(redemption_time); + hypr_whisper_local::AudioChunkStream(process_vad_stream(speaker_vad_chunks, "speaker")) + }; + + let merged_stream = hypr_whisper_local::AudioChunkStream(futures_util::stream::select( + mic_chunked.0, + speaker_chunked.0, + )); + + let stream = + hypr_whisper_local::TranscribeMetadataAudioStreamExt::transcribe(merged_stream, model); + + process_transcription_stream(ws_sender, stream, guard).await; +} + +async fn process_transcription_stream( + mut ws_sender: futures_util::stream::SplitSink, + mut stream: impl futures_util::Stream + Unpin, + guard: ConnectionGuard, +) { + loop { + tokio::select! { + _ = guard.cancelled() => { + tracing::info!("websocket_cancelled_by_new_connection"); + break; + } + chunk_opt = stream.next() => { + let Some(chunk) = chunk_opt else { break }; + + let meta = chunk.meta(); + let text = chunk.text().to_string(); + let start = chunk.start() as u64; + let duration = chunk.duration() as u64; + let confidence = chunk.confidence(); + + let source = meta.and_then(|meta| + meta.get("source") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + ); + let speaker = match source { + Some(s) if s == "mic" => Some(hypr_listener_interface::SpeakerIdentity::Unassigned { index: 0 }), + Some(s) if s == "speaker" => Some(hypr_listener_interface::SpeakerIdentity::Unassigned { index: 1 }), + _ => None, + }; + + let data = ListenOutputChunk { + meta: None, + words: text + .split_whitespace() + .filter(|w| !w.is_empty()) + .map(|w| Word { + text: w.trim().to_string(), + speaker: speaker.clone(), + start_ms: Some(start), + end_ms: Some(start + duration), + confidence: Some(confidence), + }) + .collect(), + }; + + let msg = Message::Text(serde_json::to_string(&data).unwrap().into()); + if let Err(e) = ws_sender.send(msg).await { + tracing::warn!("websocket_send_error: {}", e); + break; + } + } + } + } + + let _ = ws_sender.close().await; +} + +fn process_vad_stream( + stream: S, + source_name: &str, +) -> impl futures_util::Stream +where + S: futures_util::Stream>, + E: std::fmt::Display, +{ + let source_name = source_name.to_string(); + + stream + .take_while(move |chunk_result| { + futures_util::future::ready(match chunk_result { + Ok(_) => true, + Err(e) => { + tracing::error!("vad_error_disconnecting: {}", e); + false + } + }) + }) + .filter_map(move |chunk_result| { + futures_util::future::ready(match chunk_result { + Err(_) => None, + Ok(chunk) => Some(hypr_whisper_local::SimpleAudioChunk { + samples: chunk.samples, + meta: Some(serde_json::json!({ "source": source_name })), + }), + }) + }) +} diff --git a/plugins/local-stt/Cargo.toml b/plugins/local-stt/Cargo.toml index b836656ccd..406f333ffc 100644 --- a/plugins/local-stt/Cargo.toml +++ b/plugins/local-stt/Cargo.toml @@ -13,15 +13,15 @@ harness = false [features] default = [] -coreml = ["hypr-whisper-local/coreml", "hypr-pyannote-local/coreml"] -directml = ["hypr-pyannote-local/directml"] -cuda = ["hypr-whisper-local/cuda"] -hipblas = ["hypr-whisper-local/hipblas"] -openblas = ["hypr-whisper-local/openblas"] -metal = ["hypr-whisper-local/metal"] -vulkan = ["hypr-whisper-local/vulkan"] -openmp = ["hypr-whisper-local/openmp"] -load-dynamic = ["hypr-pyannote-local/load-dynamic"] +coreml = ["hypr-transcribe-whisper-local/coreml"] +directml = ["hypr-transcribe-whisper-local/directml"] +cuda = ["hypr-transcribe-whisper-local/cuda"] +hipblas = ["hypr-transcribe-whisper-local/hipblas"] +openblas = ["hypr-transcribe-whisper-local/openblas"] +metal = ["hypr-transcribe-whisper-local/metal"] +vulkan = ["hypr-transcribe-whisper-local/vulkan"] +openmp = ["hypr-transcribe-whisper-local/openmp"] +load-dynamic = ["hypr-transcribe-whisper-local/load-dynamic"] [build-dependencies] tauri-plugin = { workspace = true, features = ["build"] } @@ -43,13 +43,10 @@ tokio-tungstenite = { workspace = true } [dependencies] hypr-audio-utils = { workspace = true } -hypr-chunker = { workspace = true } hypr-file = { workspace = true } hypr-listener-interface = { workspace = true } -hypr-pyannote-local = { workspace = true } -hypr-whisper = { workspace = true } +hypr-transcribe-whisper-local = { workspace = true } hypr-whisper-local = { workspace = true } -hypr-ws-utils = { workspace = true } tauri = { workspace = true, features = ["test"] } tauri-specta = { workspace = true, features = ["derive", "typescript"] } @@ -65,9 +62,7 @@ specta = { workspace = true } strum = { workspace = true, features = ["derive"] } thiserror = { workspace = true } -rodio = { workspace = true, features = ["symphonia", "symphonia-all"] } - -axum = { workspace = true, features = ["ws", "multipart"] } +axum = { workspace = true, features = ["ws"] } axum-extra = { workspace = true, features = ["query"] } tower-http = { workspace = true, features = ["cors", "trace"] } diff --git a/plugins/local-stt/build.rs b/plugins/local-stt/build.rs index ebfcdb2193..e3cbc1d082 100644 --- a/plugins/local-stt/build.rs +++ b/plugins/local-stt/build.rs @@ -11,7 +11,6 @@ const COMMANDS: &[&str] = &[ "get_current_model", "set_current_model", "list_supported_models", - "process_recorded", ]; fn main() { diff --git a/plugins/local-stt/js/bindings.gen.ts b/plugins/local-stt/js/bindings.gen.ts index b1ca3d9560..58b2abb843 100644 --- a/plugins/local-stt/js/bindings.gen.ts +++ b/plugins/local-stt/js/bindings.gen.ts @@ -42,9 +42,6 @@ async stopServer() : Promise { }, async restartServer() : Promise { return await TAURI_INVOKE("plugin:local-stt|restart_server"); -}, -async processRecorded(audioPath: string) : Promise { - return await TAURI_INVOKE("plugin:local-stt|process_recorded", { audioPath }); } } diff --git a/plugins/local-stt/permissions/autogenerated/commands/process_recorded.toml b/plugins/local-stt/permissions/autogenerated/commands/process_recorded.toml deleted file mode 100644 index c5b3e018a2..0000000000 --- a/plugins/local-stt/permissions/autogenerated/commands/process_recorded.toml +++ /dev/null @@ -1,13 +0,0 @@ -# Automatically generated - DO NOT EDIT! - -"$schema" = "../../schemas/schema.json" - -[[permission]] -identifier = "allow-process-recorded" -description = "Enables the process_recorded command without any pre-configured scope." -commands.allow = ["process_recorded"] - -[[permission]] -identifier = "deny-process-recorded" -description = "Denies the process_recorded command without any pre-configured scope." -commands.deny = ["process_recorded"] diff --git a/plugins/local-stt/permissions/autogenerated/reference.md b/plugins/local-stt/permissions/autogenerated/reference.md index 279af23c28..8c9bc2f0b3 100644 --- a/plugins/local-stt/permissions/autogenerated/reference.md +++ b/plugins/local-stt/permissions/autogenerated/reference.md @@ -15,7 +15,6 @@ Default permissions for the plugin - `allow-get-current-model` - `allow-set-current-model` - `allow-list-supported-models` -- `allow-process-recorded` ## Permission Table @@ -263,32 +262,6 @@ Denies the models_dir command without any pre-configured scope. -`local-stt:allow-process-recorded` - - - - -Enables the process_recorded command without any pre-configured scope. - - - - - - - -`local-stt:deny-process-recorded` - - - - -Denies the process_recorded command without any pre-configured scope. - - - - - - - `local-stt:allow-restart-server` diff --git a/plugins/local-stt/permissions/default.toml b/plugins/local-stt/permissions/default.toml index c3ce8b5a3f..782b142a33 100644 --- a/plugins/local-stt/permissions/default.toml +++ b/plugins/local-stt/permissions/default.toml @@ -12,5 +12,4 @@ permissions = [ "allow-get-current-model", "allow-set-current-model", "allow-list-supported-models", - "allow-process-recorded", ] diff --git a/plugins/local-stt/permissions/schemas/schema.json b/plugins/local-stt/permissions/schemas/schema.json index 69fffdbd17..03817bc0c7 100644 --- a/plugins/local-stt/permissions/schemas/schema.json +++ b/plugins/local-stt/permissions/schemas/schema.json @@ -402,18 +402,6 @@ "const": "deny-models-dir", "markdownDescription": "Denies the models_dir command without any pre-configured scope." }, - { - "description": "Enables the process_recorded command without any pre-configured scope.", - "type": "string", - "const": "allow-process-recorded", - "markdownDescription": "Enables the process_recorded command without any pre-configured scope." - }, - { - "description": "Denies the process_recorded command without any pre-configured scope.", - "type": "string", - "const": "deny-process-recorded", - "markdownDescription": "Denies the process_recorded command without any pre-configured scope." - }, { "description": "Enables the restart_server command without any pre-configured scope.", "type": "string", @@ -463,10 +451,10 @@ "markdownDescription": "Denies the stop_server command without any pre-configured scope." }, { - "description": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-server-running`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-restart-server`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`\n- `allow-process-recorded`", + "description": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-server-running`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-restart-server`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`", "type": "string", "const": "default", - "markdownDescription": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-server-running`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-restart-server`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`\n- `allow-process-recorded`" + "markdownDescription": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-server-running`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-restart-server`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`" } ] } diff --git a/plugins/local-stt/src/commands.rs b/plugins/local-stt/src/commands.rs index 826379a9a3..209ccd8987 100644 --- a/plugins/local-stt/src/commands.rs +++ b/plugins/local-stt/src/commands.rs @@ -97,24 +97,3 @@ pub async fn restart_server(app: tauri::AppHandle) -> Resu app.stop_server().await.map_err(|e| e.to_string())?; app.start_server().await.map_err(|e| e.to_string()) } - -#[tauri::command] -#[specta::specta] -pub fn process_recorded( - app: tauri::AppHandle, - audio_path: String, -) -> Result<(), String> { - let current_model = app.get_current_model().map_err(|e| e.to_string())?; - let model_path = app.models_dir().join(current_model.file_name()); - - let app_clone = app.clone(); - app.spawn_task_blocking(move |_ctx| { - let app_clone_inner = app_clone.clone(); - let _ = app_clone - .process_recorded(model_path, audio_path, move |event| { - let _ = crate::events::RecordedProcessingEvent::emit(&event, &app_clone_inner); - }) - .map_err(|e| e.to_string()); - }); - Ok(()) -} diff --git a/plugins/local-stt/src/ext.rs b/plugins/local-stt/src/ext.rs index 688e2e53e3..a4121e319e 100644 --- a/plugins/local-stt/src/ext.rs +++ b/plugins/local-stt/src/ext.rs @@ -25,13 +25,6 @@ pub trait LocalSttPluginExt { fn get_current_model(&self) -> Result; fn set_current_model(&self, model: crate::SupportedModel) -> Result<(), crate::Error>; - fn process_recorded( - &self, - model_path: impl AsRef, - audio_path: impl AsRef, - progress_fn: impl FnMut(RecordedProcessingEvent) + Send + 'static, - ) -> Result, crate::Error>; - fn download_model( &self, model: crate::SupportedModel, @@ -121,7 +114,7 @@ impl> LocalSttPluginExt for T { return Err(crate::Error::ModelNotDownloaded); } - let server_state = crate::ServerStateBuilder::default() + let server_state = crate::ServerState::builder() .model_cache_dir(cache_dir) .model_type(model) .build(); @@ -194,73 +187,6 @@ impl> LocalSttPluginExt for T { Ok(()) } - #[tracing::instrument(skip_all)] - fn process_recorded( - &self, - model_path: impl AsRef, - audio_path: impl AsRef, - mut progress_fn: impl FnMut(RecordedProcessingEvent) + Send + 'static, - ) -> Result, crate::Error> { - use rodio::Source; - - let decoder = rodio::Decoder::new(std::io::BufReader::new( - std::fs::File::open(audio_path.as_ref()).unwrap(), - )) - .unwrap(); - - let original_sample_rate = decoder.sample_rate(); - - let resampled_samples = if original_sample_rate != 16000 { - hypr_audio_utils::resample_audio(decoder, 16000).unwrap() - } else { - decoder.convert_samples().collect() - }; - - let samples_i16 = hypr_audio_utils::f32_to_i16_samples(&resampled_samples); - - let mut model = hypr_whisper_local::Whisper::builder() - .model_path(model_path.as_ref().to_str().unwrap()) - .languages(vec![]) - .static_prompt("") - .dynamic_prompt("") - .build(); - - let mut segmenter = hypr_pyannote_local::segmentation::Segmenter::new(16000).unwrap(); - let segments = segmenter.process(&samples_i16, 16000).unwrap(); - let num_segments = segments.len(); - - let mut words = Vec::new(); - - for segment in segments { - let audio_f32 = hypr_audio_utils::i16_to_f32_samples(&segment.samples); - - let whisper_segments = model.transcribe(&audio_f32).unwrap(); - - for whisper_segment in whisper_segments { - let start_sec: f64 = segment.start + (whisper_segment.start() as f64); - let end_sec: f64 = segment.start + (whisper_segment.end() as f64); - let start_ms = (start_sec * 1000.0) as u64; - let end_ms = (end_sec * 1000.0) as u64; - - let word = Word { - text: whisper_segment.text().to_string(), - speaker: None, - confidence: Some(whisper_segment.confidence()), - start_ms: Some(start_ms), - end_ms: Some(end_ms), - }; - words.push(word.clone()); - progress_fn(RecordedProcessingEvent::Progress { - current: words.len(), - total: num_segments, - word, - }); - } - } - - Ok(words) - } - #[tracing::instrument(skip_all)] async fn is_model_downloading(&self, model: &crate::SupportedModel) -> bool { let state = self.state::(); diff --git a/plugins/local-stt/src/lib.rs b/plugins/local-stt/src/lib.rs index 493969c21a..fbcab85c73 100644 --- a/plugins/local-stt/src/lib.rs +++ b/plugins/local-stt/src/lib.rs @@ -5,12 +5,10 @@ mod commands; mod error; mod events; mod ext; -mod manager; mod model; +mod server; mod store; -pub mod server; - pub use error::*; pub use ext::*; pub use model::*; @@ -44,7 +42,6 @@ fn make_specta_builder() -> tauri_specta::Builder { commands::start_server::, commands::stop_server::, commands::restart_server::, - commands::process_recorded::, ]) .events(tauri_specta::collect_events![ events::RecordedProcessingEvent @@ -155,24 +152,4 @@ mod test { app.stop_server().await.unwrap(); } - - #[tokio::test] - #[ignore] - // cargo test test_local_stt2 -p tauri-plugin-local-stt -- --ignored --nocapture - async fn test_local_stt2() { - let app = create_app(tauri::test::mock_builder()); - - let model_path = dirs::data_dir() - .unwrap() - .join("com.hyprnote.dev/stt") - .join("ggml-tiny.en-q8_0.bin"); - - let words = app - .process_recorded(model_path, hypr_data::english_1::AUDIO_PATH, |event| { - println!("{:?}", event); - }) - .unwrap(); - - println!("{:?}", words); - } } diff --git a/plugins/local-stt/src/manager.rs b/plugins/local-stt/src/manager.rs deleted file mode 100644 index 59ae406386..0000000000 --- a/plugins/local-stt/src/manager.rs +++ /dev/null @@ -1,133 +0,0 @@ -use std::sync::{Arc, Mutex}; -use tokio_util::sync::CancellationToken; - -#[derive(Clone)] -pub struct ConnectionManager { - inner: Arc>>, -} - -impl Default for ConnectionManager { - fn default() -> Self { - Self { - inner: Arc::new(Mutex::new(None)), - } - } -} - -impl ConnectionManager { - pub fn acquire_connection(&self) -> ConnectionGuard { - let mut slot = self.inner.lock().unwrap(); - - if let Some(old) = slot.take() { - old.cancel(); - } - - let token = CancellationToken::new(); - *slot = Some(token.clone()); - - ConnectionGuard { token } - } -} - -pub struct ConnectionGuard { - token: CancellationToken, -} - -impl ConnectionGuard { - pub async fn cancelled(&self) { - self.token.cancelled().await - } -} - -#[cfg(test)] -mod tests { - use super::*; - use axum::{ - extract::{ - ws::{Message, WebSocket}, - State as AxumState, WebSocketUpgrade, - }, - http::StatusCode, - response::IntoResponse, - routing::get, - Router, - }; - use futures_util::{SinkExt, StreamExt}; - use std::{ - future::IntoFuture, - net::{Ipv4Addr, SocketAddr}, - }; - use tokio_tungstenite::{connect_async, tungstenite::protocol::Message as TungsteniteMessage}; - - fn app() -> Router { - let manager = ConnectionManager::default(); - Router::new().route("/ws", get(handler)).with_state(manager) - } - - async fn handler( - ws: WebSocketUpgrade, - AxumState(manager): AxumState, - ) -> Result { - let guard = manager.acquire_connection(); - - Ok(ws.on_upgrade(move |socket| handle_socket(socket, guard))) - } - - async fn handle_socket(socket: WebSocket, _guard: ConnectionGuard) { - let (mut sink, mut stream) = socket.split(); - - while let Some(Ok(msg)) = stream.next().await { - if let Ok(msg) = msg.to_text() { - sink.send(Message::Text(msg.into())).await.unwrap(); - } - } - } - - #[ignore] - #[tokio::test] - async fn integration_test() { - let addr = { - let listener = - tokio::net::TcpListener::bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))) - .await - .unwrap(); - let addr = listener.local_addr().unwrap(); - tokio::spawn(axum::serve(listener, app()).into_future()); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - addr - }; - - let socket_1 = { - let (mut socket, _response) = connect_async(format!("ws://{}/ws", addr)).await.unwrap(); - - socket - .send(TungsteniteMessage::Text("test message 1".into())) - .await - .unwrap(); - - let msg = socket.next().await.unwrap().unwrap(); - assert_eq!(msg.to_text().unwrap(), "test message 1"); - - socket - }; - - { - let result = connect_async(format!("ws://{}/ws", addr)).await; - assert!(result.is_err()); - - if let Err(tokio_tungstenite::tungstenite::Error::Http(response)) = result { - assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS); - assert_eq!(response.status().as_u16(), 429); - } else { - panic!("{:?}", result); - } - } - - drop(socket_1); - - { - let result = connect_async(format!("ws://{}/ws", addr)).await; - assert!(result.is_ok()); - } - } -} diff --git a/plugins/local-stt/src/server.rs b/plugins/local-stt/src/server.rs index 7e1453a192..bac82d121d 100644 --- a/plugins/local-stt/src/server.rs +++ b/plugins/local-stt/src/server.rs @@ -1,29 +1,11 @@ use std::{ net::{Ipv4Addr, SocketAddr}, path::PathBuf, - time::Duration, }; -use axum::{ - extract::{ - ws::{Message, WebSocket, WebSocketUpgrade}, - State as AxumState, - }, - http::StatusCode, - response::IntoResponse, - routing::get, - Router, -}; -use axum_extra::extract::Query; - -use futures_util::{SinkExt, StreamExt}; +use axum::{http::StatusCode, response::IntoResponse, routing::get, Router}; use tower_http::cors::{self, CorsLayer}; -use hypr_chunker::VadExt; -use hypr_listener_interface::{ListenOutputChunk, ListenParams, Word}; - -use crate::manager::{ConnectionGuard, ConnectionManager}; - #[derive(Default)] pub struct ServerStateBuilder { pub model_type: Option, @@ -45,7 +27,6 @@ impl ServerStateBuilder { ServerState { model_type: self.model_type.unwrap(), model_cache_dir: self.model_cache_dir.unwrap(), - connection_manager: ConnectionManager::default(), } } } @@ -54,7 +35,12 @@ impl ServerStateBuilder { pub struct ServerState { model_type: crate::SupportedModel, model_cache_dir: PathBuf, - connection_manager: ConnectionManager, +} + +impl ServerState { + pub fn builder() -> ServerStateBuilder { + ServerStateBuilder::default() + } } #[derive(Clone)] @@ -64,16 +50,7 @@ pub struct ServerHandle { } pub async fn run_server(state: ServerState) -> Result { - let router = Router::new() - .route("/health", get(health)) - .route("/api/desktop/listen/realtime", get(listen)) - .layer( - CorsLayer::new() - .allow_origin(cors::Any) - .allow_methods(cors::Any) - .allow_headers(cors::Any), - ) - .with_state(state); + let router = make_service_router(state); let listener = tokio::net::TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))).await?; @@ -100,205 +77,54 @@ pub async fn run_server(state: ServerState) -> Result impl IntoResponse { - "ok" -} +fn make_service_router(state: ServerState) -> Router { + let model_path = state.model_cache_dir.join(state.model_type.file_name()); -async fn listen( - Query(params): Query, - ws: WebSocketUpgrade, - AxumState(state): AxumState, -) -> Result { - let guard = state.connection_manager.acquire_connection(); - - Ok(ws.on_upgrade(move |socket| async move { - websocket_with_model(socket, params, state, guard).await - })) -} - -async fn websocket_with_model( - socket: WebSocket, - params: ListenParams, - state: ServerState, - guard: ConnectionGuard, -) { - let model_type = state.model_type; - let model_cache_dir = state.model_cache_dir.clone(); - let model_path = model_cache_dir.join(model_type.file_name()); - - let languages: Vec = params - .languages - .into_iter() - .filter_map(|lang| lang.try_into().ok()) - .collect(); - - let model = hypr_whisper_local::Whisper::builder() - .model_path(model_path.to_str().unwrap()) - .languages(languages) - .static_prompt(¶ms.static_prompt) - .dynamic_prompt(¶ms.dynamic_prompt) + let whisper_service = hypr_transcribe_whisper_local::WhisperStreamingService::builder() + .model_path(model_path) .build(); - let (ws_sender, ws_receiver) = socket.split(); - - match params.audio_mode { - hypr_listener_interface::AudioMode::Single => { - websocket_single_channel( - ws_sender, - ws_receiver, - model, - guard, - Duration::from_millis(params.redemption_time_ms), - ) - .await; - } - hypr_listener_interface::AudioMode::Dual => { - websocket_dual_channel( - ws_sender, - ws_receiver, - model, - guard, - Duration::from_millis(params.redemption_time_ms), - ) - .await; - } - } -} - -async fn websocket_single_channel( - ws_sender: futures_util::stream::SplitSink, - ws_receiver: futures_util::stream::SplitStream, - model: hypr_whisper_local::Whisper, - guard: ConnectionGuard, - redemption_time: Duration, -) { - let audio_source = hypr_ws_utils::WebSocketAudioSource::new(ws_receiver, 16 * 1000); - let vad_chunks = audio_source.vad_chunks(redemption_time); - - let chunked = hypr_whisper_local::AudioChunkStream(process_vad_stream(vad_chunks, "mixed")); - - let stream = hypr_whisper_local::TranscribeMetadataAudioStreamExt::transcribe(chunked, model); - process_transcription_stream(ws_sender, stream, guard).await; + Router::new() + .route("/health", get(health)) + .route_service("/api/desktop/listen/realtime", whisper_service) + .layer( + CorsLayer::new() + .allow_origin(cors::Any) + .allow_methods(cors::Any) + .allow_headers(cors::Any), + ) } -async fn websocket_dual_channel( - ws_sender: futures_util::stream::SplitSink, - ws_receiver: futures_util::stream::SplitStream, - model: hypr_whisper_local::Whisper, - guard: ConnectionGuard, - redemption_time: Duration, -) { - let (mic_source, speaker_source) = - hypr_ws_utils::split_dual_audio_sources(ws_receiver, 16 * 1000); - - let mic_chunked = { - let mic_vad_chunks = mic_source.vad_chunks(redemption_time); - hypr_whisper_local::AudioChunkStream(process_vad_stream(mic_vad_chunks, "mic")) - }; - - let speaker_chunked = { - let speaker_vad_chunks = speaker_source.vad_chunks(redemption_time); - hypr_whisper_local::AudioChunkStream(process_vad_stream(speaker_vad_chunks, "speaker")) - }; - - let merged_stream = hypr_whisper_local::AudioChunkStream(futures_util::stream::select( - mic_chunked.0, - speaker_chunked.0, - )); - - let stream = - hypr_whisper_local::TranscribeMetadataAudioStreamExt::transcribe(merged_stream, model); - - process_transcription_stream(ws_sender, stream, guard).await; +async fn health() -> impl IntoResponse { + StatusCode::OK } -async fn process_transcription_stream( - mut ws_sender: futures_util::stream::SplitSink, - mut stream: impl futures_util::Stream + Unpin, - guard: ConnectionGuard, -) { - loop { - tokio::select! { - _ = guard.cancelled() => { - tracing::info!("websocket_cancelled_by_new_connection"); - break; - } - chunk_opt = stream.next() => { - let Some(chunk) = chunk_opt else { break }; - - let meta = chunk.meta(); - let text = chunk.text().to_string(); - let start = chunk.start() as u64; - let duration = chunk.duration() as u64; - let confidence = chunk.confidence(); - - - - let source = meta.and_then(|meta| - meta.get("source") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - ); - let speaker = match source { - Some(s) if s == "mic" => Some(hypr_listener_interface::SpeakerIdentity::Unassigned { index: 0 }), - Some(s) if s == "speaker" => Some(hypr_listener_interface::SpeakerIdentity::Unassigned { index: 1 }), - _ => None, - }; - - let data = ListenOutputChunk { - meta: None, - words: text - .split_whitespace() - .filter(|w| !w.is_empty()) - .map(|w| Word { - text: w.trim().to_string(), - speaker: speaker.clone(), - start_ms: Some(start), - end_ms: Some(start + duration), - confidence: Some(confidence), - }) - .collect(), - }; +#[cfg(test)] +mod tests { + use super::*; + use axum::body::Body; + use axum::http::{Request, StatusCode}; + use tower::ServiceExt; + + #[tokio::test] + async fn test_health_endpoint() { + let state = ServerStateBuilder::default() + .model_cache_dir(dirs::data_dir().unwrap().join("com.hyprnote.dev/stt")) + .model_type(crate::SupportedModel::QuantizedTinyEn) + .build(); + + let app = make_service_router(state); + + let response = app + .oneshot( + Request::builder() + .uri("/health") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); - let msg = Message::Text(serde_json::to_string(&data).unwrap().into()); - if let Err(e) = ws_sender.send(msg).await { - tracing::warn!("websocket_send_error: {}", e); - break; - } - } - } + assert_eq!(response.status(), StatusCode::OK); } - - let _ = ws_sender.close().await; -} - -fn process_vad_stream( - stream: S, - source_name: &str, -) -> impl futures_util::Stream -where - S: futures_util::Stream>, - E: std::fmt::Display, -{ - let source_name = source_name.to_string(); - - stream - .take_while(move |chunk_result| { - futures_util::future::ready(match chunk_result { - Ok(_) => true, - Err(e) => { - tracing::error!("vad_error_disconnecting: {}", e); - false // This will end the stream - } - }) - }) - .filter_map(move |chunk_result| { - futures_util::future::ready(match chunk_result { - Err(_) => None, // This shouldn't happen due to take_while above - Ok(chunk) => Some(hypr_whisper_local::SimpleAudioChunk { - samples: chunk.samples, - meta: Some(serde_json::json!({ "source": source_name })), - }), - }) - }) }