From 17e893ac47958d8df2391ad74a4f30fb8968de39 Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Mon, 1 Dec 2025 21:30:36 +0900 Subject: [PATCH 1/3] looks good --- crates/ws/src/client.rs | 7 +- owhisper/owhisper-client/Cargo.toml | 1 + .../owhisper-client/src/adapter/argmax.rs | 10 + .../owhisper-client/src/adapter/deepgram.rs | 10 + owhisper/owhisper-client/src/adapter/mod.rs | 9 + owhisper/owhisper-client/src/lib.rs | 44 ++++- owhisper/owhisper-client/src/live.rs | 169 +++++++++++++++-- plugins/listener/src/actors/listener.rs | 176 ++---------------- 8 files changed, 235 insertions(+), 191 deletions(-) create mode 100644 owhisper/owhisper-client/src/adapter/argmax.rs create mode 100644 owhisper/owhisper-client/src/adapter/deepgram.rs create mode 100644 owhisper/owhisper-client/src/adapter/mod.rs diff --git a/crates/ws/src/client.rs b/crates/ws/src/client.rs index 48188236b6..72cc363477 100644 --- a/crates/ws/src/client.rs +++ b/crates/ws/src/client.rs @@ -5,12 +5,9 @@ use futures_util::{ future::{pending, FutureExt}, SinkExt, Stream, StreamExt, }; -use tokio_tungstenite::{ - connect_async, - tungstenite::{client::IntoClientRequest, Utf8Bytes}, -}; +use tokio_tungstenite::{connect_async, tungstenite::client::IntoClientRequest}; -pub use tokio_tungstenite::tungstenite::{protocol::Message, ClientRequestBuilder}; +pub use tokio_tungstenite::tungstenite::{protocol::Message, ClientRequestBuilder, Utf8Bytes}; #[derive(Debug)] enum ControlCommand { diff --git a/owhisper/owhisper-client/Cargo.toml b/owhisper/owhisper-client/Cargo.toml index 2509b01192..4d56a0bae5 100644 --- a/owhisper/owhisper-client/Cargo.toml +++ b/owhisper/owhisper-client/Cargo.toml @@ -14,6 +14,7 @@ owhisper-interface = { workspace = true } futures-util = { workspace = true } reqwest = { workspace = true, features = ["json"] } tokio = { workspace = true } +tokio-stream = { workspace = true } bytes = { workspace = true } serde_json = { workspace = true } diff --git a/owhisper/owhisper-client/src/adapter/argmax.rs b/owhisper/owhisper-client/src/adapter/argmax.rs new file mode 100644 index 0000000000..7b801ab674 --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/argmax.rs @@ -0,0 +1,10 @@ +use super::SttAdapter; + +#[derive(Clone, Default)] +pub struct ArgmaxAdapter; + +impl SttAdapter for ArgmaxAdapter { + fn supports_native_multichannel(&self) -> bool { + false + } +} diff --git a/owhisper/owhisper-client/src/adapter/deepgram.rs b/owhisper/owhisper-client/src/adapter/deepgram.rs new file mode 100644 index 0000000000..f89f6637ba --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/deepgram.rs @@ -0,0 +1,10 @@ +use super::SttAdapter; + +#[derive(Clone, Default)] +pub struct DeepgramAdapter; + +impl SttAdapter for DeepgramAdapter { + fn supports_native_multichannel(&self) -> bool { + true + } +} diff --git a/owhisper/owhisper-client/src/adapter/mod.rs b/owhisper/owhisper-client/src/adapter/mod.rs new file mode 100644 index 0000000000..4c9ade38c1 --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/mod.rs @@ -0,0 +1,9 @@ +mod argmax; +mod deepgram; + +pub use argmax::*; +pub use deepgram::*; + +pub trait SttAdapter: Clone + Default + Send + Sync + 'static { + fn supports_native_multichannel(&self) -> bool; +} diff --git a/owhisper/owhisper-client/src/lib.rs b/owhisper/owhisper-client/src/lib.rs index cc80f0cb41..8950d5a297 100644 --- a/owhisper/owhisper-client/src/lib.rs +++ b/owhisper/owhisper-client/src/lib.rs @@ -1,23 +1,38 @@ +mod adapter; mod batch; mod error; mod live; +use std::marker::PhantomData; + use url::form_urlencoded::Serializer; use url::UrlQuery; +pub use adapter::{ArgmaxAdapter, DeepgramAdapter, SttAdapter}; pub use batch::BatchClient; pub use error::Error; pub use hypr_ws; -pub use live::{ListenClient, ListenClientDual}; +pub use live::{DualHandle, FinalizeHandle, ListenClient, ListenClientDual}; -#[derive(Default)] -pub struct ListenClientBuilder { +pub struct ListenClientBuilder { api_base: Option, api_key: Option, params: Option, + _marker: PhantomData, +} + +impl Default for ListenClientBuilder { + fn default() -> Self { + Self { + api_base: None, + api_key: None, + params: None, + _marker: PhantomData, + } + } } -impl ListenClientBuilder { +impl ListenClientBuilder { pub fn api_base(mut self, api_base: impl Into) -> Self { self.api_base = Some(api_base.into()); self @@ -33,6 +48,15 @@ impl ListenClientBuilder { self } + pub fn adapter(self) -> ListenClientBuilder { + ListenClientBuilder { + api_base: self.api_base, + api_key: self.api_key, + params: self.params, + _marker: PhantomData, + } + } + fn listen_endpoint_url(&self) -> url::Url { let mut url: url::Url = self .api_base @@ -171,9 +195,15 @@ impl ListenClientBuilder { self.build_with_channels(1) } - pub fn build_dual(self) -> ListenClientDual { - let request = self.build_request(2); - ListenClientDual { request } + pub fn build_dual(self) -> ListenClientDual { + let adapter = A::default(); + let channels = if adapter.supports_native_multichannel() { + 2 + } else { + 1 + }; + let request = self.build_request(channels); + ListenClientDual { adapter, request } } } diff --git a/owhisper/owhisper-client/src/live.rs b/owhisper/owhisper-client/src/live.rs index c0d8f6e571..7a3a5bdd76 100644 --- a/owhisper/owhisper-client/src/live.rs +++ b/owhisper/owhisper-client/src/live.rs @@ -1,12 +1,15 @@ +use std::pin::Pin; use std::time::Duration; -use futures_util::Stream; +use futures_util::{Stream, StreamExt}; -use hypr_ws::client::{ClientRequestBuilder, Message, WebSocketClient, WebSocketIO}; +use hypr_ws::client::{ + ClientRequestBuilder, Message, Utf8Bytes, WebSocketClient, WebSocketHandle, WebSocketIO, +}; use owhisper_interface::stream::StreamResponse; use owhisper_interface::{ControlMessage, MixedMessage}; -use crate::ListenClientBuilder; +use crate::{ListenClientBuilder, SttAdapter}; pub type ListenClientInput = MixedMessage; pub type ListenClientDualInput = MixedMessage<(bytes::Bytes, bytes::Bytes), ControlMessage>; @@ -17,10 +20,56 @@ pub struct ListenClient { } #[derive(Clone)] -pub struct ListenClientDual { +pub struct ListenClientDual { + pub(crate) adapter: A, pub(crate) request: ClientRequestBuilder, } +pub enum DualHandle { + Native(WebSocketHandle), + Split { + mic: WebSocketHandle, + spk: WebSocketHandle, + }, +} + +pub trait FinalizeHandle: Send { + fn finalize_with_text(&self, text: Utf8Bytes) -> impl std::future::Future + Send; + fn expected_finalize_count(&self) -> usize; +} + +impl FinalizeHandle for WebSocketHandle { + async fn finalize_with_text(&self, text: Utf8Bytes) { + self.finalize_with_text(text).await + } + + fn expected_finalize_count(&self) -> usize { + 1 + } +} + +impl FinalizeHandle for DualHandle { + async fn finalize_with_text(&self, text: Utf8Bytes) { + match self { + DualHandle::Native(h) => h.finalize_with_text(text).await, + DualHandle::Split { mic, spk } => { + let text_clone = text.clone(); + tokio::join!( + WebSocketHandle::finalize_with_text(mic, text), + WebSocketHandle::finalize_with_text(spk, text_clone) + ); + } + } + } + + fn expected_finalize_count(&self) -> usize { + match self { + DualHandle::Native(_) => 1, + DualHandle::Split { .. } => 2, + } + } +} + fn interleave_audio(mic: &[u8], speaker: &[u8]) -> Vec { let mic_samples: Vec = mic .chunks_exact(2) @@ -44,7 +93,9 @@ fn interleave_audio(mic: &[u8], speaker: &[u8]) -> Vec { interleaved } -impl WebSocketIO for ListenClient { +pub struct ListenClientIO; + +impl WebSocketIO for ListenClientIO { type Data = ListenClientInput; type Input = ListenClientInput; type Output = StreamResponse; @@ -70,7 +121,9 @@ impl WebSocketIO for ListenClient { } } -impl WebSocketIO for ListenClientDual { +pub struct ListenClientDualIO; + +impl WebSocketIO for ListenClientDualIO { type Data = ListenClientDualInput; type Input = ListenClientInput; type Output = StreamResponse; @@ -118,24 +171,108 @@ impl ListenClient { hypr_ws::Error, > { let ws = websocket_client_with_keep_alive(&self.request); - ws.from_audio::(audio_stream).await + ws.from_audio::(audio_stream).await } } -impl ListenClientDual { +type DualOutputStream = Pin> + Send>>; + +impl ListenClientDual { pub async fn from_realtime_audio( self, stream: impl Stream + Send + Unpin + 'static, - ) -> Result< - ( - impl Stream>, - hypr_ws::client::WebSocketHandle, - ), - hypr_ws::Error, - > { + ) -> Result<(DualOutputStream, DualHandle), hypr_ws::Error> { + if self.adapter.supports_native_multichannel() { + self.from_realtime_audio_native(stream).await + } else { + self.from_realtime_audio_split(stream).await + } + } + + async fn from_realtime_audio_native( + self, + stream: impl Stream + Send + Unpin + 'static, + ) -> Result<(DualOutputStream, DualHandle), hypr_ws::Error> { let ws = websocket_client_with_keep_alive(&self.request); - ws.from_audio::(stream).await + let (output_stream, handle) = ws.from_audio::(stream).await?; + Ok((Box::pin(output_stream), DualHandle::Native(handle))) } + + async fn from_realtime_audio_split( + self, + stream: impl Stream + Send + Unpin + 'static, + ) -> Result<(DualOutputStream, DualHandle), hypr_ws::Error> { + let (mic_tx, mic_rx) = tokio::sync::mpsc::channel::(32); + let (spk_tx, spk_rx) = tokio::sync::mpsc::channel::(32); + + let mic_ws = websocket_client_with_keep_alive(&self.request); + let spk_ws = websocket_client_with_keep_alive(&self.request); + + let mic_outbound = tokio_stream::wrappers::ReceiverStream::new(mic_rx); + let spk_outbound = tokio_stream::wrappers::ReceiverStream::new(spk_rx); + + let mic_connect = mic_ws.from_audio::(mic_outbound); + let spk_connect = spk_ws.from_audio::(spk_outbound); + + let ((mic_stream, mic_handle), (spk_stream, spk_handle)) = + tokio::try_join!(mic_connect, spk_connect)?; + + tokio::spawn(forward_dual_to_single(stream, mic_tx, spk_tx)); + + let merged_stream = merge_streams_with_channel_remap(mic_stream, spk_stream); + + Ok(( + Box::pin(merged_stream), + DualHandle::Split { + mic: mic_handle, + spk: spk_handle, + }, + )) + } +} + +async fn forward_dual_to_single( + mut stream: impl Stream + Send + Unpin + 'static, + mic_tx: tokio::sync::mpsc::Sender, + spk_tx: tokio::sync::mpsc::Sender, +) { + while let Some(msg) = stream.next().await { + match msg { + MixedMessage::Audio((mic, spk)) => { + let _ = mic_tx.try_send(MixedMessage::Audio(mic)); + let _ = spk_tx.try_send(MixedMessage::Audio(spk)); + } + MixedMessage::Control(ctrl) => { + let _ = mic_tx.send(MixedMessage::Control(ctrl.clone())).await; + let _ = spk_tx.send(MixedMessage::Control(ctrl)).await; + } + } + } +} + +fn merge_streams_with_channel_remap( + mic_stream: S1, + spk_stream: S2, +) -> impl Stream> + Send +where + S1: Stream> + Send + 'static, + S2: Stream> + Send + 'static, +{ + let mic_mapped = mic_stream.map(|result| { + result.map(|mut response| { + response.set_channel_index(0, 2); + response + }) + }); + + let spk_mapped = spk_stream.map(|result| { + result.map(|mut response| { + response.set_channel_index(1, 2); + response + }) + }); + + futures_util::stream::select(mic_mapped, spk_mapped) } fn websocket_client_with_keep_alive(request: &ClientRequestBuilder) -> WebSocketClient { diff --git a/plugins/listener/src/actors/listener.rs b/plugins/listener/src/actors/listener.rs index 87d687a938..28d41b9c82 100644 --- a/plugins/listener/src/actors/listener.rs +++ b/plugins/listener/src/actors/listener.rs @@ -4,7 +4,7 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use futures_util::StreamExt; use tokio::time::error::Elapsed; -use owhisper_client::hypr_ws; +use owhisper_client::{ArgmaxAdapter, DeepgramAdapter, FinalizeHandle, SttAdapter}; use owhisper_interface::stream::{Extra, StreamResponse}; use owhisper_interface::{ControlMessage, MixedMessage}; use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, SupervisionEvent}; @@ -211,9 +211,9 @@ async fn spawn_rx_task( } crate::actors::ChannelMode::MicAndSpeaker => { if is_local_stt_base_url(&args.base_url) { - spawn_rx_task_dual_split(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself).await } else { - spawn_rx_task_dual(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself).await } } } @@ -297,7 +297,6 @@ async fn spawn_rx_task_single( shutdown_rx, session_offset_secs, extra, - None, ) .await; }); @@ -305,7 +304,7 @@ async fn spawn_rx_task_single( Ok((ChannelSender::Single(tx), rx_task, shutdown_tx)) } -async fn spawn_rx_task_dual( +async fn spawn_rx_task_dual_with_adapter( args: ListenerArgs, myself: ActorRef, ) -> Result< @@ -322,6 +321,7 @@ async fn spawn_rx_task_dual( let (tx, rx) = tokio::sync::mpsc::channel::>(32); let client = owhisper_client::ListenClient::builder() + .adapter::() .api_base(args.base_url.clone()) .api_key(args.api_key.clone()) .params(build_listen_params(&args)) @@ -356,7 +356,6 @@ async fn spawn_rx_task_dual( shutdown_rx, session_offset_secs, extra, - None, ) .await; }); @@ -364,161 +363,17 @@ async fn spawn_rx_task_dual( Ok((ChannelSender::Dual(tx), rx_task, shutdown_tx)) } -async fn spawn_rx_task_dual_split( - args: ListenerArgs, - myself: ActorRef, -) -> Result< - ( - ChannelSender, - tokio::task::JoinHandle<()>, - tokio::sync::oneshot::Sender<()>, - ), - ActorProcessingErr, -> { - let (shutdown_tx_global, shutdown_rx_global) = tokio::sync::oneshot::channel::<()>(); - let (session_offset_secs, extra) = build_extra(&args); - - let (tx, rx) = tokio::sync::mpsc::channel::>(32); - - let (mic_tx, mic_rx) = tokio::sync::mpsc::channel::>(32); - let (spk_tx, spk_rx) = tokio::sync::mpsc::channel::>(32); - - let (shutdown_tx_mic, shutdown_rx_mic) = tokio::sync::oneshot::channel::<()>(); - let (shutdown_tx_spk, shutdown_rx_spk) = tokio::sync::oneshot::channel::<()>(); - - let extra_mic = extra.clone(); - let extra_spk = extra; - - let mic_client = owhisper_client::ListenClient::builder() - .api_base(args.base_url.clone()) - .api_key(args.api_key.clone()) - .params(build_listen_params(&args)) - .build_single(); - - let spk_client = owhisper_client::ListenClient::builder() - .api_base(args.base_url.clone()) - .api_key(args.api_key.clone()) - .params(build_listen_params(&args)) - .build_single(); - - let mic_outbound = tokio_stream::wrappers::ReceiverStream::new(mic_rx); - let spk_outbound = tokio_stream::wrappers::ReceiverStream::new(spk_rx); - - let connect_fut = async { - tokio::try_join!( - mic_client.from_realtime_audio(mic_outbound), - spk_client.from_realtime_audio(spk_outbound) - ) - }; - - let connect_result = tokio::time::timeout(LISTEN_CONNECT_TIMEOUT, connect_fut).await; - - let ((mic_stream, mic_handle), (spk_stream, spk_handle)) = match connect_result { - Err(_elapsed) => { - tracing::error!( - timeout_secs = LISTEN_CONNECT_TIMEOUT.as_secs_f32(), - "listen_ws_connect_timeout(dual_split)" - ); - return Err(actor_error("listen_ws_connect_timeout")); - } - Ok(Err(e)) => { - tracing::error!(error = ?e, "listen_ws_connect_failed(dual_split)"); - return Err(actor_error(format!("listen_ws_connect_failed: {:?}", e))); - } - Ok(Ok(res)) => res, - }; - - let rx_task = tokio::spawn(async move { - let myself_mic = myself.clone(); - let myself_spk = myself; - - let mic_fut = async move { - futures_util::pin_mut!(mic_stream); - process_stream( - mic_stream, - mic_handle, - myself_mic, - shutdown_rx_mic, - session_offset_secs, - extra_mic, - Some((0, 2)), - ) - .await; - }; - - let spk_fut = async move { - futures_util::pin_mut!(spk_stream); - process_stream( - spk_stream, - spk_handle, - myself_spk, - shutdown_rx_spk, - session_offset_secs, - extra_spk, - Some((1, 2)), - ) - .await; - }; - - let forward_fut = async move { - let mut rx = rx; - let mut shutdown_rx_global = shutdown_rx_global; - let mut shutdown_tx_mic = Some(shutdown_tx_mic); - let mut shutdown_tx_spk = Some(shutdown_tx_spk); - - loop { - tokio::select! { - _ = &mut shutdown_rx_global => { - if let Some(tx) = shutdown_tx_mic.take() { - let _ = tx.send(()); - } - if let Some(tx) = shutdown_tx_spk.take() { - let _ = tx.send(()); - } - break; - } - msg = rx.recv() => { - match msg { - Some(MixedMessage::Audio((mic, spk))) => { - let _ = mic_tx.try_send(MixedMessage::Audio(mic)); - let _ = spk_tx.try_send(MixedMessage::Audio(spk)); - } - Some(MixedMessage::Control(ctrl)) => { - let _ = mic_tx.send(MixedMessage::Control(ctrl.clone())).await; - let _ = spk_tx.send(MixedMessage::Control(ctrl)).await; - } - None => { - if let Some(tx) = shutdown_tx_mic.take() { - let _ = tx.send(()); - } - if let Some(tx) = shutdown_tx_spk.take() { - let _ = tx.send(()); - } - break; - } - } - } - } - } - }; - - let _ = tokio::join!(mic_fut, spk_fut, forward_fut); - }); - - Ok((ChannelSender::Dual(tx), rx_task, shutdown_tx_global)) -} - -async fn process_stream( +async fn process_stream( mut listen_stream: std::pin::Pin<&mut S>, - handle: hypr_ws::client::WebSocketHandle, + handle: H, myself: ActorRef, mut shutdown_rx: tokio::sync::oneshot::Receiver<()>, offset_secs: f64, extra: Extra, - channel_override: Option<(i32, i32)>, ) where S: futures_util::Stream>, E: std::fmt::Debug, + H: FinalizeHandle, { loop { tokio::select! { @@ -528,7 +383,8 @@ async fn process_stream( let finalize_timeout = tokio::time::sleep(Duration::from_secs(5)); tokio::pin!(finalize_timeout); - let mut received_from_finalize = false; + let expected_count = handle.expected_finalize_count(); + let mut finalize_count = 0usize; loop { tokio::select! { @@ -546,22 +402,19 @@ async fn process_stream( }; if is_from_finalize { - received_from_finalize = true; + finalize_count += 1; } response.apply_offset(offset_secs); response.set_extra(&extra); - if let Some((channel_idx, total_channels)) = channel_override { - response.set_channel_index(channel_idx, total_channels); - } if myself.send_message(ListenerMsg::StreamResponse(response)).is_err() { tracing::warn!("actor_gone_during_finalize"); break; } - if received_from_finalize { - tracing::info!(from_finalize = true, "break_from_finalize"); + if finalize_count >= expected_count { + tracing::info!(finalize_count, expected_count, "break_from_finalize"); break; } } @@ -584,9 +437,6 @@ async fn process_stream( Ok(Some(Ok(mut response))) => { response.apply_offset(offset_secs); response.set_extra(&extra); - if let Some((channel_idx, total_channels)) = channel_override { - response.set_channel_index(channel_idx, total_channels); - } if myself.send_message(ListenerMsg::StreamResponse(response)).is_err() { tracing::warn!("actor_gone_breaking_stream_loop"); From ae224046424f5310ee9e0a61beddca86c7b9f1d1 Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Tue, 2 Dec 2025 10:39:02 +0900 Subject: [PATCH 2/3] unrelated --- .../docs/developers/{12.setup.mdx => 0.setup.mdx} | 0 .../developers/{0.analytics.mdx => 12.analytics.mdx} | 0 apps/web/content/docs/developers/13.languages.mdx | 12 ++++++++++++ 3 files changed, 12 insertions(+) rename apps/web/content/docs/developers/{12.setup.mdx => 0.setup.mdx} (100%) rename apps/web/content/docs/developers/{0.analytics.mdx => 12.analytics.mdx} (100%) create mode 100644 apps/web/content/docs/developers/13.languages.mdx diff --git a/apps/web/content/docs/developers/12.setup.mdx b/apps/web/content/docs/developers/0.setup.mdx similarity index 100% rename from apps/web/content/docs/developers/12.setup.mdx rename to apps/web/content/docs/developers/0.setup.mdx diff --git a/apps/web/content/docs/developers/0.analytics.mdx b/apps/web/content/docs/developers/12.analytics.mdx similarity index 100% rename from apps/web/content/docs/developers/0.analytics.mdx rename to apps/web/content/docs/developers/12.analytics.mdx diff --git a/apps/web/content/docs/developers/13.languages.mdx b/apps/web/content/docs/developers/13.languages.mdx new file mode 100644 index 0000000000..b99fa9353c --- /dev/null +++ b/apps/web/content/docs/developers/13.languages.mdx @@ -0,0 +1,12 @@ +--- +title: "Language Support" +section: "Developers" +description: "Learn about language support in Hyprnote" +--- + +## Every Provider/Model has its own supported languages + +- Whisper (50-60 languages) +- Deepgram () + +## Every Provider/Model has different ways to specify languages From b980f80884659e69f6e7ae73ce4ddf06b9cd6ccc Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Tue, 2 Dec 2025 10:46:35 +0900 Subject: [PATCH 3/3] looks good --- Cargo.lock | 1 + crates/ws/src/client.rs | 27 +- crates/ws/tests/client_tests.rs | 8 +- owhisper/owhisper-client/Cargo.toml | 3 +- .../owhisper-client/src/adapter/argmax.rs | 87 ++- .../owhisper-client/src/adapter/deepgram.rs | 281 ++++++++- owhisper/owhisper-client/src/adapter/mod.rs | 44 ++ .../owhisper-client/src/adapter/owhisper.rs | 68 +++ .../owhisper-client/src/adapter/soniox.rs | 553 ++++++++++++++++++ owhisper/owhisper-client/src/batch.rs | 136 +---- owhisper/owhisper-client/src/lib.rs | 307 ++-------- owhisper/owhisper-client/src/live.rs | 178 ++++-- plugins/listener/src/actors/listener.rs | 28 +- plugins/listener2/src/ext.rs | 33 +- 14 files changed, 1317 insertions(+), 437 deletions(-) create mode 100644 owhisper/owhisper-client/src/adapter/owhisper.rs create mode 100644 owhisper/owhisper-client/src/adapter/soniox.rs diff --git a/Cargo.lock b/Cargo.lock index 0e6ee59022..043678c5db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10668,6 +10668,7 @@ dependencies = [ "owhisper-interface", "reqwest 0.12.24", "rodio", + "serde", "serde_json", "thiserror 2.0.17", "tokio", diff --git a/crates/ws/src/client.rs b/crates/ws/src/client.rs index 72cc363477..1a35f8b4cb 100644 --- a/crates/ws/src/client.rs +++ b/crates/ws/src/client.rs @@ -67,6 +67,7 @@ impl WebSocketClient { pub async fn from_audio( &self, + initial_message: Option, mut audio_stream: impl Stream + Send + Unpin + 'static, ) -> Result< ( @@ -96,6 +97,14 @@ impl WebSocketClient { let handle = WebSocketHandle { control_tx }; let _send_task = tokio::spawn(async move { + if let Some(msg) = initial_message { + if let Err(e) = ws_sender.send(msg).await { + tracing::error!("ws_initial_message_failed: {:?}", e); + let _ = error_tx.send(e.into()); + return; + } + } + let mut last_outbound_at = tokio::time::Instant::now(); loop { let mut keep_alive_fut = if let Some(cfg) = keep_alive_config.as_ref() { @@ -128,18 +137,14 @@ impl WebSocketClient { } last_outbound_at = tokio::time::Instant::now(); } - Some(cmd) = control_rx.recv() => { - match cmd { - ControlCommand::Finalize(maybe_msg) => { - if let Some(msg) = maybe_msg { - if let Err(e) = ws_sender.send(msg).await { - tracing::error!("ws_finalize_failed: {:?}", e); - let _ = error_tx.send(e.into()); - break; - } - last_outbound_at = tokio::time::Instant::now(); - } + Some(ControlCommand::Finalize(maybe_msg)) = control_rx.recv() => { + if let Some(msg) = maybe_msg { + if let Err(e) = ws_sender.send(msg).await { + tracing::error!("ws_finalize_failed: {:?}", e); + let _ = error_tx.send(e.into()); + break; } + last_outbound_at = tokio::time::Instant::now(); } } else => break, diff --git a/crates/ws/tests/client_tests.rs b/crates/ws/tests/client_tests.rs index 552de1c8c1..180d66b712 100644 --- a/crates/ws/tests/client_tests.rs +++ b/crates/ws/tests/client_tests.rs @@ -98,7 +98,7 @@ async fn test_basic_echo() { ]; let stream = futures_util::stream::iter(messages.clone()); - let (output, _handle) = client.from_audio::(stream).await.unwrap(); + let (output, _handle) = client.from_audio::(None, stream).await.unwrap(); let received = collect_messages::(output, 2).await; assert_eq!(received, messages); @@ -115,7 +115,7 @@ async fn test_finalize() { text: "initial".to_string(), count: 1, }]); - let (output, handle) = client.from_audio::(stream).await.unwrap(); + let (output, handle) = client.from_audio::(None, stream).await.unwrap(); let final_msg = TestMessage { text: "final".to_string(), @@ -169,7 +169,7 @@ async fn test_keep_alive() { ); let stream = futures_util::stream::pending::(); - let (output, _handle) = client.from_audio::(stream).await.unwrap(); + let (output, _handle) = client.from_audio::(None, stream).await.unwrap(); let received = collect_messages::(output, 1).await; assert_eq!(received[0].text, "done"); @@ -216,7 +216,7 @@ async fn test_retry() { text: "retry_test".to_string(), count: 1, }]); - let (output, _handle) = client.from_audio::(stream).await.unwrap(); + let (output, _handle) = client.from_audio::(None, stream).await.unwrap(); let received = collect_messages::(output, 1).await; assert_eq!(received[0].text, "retry_test"); diff --git a/owhisper/owhisper-client/Cargo.toml b/owhisper/owhisper-client/Cargo.toml index 4d56a0bae5..fb9cd8b856 100644 --- a/owhisper/owhisper-client/Cargo.toml +++ b/owhisper/owhisper-client/Cargo.toml @@ -12,11 +12,12 @@ hypr-ws = { workspace = true } owhisper-interface = { workspace = true } futures-util = { workspace = true } -reqwest = { workspace = true, features = ["json"] } +reqwest = { workspace = true, features = ["json", "multipart"] } tokio = { workspace = true } tokio-stream = { workspace = true } bytes = { workspace = true } +serde = { workspace = true } serde_json = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } diff --git a/owhisper/owhisper-client/src/adapter/argmax.rs b/owhisper/owhisper-client/src/adapter/argmax.rs index 7b801ab674..c43ec2079d 100644 --- a/owhisper/owhisper-client/src/adapter/argmax.rs +++ b/owhisper/owhisper-client/src/adapter/argmax.rs @@ -1,10 +1,93 @@ -use super::SttAdapter; +use std::path::Path; + +use hypr_ws::client::Message; +use owhisper_interface::stream::StreamResponse; +use owhisper_interface::ListenParams; + +use super::{BatchFuture, DeepgramAdapter, SttAdapter}; #[derive(Clone, Default)] -pub struct ArgmaxAdapter; +pub struct ArgmaxAdapter { + inner: DeepgramAdapter, +} impl SttAdapter for ArgmaxAdapter { fn supports_native_multichannel(&self) -> bool { false } + + fn build_ws_url(&self, api_base: &str, params: &ListenParams, channels: u8) -> url::Url { + self.inner.build_ws_url(api_base, params, channels) + } + + fn build_auth_header(&self, api_key: Option<&str>) -> Option<(&'static str, String)> { + self.inner.build_auth_header(api_key) + } + + fn keep_alive_message(&self) -> Option { + self.inner.keep_alive_message() + } + + fn finalize_message(&self) -> Message { + self.inner.finalize_message() + } + + fn parse_response(&self, raw: &str) -> Option { + self.inner.parse_response(raw) + } + + fn transcribe_file<'a, P: AsRef + Send + 'a>( + &'a self, + client: &'a reqwest::Client, + api_base: &'a str, + api_key: &'a str, + params: &'a ListenParams, + file_path: P, + ) -> BatchFuture<'a> { + self.inner + .transcribe_file(client, api_base, api_key, params, file_path) + } +} + +#[cfg(test)] +mod tests { + use super::ArgmaxAdapter; + + use futures_util::StreamExt; + use hypr_audio_utils::AudioFormatExt; + + use crate::live::ListenClientInput; + use crate::ListenClientBuilder; + + #[tokio::test] + async fn test_client() { + let audio = rodio::Decoder::new(std::io::BufReader::new( + std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), + )) + .unwrap() + .to_i16_le_chunks(16000, 16000); + + let input = Box::pin(tokio_stream::StreamExt::throttle( + audio.map(|chunk| ListenClientInput::Audio(bytes::Bytes::from(chunk.to_vec()))), + std::time::Duration::from_millis(20), + )); + + let client = ListenClientBuilder::default() + .api_base("ws://localhost:50060/v1") + .api_key("".to_string()) + .params(owhisper_interface::ListenParams { + model: Some("large-v3-v20240930_626MB".to_string()), + languages: vec![hypr_language::ISO639::En.into()], + ..Default::default() + }) + .adapter::() + .build_single(); + + let (stream, _) = client.from_realtime_audio(input).await.unwrap(); + futures_util::pin_mut!(stream); + + while let Some(result) = stream.next().await { + println!("{:?}", result); + } + } } diff --git a/owhisper/owhisper-client/src/adapter/deepgram.rs b/owhisper/owhisper-client/src/adapter/deepgram.rs index f89f6637ba..30cfee2dc7 100644 --- a/owhisper/owhisper-client/src/adapter/deepgram.rs +++ b/owhisper/owhisper-client/src/adapter/deepgram.rs @@ -1,10 +1,289 @@ -use super::SttAdapter; +use std::path::{Path, PathBuf}; + +use hypr_audio_utils::{f32_to_i16_bytes, resample_audio, source_from_path, Source}; +use hypr_ws::client::Message; +use owhisper_interface::batch::Response as BatchResponse; +use owhisper_interface::stream::StreamResponse; +use owhisper_interface::ListenParams; + +use super::{BatchFuture, SttAdapter}; +use crate::error::Error; +use crate::{append_keyword_query, append_language_query}; #[derive(Clone, Default)] pub struct DeepgramAdapter; +impl DeepgramAdapter { + fn listen_endpoint_url(api_base: &str) -> url::Url { + let mut url: url::Url = api_base.parse().expect("invalid_api_base"); + + let mut path = url.path().to_string(); + if !path.ends_with('/') { + path.push('/'); + } + path.push_str("listen"); + url.set_path(&path); + + url + } + + fn build_batch_url(api_base: &str, params: &ListenParams) -> url::Url { + let mut url = Self::listen_endpoint_url(api_base); + + { + let mut query_pairs = url.query_pairs_mut(); + + append_language_query(&mut query_pairs, params); + + let model = params.model.as_deref().unwrap_or("hypr-whisper"); + let sample_rate = params.sample_rate.to_string(); + + query_pairs.append_pair("model", model); + query_pairs.append_pair("encoding", "linear16"); + query_pairs.append_pair("sample_rate", &sample_rate); + query_pairs.append_pair("diarize", "true"); + query_pairs.append_pair("multichannel", "false"); + query_pairs.append_pair("punctuate", "true"); + query_pairs.append_pair("smart_format", "true"); + query_pairs.append_pair("utterances", "true"); + query_pairs.append_pair("numerals", "true"); + query_pairs.append_pair("filler_words", "false"); + query_pairs.append_pair("dictation", "false"); + query_pairs.append_pair("paragraphs", "false"); + query_pairs.append_pair("profanity_filter", "false"); + query_pairs.append_pair("measurements", "false"); + query_pairs.append_pair("topics", "false"); + query_pairs.append_pair("sentiment", "false"); + query_pairs.append_pair("intents", "false"); + query_pairs.append_pair("detect_entities", "false"); + query_pairs.append_pair("mip_opt_out", "true"); + + append_keyword_query(&mut query_pairs, params); + } + + url + } + + // https://developers.deepgram.com/reference/speech-to-text/listen-pre-recorded + // https://github.com/deepgram/deepgram-rust-sdk/blob/main/src/listen/rest.rs + async fn do_transcribe_file( + client: &reqwest::Client, + api_base: &str, + api_key: &str, + params: &ListenParams, + file_path: PathBuf, + ) -> Result { + let (audio_data, sample_rate) = decode_audio_to_linear16(file_path).await?; + + let url = { + let mut url = Self::build_batch_url(api_base, params); + url.query_pairs_mut() + .append_pair("sample_rate", &sample_rate.to_string()); + url + }; + + let content_type = format!("audio/raw;encoding=linear16;rate={}", sample_rate); + + let response = client + .post(url) + .header("Authorization", format!("Token {}", api_key)) + .header("Accept", "application/json") + .header("Content-Type", content_type) + .body(audio_data) + .send() + .await?; + + let status = response.status(); + if status.is_success() { + Ok(response.json().await?) + } else { + Err(Error::UnexpectedStatus { + status, + body: response.text().await.unwrap_or_default(), + }) + } + } +} + impl SttAdapter for DeepgramAdapter { fn supports_native_multichannel(&self) -> bool { true } + + fn build_ws_url(&self, api_base: &str, params: &ListenParams, channels: u8) -> url::Url { + let mut url = Self::listen_endpoint_url(api_base); + + { + let mut query_pairs = url.query_pairs_mut(); + + append_language_query(&mut query_pairs, params); + + let model = params.model.as_deref().unwrap_or("hypr-whisper"); + let channel_string = channels.to_string(); + let sample_rate = params.sample_rate.to_string(); + + query_pairs.append_pair("model", model); + query_pairs.append_pair("channels", &channel_string); + query_pairs.append_pair("filler_words", "false"); + query_pairs.append_pair("interim_results", "true"); + query_pairs.append_pair("mip_opt_out", "true"); + query_pairs.append_pair("sample_rate", &sample_rate); + query_pairs.append_pair("encoding", "linear16"); + query_pairs.append_pair("diarize", "true"); + query_pairs.append_pair("multichannel", "true"); + query_pairs.append_pair("punctuate", "true"); + query_pairs.append_pair("smart_format", "true"); + query_pairs.append_pair("vad_events", "false"); + query_pairs.append_pair("numerals", "true"); + + let redemption_time = params.redemption_time_ms.unwrap_or(400).to_string(); + query_pairs.append_pair("redemption_time_ms", &redemption_time); + + append_keyword_query(&mut query_pairs, params); + } + + if let Some(host) = url.host_str() { + if host.contains("127.0.0.1") || host.contains("localhost") || host.contains("0.0.0.0") + { + let _ = url.set_scheme("ws"); + } else { + let _ = url.set_scheme("wss"); + } + } + + url + } + + fn build_auth_header(&self, api_key: Option<&str>) -> Option<(&'static str, String)> { + api_key.map(|key| ("Authorization", format!("Token {}", key))) + } + + fn keep_alive_message(&self) -> Option { + Some(Message::Text( + serde_json::to_string(&owhisper_interface::ControlMessage::KeepAlive) + .unwrap() + .into(), + )) + } + + fn finalize_message(&self) -> Message { + Message::Text( + serde_json::to_string(&owhisper_interface::ControlMessage::Finalize) + .unwrap() + .into(), + ) + } + + fn parse_response(&self, raw: &str) -> Option { + serde_json::from_str(raw).ok() + } + + fn transcribe_file<'a, P: AsRef + Send + 'a>( + &'a self, + client: &'a reqwest::Client, + api_base: &'a str, + api_key: &'a str, + params: &'a ListenParams, + file_path: P, + ) -> BatchFuture<'a> { + let path = file_path.as_ref().to_path_buf(); + Box::pin(Self::do_transcribe_file( + client, api_base, api_key, params, path, + )) + } +} + +async fn decode_audio_to_linear16(path: PathBuf) -> Result<(bytes::Bytes, u32), Error> { + tokio::task::spawn_blocking(move || -> Result<(bytes::Bytes, u32), Error> { + let decoder = + source_from_path(&path).map_err(|err| Error::AudioProcessing(err.to_string()))?; + + let channels = decoder.channels().max(1); + let sample_rate = decoder.sample_rate(); + + let samples = resample_audio(decoder, sample_rate) + .map_err(|err| Error::AudioProcessing(err.to_string()))?; + + let samples = if channels == 1 { + samples + } else { + let channels_usize = channels as usize; + let mut mono = Vec::with_capacity(samples.len() / channels_usize); + for frame in samples.chunks(channels_usize) { + if frame.is_empty() { + continue; + } + let sum: f32 = frame.iter().copied().sum(); + mono.push(sum / frame.len() as f32); + } + mono + }; + + if samples.is_empty() { + return Err(Error::AudioProcessing( + "audio file contains no samples".to_string(), + )); + } + + let bytes = f32_to_i16_bytes(samples.into_iter()); + + Ok((bytes, sample_rate)) + }) + .await? +} + +#[cfg(test)] +mod tests { + use futures_util::StreamExt; + use hypr_audio_utils::AudioFormatExt; + + use crate::live::ListenClientInput; + use crate::ListenClient; + + #[tokio::test] + async fn test_client() { + let _ = tracing_subscriber::fmt::try_init(); + + let audio = rodio::Decoder::new(std::io::BufReader::new( + std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), + )) + .unwrap() + .to_i16_le_chunks(16000, 512); + + let input = Box::pin(tokio_stream::StreamExt::throttle( + audio.map(|chunk| ListenClientInput::Audio(chunk)), + std::time::Duration::from_millis(20), + )); + + let client = ListenClient::builder() + .api_base("https://api.deepgram.com/v1") + .api_key("71557216ffdd13bff22702be5017e4852c052b7c") + .params(owhisper_interface::ListenParams { + model: Some("nova-3".to_string()), + languages: vec![ + hypr_language::ISO639::En.into(), + hypr_language::ISO639::Es.into(), + ], + ..Default::default() + }) + .build_single(); + + let (stream, _) = client.from_realtime_audio(input).await.unwrap(); + futures_util::pin_mut!(stream); + + while let Some(result) = stream.next().await { + match result { + Ok(response) => match response { + owhisper_interface::stream::StreamResponse::TranscriptResponse { + channel, + .. + } => { + println!("{:?}", channel.alternatives.first().unwrap().transcript); + } + _ => {} + }, + _ => {} + } + } + } } diff --git a/owhisper/owhisper-client/src/adapter/mod.rs b/owhisper/owhisper-client/src/adapter/mod.rs index 4c9ade38c1..339e513829 100644 --- a/owhisper/owhisper-client/src/adapter/mod.rs +++ b/owhisper/owhisper-client/src/adapter/mod.rs @@ -1,9 +1,53 @@ mod argmax; mod deepgram; +mod owhisper; +mod soniox; pub use argmax::*; pub use deepgram::*; +pub use soniox::*; + +use std::future::Future; +use std::path::Path; +use std::pin::Pin; + +use hypr_ws::client::Message; +use owhisper_interface::batch::Response as BatchResponse; +use owhisper_interface::stream::StreamResponse; +use owhisper_interface::ListenParams; + +use crate::error::Error; + +pub type BatchFuture<'a> = Pin> + Send + 'a>>; pub trait SttAdapter: Clone + Default + Send + Sync + 'static { fn supports_native_multichannel(&self) -> bool; + + fn build_ws_url(&self, api_base: &str, params: &ListenParams, channels: u8) -> url::Url; + + fn build_auth_header(&self, api_key: Option<&str>) -> Option<(&'static str, String)>; + + fn keep_alive_message(&self) -> Option; + + fn finalize_message(&self) -> Message; + + fn initial_message( + &self, + _api_key: Option<&str>, + _params: &ListenParams, + _channels: u8, + ) -> Option { + None + } + + fn parse_response(&self, raw: &str) -> Option; + + fn transcribe_file<'a, P: AsRef + Send + 'a>( + &'a self, + client: &'a reqwest::Client, + api_base: &'a str, + api_key: &'a str, + params: &'a ListenParams, + file_path: P, + ) -> BatchFuture<'a>; } diff --git a/owhisper/owhisper-client/src/adapter/owhisper.rs b/owhisper/owhisper-client/src/adapter/owhisper.rs new file mode 100644 index 0000000000..2edf785f33 --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/owhisper.rs @@ -0,0 +1,68 @@ +#[cfg(test)] +mod tests { + use futures_util::StreamExt; + use hypr_audio_utils::AudioFormatExt; + + use crate::live::ListenClientInput; + use crate::ListenClient; + + #[tokio::test] + async fn test_owhisper_with_owhisper() { + let audio = rodio::Decoder::new(std::io::BufReader::new( + std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), + )) + .unwrap() + .to_i16_le_chunks(16000, 512); + let input = audio.map(|chunk| ListenClientInput::Audio(chunk)); + + let client = ListenClient::builder() + .api_base("ws://127.0.0.1:52693/v1") + .api_key("".to_string()) + .params(owhisper_interface::ListenParams { + model: Some("whisper-cpp-small-q8".to_string()), + languages: vec![hypr_language::ISO639::En.into()], + ..Default::default() + }) + .build_single(); + + let (stream, _) = client.from_realtime_audio(input).await.unwrap(); + futures_util::pin_mut!(stream); + + while let Some(result) = stream.next().await { + println!("{:?}", result); + } + } + + #[tokio::test] + async fn test_owhisper_with_deepgram() { + let audio = rodio::Decoder::new(std::io::BufReader::new( + std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), + )) + .unwrap() + .to_i16_le_chunks(16000, 512) + .map(Ok::<_, std::io::Error>); + + let mut stream = + deepgram::Deepgram::with_base_url_and_api_key("ws://127.0.0.1:52978", "TODO") + .unwrap() + .transcription() + .stream_request_with_options( + deepgram::common::options::Options::builder() + .language(deepgram::common::options::Language::en) + .model(deepgram::common::options::Model::CustomId( + "whisper-cpp-small-q8".to_string(), + )) + .build(), + ) + .channels(1) + .encoding(deepgram::common::options::Encoding::Linear16) + .sample_rate(16000) + .stream(audio) + .await + .unwrap(); + + while let Some(result) = stream.next().await { + println!("{:?}", result); + } + } +} diff --git a/owhisper/owhisper-client/src/adapter/soniox.rs b/owhisper/owhisper-client/src/adapter/soniox.rs new file mode 100644 index 0000000000..f90a6aaad4 --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/soniox.rs @@ -0,0 +1,553 @@ +use std::path::Path; +use std::time::Duration; + +use hypr_ws::client::Message; +use owhisper_interface::batch::{ + Alternatives as BatchAlternatives, Channel as BatchChannel, Response as BatchResponse, + Results as BatchResults, Word as BatchWord, +}; +use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse, Word}; +use owhisper_interface::ListenParams; +use serde::{Deserialize, Serialize}; + +use super::{BatchFuture, SttAdapter}; +use crate::error::Error; + +const DEFAULT_API_BASE: &str = "https://api.soniox.com"; +const POLL_INTERVAL: Duration = Duration::from_secs(2); +const MAX_POLL_ATTEMPTS: u32 = 300; + +#[derive(Clone, Default)] +pub struct SonioxAdapter; + +impl SonioxAdapter { + fn language_hints(params: &ListenParams) -> Vec { + params + .languages + .iter() + .map(|lang| lang.iso639().code().to_string()) + .collect() + } + + fn api_base_url(api_base: &str) -> String { + if api_base.is_empty() { + DEFAULT_API_BASE.to_string() + } else { + api_base.trim_end_matches('/').to_string() + } + } + + async fn upload_file( + client: &reqwest::Client, + api_base: &str, + api_key: &str, + file_path: &Path, + ) -> Result { + let file_name = file_path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("audio.wav") + .to_string(); + + let file_bytes = tokio::fs::read(file_path).await.map_err(|e| { + Error::AudioProcessing(format!( + "failed to read file {}: {}", + file_path.display(), + e + )) + })?; + + let part = reqwest::multipart::Part::bytes(file_bytes).file_name(file_name); + let form = reqwest::multipart::Form::new().part("file", part); + + let url = format!("{}/v1/files", Self::api_base_url(api_base)); + let response = client + .post(&url) + .header("Authorization", format!("Bearer {}", api_key)) + .multipart(form) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + return Err(Error::UnexpectedStatus { status, body }); + } + + #[derive(Deserialize)] + struct FileUploadResponse { + id: String, + } + + let upload_response: FileUploadResponse = response.json().await?; + Ok(upload_response.id) + } + + async fn create_transcription( + client: &reqwest::Client, + api_base: &str, + api_key: &str, + params: &ListenParams, + file_id: &str, + ) -> Result { + #[derive(Serialize)] + struct CreateTranscriptionRequest<'a> { + model: &'a str, + file_id: &'a str, + #[serde(skip_serializing_if = "Vec::is_empty")] + language_hints: Vec, + enable_speaker_diarization: bool, + enable_language_identification: bool, + } + + let model = params.model.as_deref().unwrap_or("stt-async-preview"); + + let request = CreateTranscriptionRequest { + model, + file_id, + language_hints: Self::language_hints(params), + enable_speaker_diarization: true, + enable_language_identification: true, + }; + + let url = format!("{}/v1/transcriptions", Self::api_base_url(api_base)); + let response = client + .post(&url) + .header("Authorization", format!("Bearer {}", api_key)) + .json(&request) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + return Err(Error::UnexpectedStatus { status, body }); + } + + #[derive(Deserialize)] + struct TranscriptionResponse { + id: String, + } + + let transcription: TranscriptionResponse = response.json().await?; + Ok(transcription.id) + } + + async fn poll_transcription( + client: &reqwest::Client, + api_base: &str, + api_key: &str, + transcription_id: &str, + ) -> Result<(), Error> { + #[derive(Deserialize)] + struct TranscriptionResponse { + status: String, + #[serde(default)] + error_message: Option, + } + + let url = format!( + "{}/v1/transcriptions/{}", + Self::api_base_url(api_base), + transcription_id + ); + + for attempt in 0..MAX_POLL_ATTEMPTS { + let response = client + .get(&url) + .header("Authorization", format!("Bearer {}", api_key)) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + return Err(Error::UnexpectedStatus { status, body }); + } + + let transcription: TranscriptionResponse = response.json().await?; + + match transcription.status.as_str() { + "completed" => return Ok(()), + "error" => { + let error_msg = transcription + .error_message + .unwrap_or_else(|| "unknown error".to_string()); + return Err(Error::AudioProcessing(format!( + "transcription failed: {}", + error_msg + ))); + } + "queued" | "processing" => { + tracing::debug!( + attempt = attempt, + status = transcription.status, + "polling transcription status" + ); + tokio::time::sleep(POLL_INTERVAL).await; + } + unknown => { + return Err(Error::AudioProcessing(format!( + "unexpected transcription status: {}", + unknown + ))); + } + } + } + + Err(Error::AudioProcessing(format!( + "transcription timed out after {} attempts", + MAX_POLL_ATTEMPTS + ))) + } + + async fn get_transcript( + client: &reqwest::Client, + api_base: &str, + api_key: &str, + transcription_id: &str, + ) -> Result { + #[derive(Deserialize)] + struct TranscriptResponse { + text: String, + tokens: Vec, + } + + #[derive(Deserialize)] + struct TranscriptToken { + text: String, + #[serde(default)] + start_ms: Option, + #[serde(default)] + end_ms: Option, + #[serde(default)] + confidence: Option, + #[serde(default)] + speaker: Option, + } + + #[derive(Deserialize)] + #[serde(untagged)] + enum BatchSpeakerId { + Num(i32), + Str(String), + } + + impl BatchSpeakerId { + fn as_usize(&self) -> Option { + match self { + BatchSpeakerId::Num(n) if *n >= 0 => Some(*n as usize), + BatchSpeakerId::Num(_) => None, + BatchSpeakerId::Str(s) => s + .trim_start_matches(|c: char| !c.is_ascii_digit()) + .parse() + .ok(), + } + } + } + + let url = format!( + "{}/v1/transcriptions/{}/transcript", + Self::api_base_url(api_base), + transcription_id + ); + + let response = client + .get(&url) + .header("Authorization", format!("Bearer {}", api_key)) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + return Err(Error::UnexpectedStatus { status, body }); + } + + let transcript: TranscriptResponse = response.json().await?; + + let words: Vec = transcript + .tokens + .iter() + .map(|token| BatchWord { + word: token.text.clone(), + start: token.start_ms.unwrap_or(0) as f64 / 1000.0, + end: token.end_ms.unwrap_or(0) as f64 / 1000.0, + confidence: token.confidence.unwrap_or(1.0), + speaker: token.speaker.as_ref().and_then(|s| s.as_usize()), + punctuated_word: Some(token.text.clone()), + }) + .collect(); + + let alternatives = BatchAlternatives { + transcript: transcript.text, + confidence: 1.0, + words, + }; + + let channel = BatchChannel { + alternatives: vec![alternatives], + }; + + Ok(BatchResponse { + metadata: serde_json::json!({}), + results: BatchResults { + channels: vec![channel], + }, + }) + } + + async fn do_transcribe_file( + client: &reqwest::Client, + api_base: &str, + api_key: &str, + params: &ListenParams, + file_path: &Path, + ) -> Result { + tracing::info!(path = %file_path.display(), "uploading file to Soniox"); + + let file_id = Self::upload_file(client, api_base, api_key, file_path).await?; + tracing::info!(file_id = %file_id, "file uploaded, creating transcription"); + + let transcription_id = + Self::create_transcription(client, api_base, api_key, params, &file_id).await?; + tracing::info!(transcription_id = %transcription_id, "transcription created, polling for completion"); + + Self::poll_transcription(client, api_base, api_key, &transcription_id).await?; + tracing::info!(transcription_id = %transcription_id, "transcription completed, fetching transcript"); + + let response = Self::get_transcript(client, api_base, api_key, &transcription_id).await?; + tracing::info!("transcript fetched successfully"); + + Ok(response) + } +} + +impl SttAdapter for SonioxAdapter { + fn supports_native_multichannel(&self) -> bool { + true + } + + fn build_ws_url(&self, api_base: &str, _params: &ListenParams, _channels: u8) -> url::Url { + let mut url: url::Url = api_base.parse().expect("invalid api_base"); + + match url.scheme() { + "http" => { + let _ = url.set_scheme("ws"); + } + "https" => { + let _ = url.set_scheme("wss"); + } + "ws" | "wss" => {} + _ => { + let _ = url.set_scheme("wss"); + } + } + + url + } + + fn build_auth_header(&self, _api_key: Option<&str>) -> Option<(&'static str, String)> { + None + } + + fn keep_alive_message(&self) -> Option { + Some(Message::Text(r#"{"type":"keepalive"}"#.into())) + } + + fn initial_message( + &self, + api_key: Option<&str>, + params: &ListenParams, + channels: u8, + ) -> Option { + let api_key = match api_key { + Some(key) => key, + None => { + tracing::warn!("soniox_api_key_missing"); + return None; + } + }; + + #[derive(Serialize)] + struct SonioxConfig<'a> { + api_key: &'a str, + model: &'a str, + audio_format: &'a str, + num_channels: u8, + sample_rate: u32, + #[serde(skip_serializing_if = "Vec::is_empty")] + language_hints: Vec, + include_nonfinal: bool, + enable_endpoint_detection: bool, + enable_speaker_diarization: bool, + } + + let model = params.model.as_deref().unwrap_or("stt-rt-preview"); + + let cfg = SonioxConfig { + api_key, + model, + audio_format: "pcm_s16le", + num_channels: channels, + sample_rate: params.sample_rate, + language_hints: Self::language_hints(params), + include_nonfinal: true, + enable_endpoint_detection: true, + enable_speaker_diarization: true, + }; + + let json = serde_json::to_string(&cfg).unwrap(); + Some(Message::Text(json.into())) + } + + fn parse_response(&self, raw: &str) -> Option { + #[derive(Deserialize)] + struct Token { + text: String, + #[serde(default)] + start_ms: Option, + #[serde(default)] + end_ms: Option, + #[serde(default)] + confidence: Option, + #[serde(default)] + is_final: Option, + #[serde(default)] + speaker: Option, + #[serde(default)] + channel: Option, + } + + #[derive(Deserialize)] + #[serde(untagged)] + enum SpeakerId { + Num(i32), + Str(String), + } + + impl SpeakerId { + fn as_i32(&self) -> Option { + match self { + SpeakerId::Num(n) => Some(*n), + SpeakerId::Str(s) => s + .trim_start_matches(|c: char| !c.is_ascii_digit()) + .parse() + .ok(), + } + } + } + + #[derive(Deserialize)] + struct SonioxMessage { + #[serde(default)] + tokens: Vec, + #[serde(default)] + finished: Option, + #[serde(default)] + error: Option, + } + + let msg: SonioxMessage = match serde_json::from_str(raw) { + Ok(m) => m, + Err(e) => { + tracing::warn!(error = ?e, raw = raw, "soniox_json_parse_failed"); + return None; + } + }; + + if let Some(error) = msg.error { + tracing::error!(error = error, "soniox_error"); + return None; + } + + let has_fin_token = msg.tokens.iter().any(|t| t.text == ""); + let is_finished = msg.finished.unwrap_or(false) || has_fin_token; + + let content_tokens: Vec<_> = msg + .tokens + .iter() + .filter(|t| t.text != "" && t.text != "") + .collect(); + + if content_tokens.is_empty() && !is_finished { + return None; + } + + let all_final = content_tokens.iter().all(|t| t.is_final.unwrap_or(true)); + + let mut words = Vec::with_capacity(content_tokens.len()); + let mut transcript = String::new(); + + let channel_index = content_tokens.first().and_then(|t| t.channel).unwrap_or(0) as i32; + + for t in &content_tokens { + if !transcript.is_empty() && !t.text.starts_with(|c: char| c.is_ascii_punctuation()) { + transcript.push(' '); + } + transcript.push_str(&t.text); + + let start_secs = t.start_ms.unwrap_or(0) as f64 / 1000.0; + let end_secs = t.end_ms.unwrap_or(0) as f64 / 1000.0; + let speaker = t.speaker.as_ref().and_then(|s| s.as_i32()); + + words.push(Word { + word: t.text.clone(), + start: start_secs, + end: end_secs, + confidence: t.confidence.unwrap_or(1.0), + speaker, + punctuated_word: Some(t.text.clone()), + language: None, + }); + } + + let (start, duration) = + if let (Some(first), Some(last)) = (content_tokens.first(), content_tokens.last()) { + let start_secs = first.start_ms.unwrap_or(0) as f64 / 1000.0; + let end_secs = last.end_ms.unwrap_or(0) as f64 / 1000.0; + (start_secs, end_secs - start_secs) + } else { + (0.0, 0.0) + }; + + let channel = Channel { + alternatives: vec![Alternatives { + transcript, + words, + confidence: 1.0, + languages: vec![], + }], + }; + + Some(StreamResponse::TranscriptResponse { + is_final: all_final || is_finished, + speech_final: is_finished, + from_finalize: has_fin_token, + start, + duration, + channel, + metadata: Metadata::default(), + // TODO + channel_index: vec![channel_index, 1], + }) + } + + fn finalize_message(&self) -> Message { + Message::Text(r#"{"type":"finalize"}"#.into()) + } + + fn transcribe_file<'a, P: AsRef + Send + 'a>( + &'a self, + client: &'a reqwest::Client, + api_base: &'a str, + api_key: &'a str, + params: &'a ListenParams, + file_path: P, + ) -> BatchFuture<'a> { + let path = file_path.as_ref().to_path_buf(); + Box::pin( + async move { Self::do_transcribe_file(client, api_base, api_key, params, &path).await }, + ) + } +} diff --git a/owhisper/owhisper-client/src/batch.rs b/owhisper/owhisper-client/src/batch.rs index b9c7dca03b..08f75545cd 100644 --- a/owhisper/owhisper-client/src/batch.rs +++ b/owhisper/owhisper-client/src/batch.rs @@ -1,118 +1,46 @@ -use std::path::{Path, PathBuf}; -use tokio::task; +use std::marker::PhantomData; +use std::path::Path; -use hypr_audio_utils::{f32_to_i16_bytes, resample_audio, source_from_path, Source}; use owhisper_interface::batch::Response as BatchResponse; +use owhisper_interface::ListenParams; -use crate::{error::Error, ListenClientBuilder}; +use crate::adapter::SttAdapter; +use crate::error::Error; +use crate::DeepgramAdapter; -// https://developers.deepgram.com/reference/speech-to-text/listen-pre-recorded -// https://github.com/deepgram/deepgram-rust-sdk/blob/main/src/listen/rest.rs #[derive(Clone)] -pub struct BatchClient { - pub(crate) client: reqwest::Client, - pub(crate) url: url::Url, - pub(crate) api_key: Option, +pub struct BatchClient { + client: reqwest::Client, + api_base: String, + api_key: String, + params: ListenParams, + _marker: PhantomData, } -impl BatchClient { - pub fn builder() -> ListenClientBuilder { - ListenClientBuilder::default() +impl BatchClient { + pub fn new(api_base: String, api_key: String, params: ListenParams) -> Self { + Self { + client: reqwest::Client::new(), + api_base, + api_key, + params, + _marker: PhantomData, + } } - pub async fn transcribe_file>( + pub async fn transcribe_file + Send>( &self, file_path: P, ) -> Result { - let path = file_path.as_ref(); - let (audio_data, sample_rate) = decode_audio_to_linear16(path.to_path_buf()).await?; - - let params = { - let mut params: Vec<(String, String)> = vec![]; - params.retain(|(key, _)| key != "channels"); - - params.push(("sample_rate".to_string(), sample_rate.to_string())); - params.push(("multichannel".to_string(), "false".to_string())); - params.push(("diarize".to_string(), "true".to_string())); - params.push(("detect_language".to_string(), "true".to_string())); - params - }; - - let url = { - let mut url = self.url.clone(); - - let mut serializer = url::form_urlencoded::Serializer::new(String::new()); - for (key, value) in params { - serializer.append_pair(&key, &value); - } - - let query = serializer.finish(); - url.set_query(Some(&query)); - url - }; - - let mut request = self.client.post(url); - - if let Some(key) = &self.api_key { - request = request.header("Authorization", format!("Token {}", key)); - } - - let content_type = format!("audio/raw;encoding=linear16;rate={}", sample_rate); - - let response = request - .header("Accept", "application/json") - .header("Content-Type", content_type) - .body(audio_data) - .send() - .await?; - - let status = response.status(); - if status.is_success() { - Ok(response.json().await?) - } else { - Err(Error::UnexpectedStatus { - status, - body: response.text().await.unwrap_or_default(), - }) - } + let adapter = A::default(); + adapter + .transcribe_file( + &self.client, + &self.api_base, + &self.api_key, + &self.params, + file_path, + ) + .await } } - -async fn decode_audio_to_linear16(path: PathBuf) -> Result<(bytes::Bytes, u32), Error> { - task::spawn_blocking(move || -> Result<(bytes::Bytes, u32), Error> { - let decoder = - source_from_path(&path).map_err(|err| Error::AudioProcessing(err.to_string()))?; - - let channels = decoder.channels().max(1); - let sample_rate = decoder.sample_rate(); - - let samples = resample_audio(decoder, sample_rate) - .map_err(|err| Error::AudioProcessing(err.to_string()))?; - - let samples = if channels == 1 { - samples - } else { - let channels_usize = channels as usize; - let mut mono = Vec::with_capacity(samples.len() / channels_usize); - for frame in samples.chunks(channels_usize) { - if frame.is_empty() { - continue; - } - let sum: f32 = frame.iter().copied().sum(); - mono.push(sum / frame.len() as f32); - } - mono - }; - - if samples.is_empty() { - return Err(Error::AudioProcessing( - "audio file contains no samples".to_string(), - )); - } - - let bytes = f32_to_i16_bytes(samples.into_iter()); - - Ok((bytes, sample_rate)) - }) - .await? -} diff --git a/owhisper/owhisper-client/src/lib.rs b/owhisper/owhisper-client/src/lib.rs index 8950d5a297..5b6012730c 100644 --- a/owhisper/owhisper-client/src/lib.rs +++ b/owhisper/owhisper-client/src/lib.rs @@ -8,7 +8,7 @@ use std::marker::PhantomData; use url::form_urlencoded::Serializer; use url::UrlQuery; -pub use adapter::{ArgmaxAdapter, DeepgramAdapter, SttAdapter}; +pub use adapter::{ArgmaxAdapter, DeepgramAdapter, SonioxAdapter, SttAdapter}; pub use batch::BatchClient; pub use error::Error; pub use hypr_ws; @@ -57,141 +57,50 @@ impl ListenClientBuilder { } } - fn listen_endpoint_url(&self) -> url::Url { - let mut url: url::Url = self - .api_base - .as_ref() - .expect("api_base is required") - .parse() - .expect("invalid api_base"); - - let mut path = url.path().to_string(); - if !path.ends_with('/') { - path.push('/'); - } - path.push_str("listen"); - url.set_path(&path); - - url + fn get_api_base(&self) -> &str { + self.api_base.as_ref().expect("api_base is required") } - pub(crate) fn build_batch_url(&self) -> url::Url { - let params = self.params.clone().unwrap_or_default(); - let mut url = self.listen_endpoint_url(); - - { - let mut query_pairs = url.query_pairs_mut(); - - append_language_query(&mut query_pairs, ¶ms); - - let model = params.model.as_deref().unwrap_or("hypr-whisper"); - let sample_rate = params.sample_rate.to_string(); - - query_pairs.append_pair("model", model); - query_pairs.append_pair("encoding", "linear16"); - query_pairs.append_pair("sample_rate", &sample_rate); - query_pairs.append_pair("diarize", "true"); - query_pairs.append_pair("multichannel", "false"); - query_pairs.append_pair("punctuate", "true"); - query_pairs.append_pair("smart_format", "true"); - query_pairs.append_pair("utterances", "true"); - query_pairs.append_pair("numerals", "true"); - query_pairs.append_pair("filler_words", "false"); - query_pairs.append_pair("dictation", "false"); - query_pairs.append_pair("paragraphs", "false"); - query_pairs.append_pair("profanity_filter", "false"); - query_pairs.append_pair("measurements", "false"); - query_pairs.append_pair("topics", "false"); - query_pairs.append_pair("sentiment", "false"); - query_pairs.append_pair("intents", "false"); - query_pairs.append_pair("detect_entities", "false"); - query_pairs.append_pair("mip_opt_out", "true"); - - append_keyword_query(&mut query_pairs, ¶ms); - } - - url + fn get_params(&self) -> owhisper_interface::ListenParams { + self.params.clone().unwrap_or_default() } - pub(crate) fn build_url(&self, channels: u8) -> url::Url { - let mut params = self.params.clone().unwrap_or_default(); - params.channels = channels; + fn build_request(&self, adapter: &A, channels: u8) -> hypr_ws::client::ClientRequestBuilder { + let params = self.get_params(); + let url = adapter.build_ws_url(self.get_api_base(), ¶ms, channels); + let uri = url.to_string().parse().unwrap(); - let mut url = self.listen_endpoint_url(); + let mut request = hypr_ws::client::ClientRequestBuilder::new(uri); + if let Some((header_name, header_value)) = + adapter.build_auth_header(self.api_key.as_deref()) { - let mut query_pairs = url.query_pairs_mut(); - - append_language_query(&mut query_pairs, ¶ms); - - let model = params.model.as_deref().unwrap_or("hypr-whisper"); - let channel_string = channels.to_string(); - let sample_rate = params.sample_rate.to_string(); - - query_pairs.append_pair("model", model); - query_pairs.append_pair("channels", &channel_string); - query_pairs.append_pair("filler_words", "false"); - query_pairs.append_pair("interim_results", "true"); - query_pairs.append_pair("mip_opt_out", "true"); - query_pairs.append_pair("sample_rate", &sample_rate); - query_pairs.append_pair("encoding", "linear16"); - query_pairs.append_pair("diarize", "true"); - query_pairs.append_pair("multichannel", "true"); - query_pairs.append_pair("punctuate", "true"); - query_pairs.append_pair("smart_format", "true"); - query_pairs.append_pair("vad_events", "false"); - query_pairs.append_pair("numerals", "true"); - - let redemption_time = params.redemption_time_ms.unwrap_or(400).to_string(); - query_pairs.append_pair("redemption_time_ms", &redemption_time); - - append_keyword_query(&mut query_pairs, ¶ms); - } - - url - } - - pub(crate) fn build_uri(&self, channels: u8) -> String { - let mut url = self.build_url(channels); - - if let Some(host) = url.host_str() { - if host.contains("127.0.0.1") || host.contains("localhost") || host.contains("0.0.0.0") - { - let _ = url.set_scheme("ws"); - } else { - let _ = url.set_scheme("wss"); - } + request = request.with_header(header_name, header_value); } - url.to_string() + request } - pub(crate) fn build_request(&self, channels: u8) -> hypr_ws::client::ClientRequestBuilder { - let uri = self.build_uri(channels).parse().unwrap(); - - match &self.api_key { - Some(key) => hypr_ws::client::ClientRequestBuilder::new(uri) - .with_header("Authorization", format!("Token {}", key)), - None => hypr_ws::client::ClientRequestBuilder::new(uri), + pub fn build_with_channels(self, channels: u8) -> ListenClient { + let adapter = A::default(); + let params = self.get_params(); + let request = self.build_request(&adapter, channels); + let initial_message = adapter.initial_message(self.api_key.as_deref(), ¶ms, channels); + + ListenClient { + adapter, + request, + initial_message, } } - pub fn build_with_channels(self, channels: u8) -> ListenClient { - let request = self.build_request(channels); - ListenClient { request } + pub fn build_batch(self) -> BatchClient { + let params = self.get_params(); + let api_base = self.get_api_base().to_string(); + BatchClient::new(api_base, self.api_key.unwrap_or_default(), params) } - pub fn build_batch(self) -> BatchClient { - let url = self.build_batch_url(); - - BatchClient { - client: reqwest::Client::new(), - url, - api_key: self.api_key, - } - } - - pub fn build_single(self) -> ListenClient { + pub fn build_single(self) -> ListenClient { self.build_with_channels(1) } @@ -202,8 +111,15 @@ impl ListenClientBuilder { } else { 1 }; - let request = self.build_request(channels); - ListenClientDual { adapter, request } + let params = self.get_params(); + let request = self.build_request(&adapter, channels); + let initial_message = adapter.initial_message(self.api_key.as_deref(), ¶ms, channels); + + ListenClientDual { + adapter, + request, + initial_message, + } } } @@ -252,150 +168,3 @@ pub(crate) fn append_keyword_query<'a>( query_pairs.append_pair(param_name, keyword); } } - -#[cfg(test)] -mod tests { - use super::*; - - use futures_util::StreamExt; - use hypr_audio_utils::AudioFormatExt; - use live::ListenClientInput; - - #[tokio::test] - async fn test_client_deepgram() { - let _ = tracing_subscriber::fmt::try_init(); - - let audio = rodio::Decoder::new(std::io::BufReader::new( - std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), - )) - .unwrap() - .to_i16_le_chunks(16000, 512); - - let input = Box::pin(tokio_stream::StreamExt::throttle( - audio.map(|chunk| ListenClientInput::Audio(chunk)), - std::time::Duration::from_millis(20), - )); - - let client = ListenClient::builder() - .api_base("https://api.deepgram.com/v1") - .api_key(std::env::var("DEEPGRAM_API_KEY").unwrap()) - .params(owhisper_interface::ListenParams { - model: Some("nova-3".to_string()), - languages: vec![ - hypr_language::ISO639::En.into(), - hypr_language::ISO639::Es.into(), - ], - ..Default::default() - }) - .build_single(); - - let (stream, _) = client.from_realtime_audio(input).await.unwrap(); - futures_util::pin_mut!(stream); - - while let Some(result) = stream.next().await { - match result { - Ok(response) => match response { - owhisper_interface::stream::StreamResponse::TranscriptResponse { - channel, - .. - } => { - println!("{:?}", channel.alternatives.first().unwrap().transcript); - } - _ => {} - }, - _ => {} - } - } - } - - #[tokio::test] - async fn test_owhisper_with_owhisper() { - let audio = rodio::Decoder::new(std::io::BufReader::new( - std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), - )) - .unwrap() - .to_i16_le_chunks(16000, 512); - let input = audio.map(|chunk| ListenClientInput::Audio(chunk)); - - let client = ListenClient::builder() - .api_base("ws://127.0.0.1:52693/v1") - .api_key("".to_string()) - .params(owhisper_interface::ListenParams { - model: Some("whisper-cpp-small-q8".to_string()), - languages: vec![hypr_language::ISO639::En.into()], - ..Default::default() - }) - .build_single(); - - let (stream, _) = client.from_realtime_audio(input).await.unwrap(); - futures_util::pin_mut!(stream); - - while let Some(result) = stream.next().await { - println!("{:?}", result); - } - } - - #[tokio::test] - async fn test_owhisper_with_deepgram() { - let audio = rodio::Decoder::new(std::io::BufReader::new( - std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), - )) - .unwrap() - .to_i16_le_chunks(16000, 512) - .map(Ok::<_, std::io::Error>); - - let mut stream = - deepgram::Deepgram::with_base_url_and_api_key("ws://127.0.0.1:52978", "TODO") - .unwrap() - .transcription() - .stream_request_with_options( - deepgram::common::options::Options::builder() - .language(deepgram::common::options::Language::en) - .model(deepgram::common::options::Model::CustomId( - "whisper-cpp-small-q8".to_string(), - )) - .build(), - ) - .channels(1) - .encoding(deepgram::common::options::Encoding::Linear16) - .sample_rate(16000) - .stream(audio) - .await - .unwrap(); - - while let Some(result) = stream.next().await { - println!("{:?}", result); - } - } - - #[tokio::test] - async fn test_client_ag() { - let audio = rodio::Decoder::new(std::io::BufReader::new( - std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), - )) - .unwrap() - .to_i16_le_chunks(16000, 16000); - - let input = Box::pin(tokio_stream::StreamExt::throttle( - audio.map(|chunk| ListenClientInput::Audio(bytes::Bytes::from(chunk.to_vec()))), - std::time::Duration::from_millis(20), - )); - - let client = ListenClient::builder() - .api_base("ws://localhost:50060/v1") - .api_key("".to_string()) - .params(owhisper_interface::ListenParams { - model: Some("large-v3-v20240930_626MB".to_string()), - languages: vec![hypr_language::ISO639::En.into()], - ..Default::default() - }) - .build_single(); - - let (stream, _) = client.from_realtime_audio(input).await.unwrap(); - futures_util::pin_mut!(stream); - - while let Some(result) = stream.next().await { - println!("{:?}", result); - } - } -} diff --git a/owhisper/owhisper-client/src/live.rs b/owhisper/owhisper-client/src/live.rs index 7a3a5bdd76..43e3b3b1b3 100644 --- a/owhisper/owhisper-client/src/live.rs +++ b/owhisper/owhisper-client/src/live.rs @@ -9,38 +9,52 @@ use hypr_ws::client::{ use owhisper_interface::stream::StreamResponse; use owhisper_interface::{ControlMessage, MixedMessage}; -use crate::{ListenClientBuilder, SttAdapter}; +use crate::{DeepgramAdapter, ListenClientBuilder, SttAdapter}; pub type ListenClientInput = MixedMessage; pub type ListenClientDualInput = MixedMessage<(bytes::Bytes, bytes::Bytes), ControlMessage>; #[derive(Clone)] -pub struct ListenClient { +pub struct ListenClient { + pub(crate) adapter: A, pub(crate) request: ClientRequestBuilder, + pub(crate) initial_message: Option, } #[derive(Clone)] pub struct ListenClientDual { pub(crate) adapter: A, pub(crate) request: ClientRequestBuilder, + pub(crate) initial_message: Option, +} + +pub struct SingleHandle { + inner: WebSocketHandle, + finalize_text: Utf8Bytes, } pub enum DualHandle { - Native(WebSocketHandle), + Native { + inner: WebSocketHandle, + finalize_text: Utf8Bytes, + }, Split { mic: WebSocketHandle, spk: WebSocketHandle, + finalize_text: Utf8Bytes, }, } pub trait FinalizeHandle: Send { - fn finalize_with_text(&self, text: Utf8Bytes) -> impl std::future::Future + Send; + fn finalize(&self) -> impl std::future::Future + Send; fn expected_finalize_count(&self) -> usize; } -impl FinalizeHandle for WebSocketHandle { - async fn finalize_with_text(&self, text: Utf8Bytes) { - self.finalize_with_text(text).await +impl FinalizeHandle for SingleHandle { + async fn finalize(&self) { + self.inner + .finalize_with_text(self.finalize_text.clone()) + .await } fn expected_finalize_count(&self) -> usize { @@ -49,14 +63,20 @@ impl FinalizeHandle for WebSocketHandle { } impl FinalizeHandle for DualHandle { - async fn finalize_with_text(&self, text: Utf8Bytes) { + async fn finalize(&self) { match self { - DualHandle::Native(h) => h.finalize_with_text(text).await, - DualHandle::Split { mic, spk } => { - let text_clone = text.clone(); + DualHandle::Native { + inner, + finalize_text, + } => inner.finalize_with_text(finalize_text.clone()).await, + DualHandle::Split { + mic, + spk, + finalize_text, + } => { tokio::join!( - WebSocketHandle::finalize_with_text(mic, text), - WebSocketHandle::finalize_with_text(spk, text_clone) + mic.finalize_with_text(finalize_text.clone()), + spk.finalize_with_text(finalize_text.clone()) ); } } @@ -64,7 +84,7 @@ impl FinalizeHandle for DualHandle { fn expected_finalize_count(&self) -> usize { match self { - DualHandle::Native(_) => 1, + DualHandle::Native { .. } => 1, DualHandle::Split { .. } => 2, } } @@ -98,7 +118,7 @@ pub struct ListenClientIO; impl WebSocketIO for ListenClientIO { type Data = ListenClientInput; type Input = ListenClientInput; - type Output = StreamResponse; + type Output = String; fn to_input(data: Self::Data) -> Self::Input { data @@ -115,7 +135,7 @@ impl WebSocketIO for ListenClientIO { fn from_message(msg: Message) -> Option { match msg { - Message::Text(text) => serde_json::from_str::(&text).ok(), + Message::Text(text) => Some(text.to_string()), _ => None, } } @@ -126,7 +146,7 @@ pub struct ListenClientDualIO; impl WebSocketIO for ListenClientDualIO { type Data = ListenClientDualInput; type Input = ListenClientInput; - type Output = StreamResponse; + type Output = String; fn to_input(data: Self::Data) -> Self::Input { match data { @@ -149,29 +169,51 @@ impl WebSocketIO for ListenClientDualIO { fn from_message(msg: Message) -> Option { match msg { - Message::Text(text) => serde_json::from_str::(&text).ok(), + Message::Text(text) => Some(text.to_string()), _ => None, } } } -impl ListenClient { - pub fn builder() -> ListenClientBuilder { +impl ListenClient { + pub fn builder() -> ListenClientBuilder { ListenClientBuilder::default() } +} +impl ListenClient { pub async fn from_realtime_audio( self, audio_stream: impl Stream + Send + Unpin + 'static, ) -> Result< ( impl Stream>, - hypr_ws::client::WebSocketHandle, + SingleHandle, ), hypr_ws::Error, > { - let ws = websocket_client_with_keep_alive(&self.request); - ws.from_audio::(audio_stream).await + let finalize_text = extract_finalize_text(&self.adapter); + let ws = websocket_client_with_keep_alive(&self.request, &self.adapter); + let (raw_stream, inner) = ws + .from_audio::(self.initial_message, audio_stream) + .await?; + + let adapter = self.adapter; + let mapped_stream = raw_stream.filter_map(move |result| { + let adapter = adapter.clone(); + async move { + match result { + Ok(raw) => adapter.parse_response(&raw).map(Ok), + Err(e) => Some(Err(e)), + } + } + }); + + let handle = SingleHandle { + inner, + finalize_text, + }; + Ok((mapped_stream, handle)) } } @@ -193,32 +235,80 @@ impl ListenClientDual { self, stream: impl Stream + Send + Unpin + 'static, ) -> Result<(DualOutputStream, DualHandle), hypr_ws::Error> { - let ws = websocket_client_with_keep_alive(&self.request); - let (output_stream, handle) = ws.from_audio::(stream).await?; - Ok((Box::pin(output_stream), DualHandle::Native(handle))) + let finalize_text = extract_finalize_text(&self.adapter); + let ws = websocket_client_with_keep_alive(&self.request, &self.adapter); + let (raw_stream, inner) = ws + .from_audio::(self.initial_message, stream) + .await?; + + let adapter = self.adapter; + let mapped_stream = raw_stream.filter_map(move |result| { + let adapter = adapter.clone(); + async move { + match result { + Ok(raw) => adapter.parse_response(&raw).map(Ok), + Err(e) => Some(Err(e)), + } + } + }); + + let handle = DualHandle::Native { + inner, + finalize_text, + }; + Ok((Box::pin(mapped_stream), handle)) } async fn from_realtime_audio_split( self, stream: impl Stream + Send + Unpin + 'static, ) -> Result<(DualOutputStream, DualHandle), hypr_ws::Error> { + let finalize_text = extract_finalize_text(&self.adapter); let (mic_tx, mic_rx) = tokio::sync::mpsc::channel::(32); let (spk_tx, spk_rx) = tokio::sync::mpsc::channel::(32); - let mic_ws = websocket_client_with_keep_alive(&self.request); - let spk_ws = websocket_client_with_keep_alive(&self.request); + let mic_ws = websocket_client_with_keep_alive(&self.request, &self.adapter); + let spk_ws = websocket_client_with_keep_alive(&self.request, &self.adapter); let mic_outbound = tokio_stream::wrappers::ReceiverStream::new(mic_rx); let spk_outbound = tokio_stream::wrappers::ReceiverStream::new(spk_rx); - let mic_connect = mic_ws.from_audio::(mic_outbound); - let spk_connect = spk_ws.from_audio::(spk_outbound); + let mic_connect = + mic_ws.from_audio::(self.initial_message.clone(), mic_outbound); + let spk_connect = spk_ws.from_audio::(self.initial_message, spk_outbound); - let ((mic_stream, mic_handle), (spk_stream, spk_handle)) = + let ((mic_raw, mic_handle), (spk_raw, spk_handle)) = tokio::try_join!(mic_connect, spk_connect)?; tokio::spawn(forward_dual_to_single(stream, mic_tx, spk_tx)); + let adapter = self.adapter.clone(); + let mic_stream = mic_raw.filter_map({ + let adapter = adapter.clone(); + move |result| { + let adapter = adapter.clone(); + async move { + match result { + Ok(raw) => adapter.parse_response(&raw).map(Ok), + Err(e) => Some(Err(e)), + } + } + } + }); + + let spk_stream = spk_raw.filter_map({ + let adapter = adapter.clone(); + move |result| { + let adapter = adapter.clone(); + async move { + match result { + Ok(raw) => adapter.parse_response(&raw).map(Ok), + Err(e) => Some(Err(e)), + } + } + } + }); + let merged_stream = merge_streams_with_channel_remap(mic_stream, spk_stream); Ok(( @@ -226,6 +316,7 @@ impl ListenClientDual { DualHandle::Split { mic: mic_handle, spk: spk_handle, + finalize_text, }, )) } @@ -275,15 +366,22 @@ where futures_util::stream::select(mic_mapped, spk_mapped) } -fn websocket_client_with_keep_alive(request: &ClientRequestBuilder) -> WebSocketClient { - WebSocketClient::new(request.clone()) - .with_keep_alive_message(Duration::from_secs(5), keep_alive_message()) +fn websocket_client_with_keep_alive( + request: &ClientRequestBuilder, + adapter: &A, +) -> WebSocketClient { + let mut client = WebSocketClient::new(request.clone()); + + if let Some(keep_alive) = adapter.keep_alive_message() { + client = client.with_keep_alive_message(Duration::from_secs(5), keep_alive); + } + + client } -fn keep_alive_message() -> Message { - Message::Text( - serde_json::to_string(&ControlMessage::KeepAlive) - .unwrap() - .into(), - ) +fn extract_finalize_text(adapter: &A) -> Utf8Bytes { + match adapter.finalize_message() { + Message::Text(text) => text, + _ => r#"{"type":"Finalize"}"#.into(), + } } diff --git a/plugins/listener/src/actors/listener.rs b/plugins/listener/src/actors/listener.rs index 28d41b9c82..7c50897ecb 100644 --- a/plugins/listener/src/actors/listener.rs +++ b/plugins/listener/src/actors/listener.rs @@ -4,7 +4,7 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use futures_util::StreamExt; use tokio::time::error::Elapsed; -use owhisper_client::{ArgmaxAdapter, DeepgramAdapter, FinalizeHandle, SttAdapter}; +use owhisper_client::{ArgmaxAdapter, DeepgramAdapter, FinalizeHandle, SonioxAdapter, SttAdapter}; use owhisper_interface::stream::{Extra, StreamResponse}; use owhisper_interface::{ControlMessage, MixedMessage}; use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, SupervisionEvent}; @@ -194,6 +194,17 @@ fn is_local_stt_base_url(base_url: &str) -> bool { } } +fn is_soniox_base_url(base_url: &str) -> bool { + if let Ok(parsed) = url::Url::parse(base_url) { + parsed + .host_str() + .map(|h| h.contains("soniox.com")) + .unwrap_or(false) + } else { + base_url.contains("soniox.com") + } +} + async fn spawn_rx_task( args: ListenerArgs, myself: ActorRef, @@ -207,11 +218,19 @@ async fn spawn_rx_task( > { match args.mode { crate::actors::ChannelMode::MicOnly | crate::actors::ChannelMode::SpeakerOnly => { - spawn_rx_task_single(args, myself).await + if is_local_stt_base_url(&args.base_url) { + spawn_rx_task_single_with_adapter::(args, myself).await + } else if is_soniox_base_url(&args.base_url) { + spawn_rx_task_single_with_adapter::(args, myself).await + } else { + spawn_rx_task_single_with_adapter::(args, myself).await + } } crate::actors::ChannelMode::MicAndSpeaker => { if is_local_stt_base_url(&args.base_url) { spawn_rx_task_dual_with_adapter::(args, myself).await + } else if is_soniox_base_url(&args.base_url) { + spawn_rx_task_dual_with_adapter::(args, myself).await } else { spawn_rx_task_dual_with_adapter::(args, myself).await } @@ -246,7 +265,7 @@ fn build_extra(args: &ListenerArgs) -> (f64, Extra) { (session_offset_secs, extra) } -async fn spawn_rx_task_single( +async fn spawn_rx_task_single_with_adapter( args: ListenerArgs, myself: ActorRef, ) -> Result< @@ -263,6 +282,7 @@ async fn spawn_rx_task_single( let (tx, rx) = tokio::sync::mpsc::channel::>(32); let client = owhisper_client::ListenClient::builder() + .adapter::() .api_base(args.base_url.clone()) .api_key(args.api_key.clone()) .params(build_listen_params(&args)) @@ -378,7 +398,7 @@ async fn process_stream( loop { tokio::select! { _ = &mut shutdown_rx => { - handle.finalize_with_text(serde_json::json!({"type": "Finalize"}).to_string().into()).await; + handle.finalize().await; let finalize_timeout = tokio::time::sleep(Duration::from_secs(5)); tokio::pin!(finalize_timeout); diff --git a/plugins/listener2/src/ext.rs b/plugins/listener2/src/ext.rs index 70621ebff2..c896afa54f 100644 --- a/plugins/listener2/src/ext.rs +++ b/plugins/listener2/src/ext.rs @@ -10,6 +10,7 @@ use crate::BatchEvent; #[serde(rename_all = "lowercase")] pub enum BatchProvider { Deepgram, + Soniox, Am, } @@ -123,7 +124,7 @@ impl> Listener2PluginExt for T { .emit(&app) .map_err(|_| crate::Error::BatchStartFailed("failed to emit event".to_string()))?; - let client = owhisper_client::BatchClient::builder() + let client = owhisper_client::ListenClient::builder() .api_base(params.base_url.clone()) .api_key(params.api_key.clone()) .params(listen_params) @@ -141,6 +142,36 @@ impl> Listener2PluginExt for T { .emit(&app) .map_err(|_| crate::Error::BatchStartFailed("failed to emit event".to_string()))?; + Ok(()) + } + BatchProvider::Soniox => { + tracing::debug!("using Soniox batch client"); + + BatchEvent::BatchStarted { + session_id: params.session_id.clone(), + } + .emit(&app) + .map_err(|_| crate::Error::BatchStartFailed("failed to emit event".to_string()))?; + + let client = owhisper_client::ListenClient::builder() + .adapter::() + .api_base(params.base_url.clone()) + .api_key(params.api_key.clone()) + .params(listen_params) + .build_batch(); + + tracing::debug!("transcribing file: {}", params.file_path); + let response = client.transcribe_file(¶ms.file_path).await?; + + tracing::info!("Soniox batch transcription completed, emitting response"); + + BatchEvent::BatchResponse { + session_id: params.session_id.clone(), + response, + } + .emit(&app) + .map_err(|_| crate::Error::BatchStartFailed("failed to emit event".to_string()))?; + Ok(()) } }