diff --git a/owhisper/owhisper-client/src/adapter/deepgram.rs b/owhisper/owhisper-client/src/adapter/deepgram.rs new file mode 100644 index 0000000000..06e4d1f908 --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/deepgram.rs @@ -0,0 +1,195 @@ +use std::time::Duration; + +use hypr_ws::client::{ClientRequestBuilder, Message}; +use owhisper_interface::stream::StreamResponse; +use owhisper_interface::{ControlMessage, ListenParams}; +use url::form_urlencoded::Serializer; +use url::UrlQuery; + +use super::SttAdapter; + +/// Deepgram STT adapter. +/// +/// This adapter implements the Deepgram-like API format, which is also used by +/// owhisper-server and other compatible services. +#[derive(Clone, Default)] +pub struct DeepgramAdapter; + +impl DeepgramAdapter { + pub fn new() -> Self { + Self + } + + fn listen_endpoint_url(&self, api_base: &str) -> url::Url { + let mut url: url::Url = api_base.parse().expect("invalid api_base"); + + let mut path = url.path().to_string(); + if !path.ends_with('/') { + path.push('/'); + } + path.push_str("listen"); + url.set_path(&path); + + url + } + + fn apply_ws_scheme(&self, url: &mut url::Url) { + if let Some(host) = url.host_str() { + if host.contains("127.0.0.1") || host.contains("localhost") || host.contains("0.0.0.0") + { + let _ = url.set_scheme("ws"); + } else { + let _ = url.set_scheme("wss"); + } + } + } +} + +impl SttAdapter for DeepgramAdapter { + fn build_url(&self, api_base: &str, params: &ListenParams, channels: u8) -> url::Url { + let mut url = self.listen_endpoint_url(api_base); + + { + let mut query_pairs = url.query_pairs_mut(); + + append_language_query(&mut query_pairs, params); + + let model = params.model.as_deref().unwrap_or("hypr-whisper"); + let channel_string = channels.to_string(); + let sample_rate = params.sample_rate.to_string(); + + query_pairs.append_pair("model", model); + query_pairs.append_pair("channels", &channel_string); + query_pairs.append_pair("filler_words", "false"); + query_pairs.append_pair("interim_results", "true"); + query_pairs.append_pair("mip_opt_out", "true"); + query_pairs.append_pair("sample_rate", &sample_rate); + query_pairs.append_pair("encoding", "linear16"); + query_pairs.append_pair("diarize", "true"); + query_pairs.append_pair("multichannel", "true"); + query_pairs.append_pair("punctuate", "true"); + query_pairs.append_pair("smart_format", "true"); + query_pairs.append_pair("vad_events", "false"); + query_pairs.append_pair("numerals", "true"); + + let redemption_time = params.redemption_time_ms.unwrap_or(400).to_string(); + query_pairs.append_pair("redemption_time_ms", &redemption_time); + + append_keyword_query(&mut query_pairs, params); + } + + self.apply_ws_scheme(&mut url); + url + } + + fn build_batch_url(&self, api_base: &str, params: &ListenParams) -> url::Url { + let mut url = self.listen_endpoint_url(api_base); + + { + let mut query_pairs = url.query_pairs_mut(); + + append_language_query(&mut query_pairs, params); + + let model = params.model.as_deref().unwrap_or("hypr-whisper"); + let sample_rate = params.sample_rate.to_string(); + + query_pairs.append_pair("model", model); + query_pairs.append_pair("encoding", "linear16"); + query_pairs.append_pair("sample_rate", &sample_rate); + query_pairs.append_pair("diarize", "true"); + query_pairs.append_pair("multichannel", "false"); + query_pairs.append_pair("punctuate", "true"); + query_pairs.append_pair("smart_format", "true"); + query_pairs.append_pair("utterances", "true"); + query_pairs.append_pair("numerals", "true"); + query_pairs.append_pair("filler_words", "false"); + query_pairs.append_pair("dictation", "false"); + query_pairs.append_pair("paragraphs", "false"); + query_pairs.append_pair("profanity_filter", "false"); + query_pairs.append_pair("measurements", "false"); + query_pairs.append_pair("topics", "false"); + query_pairs.append_pair("sentiment", "false"); + query_pairs.append_pair("intents", "false"); + query_pairs.append_pair("detect_entities", "false"); + query_pairs.append_pair("mip_opt_out", "true"); + + append_keyword_query(&mut query_pairs, params); + } + + url + } + + fn build_request(&self, url: url::Url, api_key: Option<&str>) -> ClientRequestBuilder { + let uri = url.to_string().parse().unwrap(); + + match api_key { + Some(key) => ClientRequestBuilder::new(uri) + .with_header("Authorization", format!("Token {}", key)), + None => ClientRequestBuilder::new(uri), + } + } + + fn encode_audio(&self, audio: bytes::Bytes) -> Message { + Message::Binary(audio) + } + + fn encode_control(&self, control: &ControlMessage) -> Message { + Message::Text(serde_json::to_string(control).unwrap().into()) + } + + fn decode_response(&self, msg: Message) -> Option { + match msg { + Message::Text(text) => serde_json::from_str::(&text).ok(), + _ => None, + } + } + + fn keep_alive_config(&self) -> Option<(Duration, Message)> { + let message = Message::Text( + serde_json::to_string(&ControlMessage::KeepAlive) + .unwrap() + .into(), + ); + Some((Duration::from_secs(5), message)) + } +} + +fn append_language_query<'a>(query_pairs: &mut Serializer<'a, UrlQuery>, params: &ListenParams) { + match params.languages.len() { + 0 => { + query_pairs.append_pair("detect_language", "true"); + } + 1 => { + if let Some(language) = params.languages.first() { + let code = language.iso639().code(); + query_pairs.append_pair("language", code); + query_pairs.append_pair("languages", code); + } + } + _ => { + query_pairs.append_pair("language", "multi"); + for language in ¶ms.languages { + let code = language.iso639().code(); + query_pairs.append_pair("languages", code); + } + } + } +} + +fn append_keyword_query<'a>(query_pairs: &mut Serializer<'a, UrlQuery>, params: &ListenParams) { + if params.keywords.is_empty() { + return; + } + + let use_keyterms = params + .model + .as_ref() + .map(|model| model.contains("nova-3")) + .unwrap_or(false); + + let param_name = if use_keyterms { "keyterm" } else { "keywords" }; + + for keyword in ¶ms.keywords { + query_pairs.append_pair(param_name, keyword); + } +} diff --git a/owhisper/owhisper-client/src/adapter/mod.rs b/owhisper/owhisper-client/src/adapter/mod.rs new file mode 100644 index 0000000000..e7f5a71028 --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/mod.rs @@ -0,0 +1,47 @@ +mod deepgram; + +pub use deepgram::DeepgramAdapter; + +use std::time::Duration; + +use hypr_ws::client::{ClientRequestBuilder, Message}; +use owhisper_interface::stream::StreamResponse; +use owhisper_interface::{ControlMessage, ListenParams}; + +/// Trait for STT provider adapters. +/// +/// This trait encapsulates provider-specific logic for: +/// - Building WebSocket URLs with provider-specific query parameters +/// - Building WebSocket requests with authentication +/// - Encoding audio and control messages to WebSocket format +/// - Decoding provider responses to the common StreamResponse format +/// - Keep-alive configuration +pub trait SttAdapter: Clone + Send + Sync + 'static { + /// Build the WebSocket URL for this provider. + /// + /// # Arguments + /// * `api_base` - The base URL for the provider's API + /// * `params` - Listen parameters (model, languages, sample_rate, etc.) + /// * `channels` - Number of audio channels (1 for single, 2 for dual) + fn build_url(&self, api_base: &str, params: &ListenParams, channels: u8) -> url::Url; + + /// Build the WebSocket URL for batch transcription. + fn build_batch_url(&self, api_base: &str, params: &ListenParams) -> url::Url; + + /// Build the WebSocket request with authentication headers. + fn build_request(&self, url: url::Url, api_key: Option<&str>) -> ClientRequestBuilder; + + /// Encode audio bytes to a WebSocket message. + fn encode_audio(&self, audio: bytes::Bytes) -> Message; + + /// Encode a control message to a WebSocket message. + fn encode_control(&self, control: &ControlMessage) -> Message; + + /// Decode a WebSocket message to a StreamResponse. + /// Returns None if the message cannot be decoded (e.g., ping/pong messages). + fn decode_response(&self, msg: Message) -> Option; + + /// Get the keep-alive configuration for this provider. + /// Returns None if keep-alive is not needed. + fn keep_alive_config(&self) -> Option<(Duration, Message)>; +} diff --git a/owhisper/owhisper-client/src/lib.rs b/owhisper/owhisper-client/src/lib.rs index cc80f0cb41..1523bef4fd 100644 --- a/owhisper/owhisper-client/src/lib.rs +++ b/owhisper/owhisper-client/src/lib.rs @@ -1,162 +1,122 @@ +pub mod adapter; mod batch; mod error; mod live; -use url::form_urlencoded::Serializer; -use url::UrlQuery; - +pub use adapter::{DeepgramAdapter, SttAdapter}; pub use batch::BatchClient; pub use error::Error; pub use hypr_ws; pub use live::{ListenClient, ListenClientDual}; -#[derive(Default)] -pub struct ListenClientBuilder { +/// Builder for creating STT clients. +/// +/// The builder is generic over the adapter type, which determines how to +/// communicate with the STT provider. By default, it uses `DeepgramAdapter`. +/// +/// # Example +/// +/// ```ignore +/// // Using default Deepgram adapter +/// let client = ListenClient::builder() +/// .api_base("https://api.deepgram.com/v1") +/// .api_key("your-api-key") +/// .params(params) +/// .build_single(); +/// +/// // Using explicit adapter +/// let client = ListenClientBuilder::with_adapter(DeepgramAdapter::new()) +/// .api_base("https://api.deepgram.com/v1") +/// .api_key("your-api-key") +/// .params(params) +/// .build_single(); +/// ``` +pub struct ListenClientBuilder { + adapter: A, api_base: Option, api_key: Option, params: Option, } -impl ListenClientBuilder { +impl Default for ListenClientBuilder { + fn default() -> Self { + Self { + adapter: DeepgramAdapter::default(), + api_base: None, + api_key: None, + params: None, + } + } +} + +impl ListenClientBuilder { + /// Create a new builder with the default Deepgram adapter. + pub fn new() -> Self { + Self::default() + } +} + +impl ListenClientBuilder { + /// Create a new builder with a specific adapter. + pub fn with_adapter(adapter: A) -> Self { + Self { + adapter, + api_base: None, + api_key: None, + params: None, + } + } + + /// Set the API base URL. pub fn api_base(mut self, api_base: impl Into) -> Self { self.api_base = Some(api_base.into()); self } + /// Set the API key for authentication. pub fn api_key(mut self, api_key: impl Into) -> Self { self.api_key = Some(api_key.into()); self } + /// Set the listen parameters. pub fn params(mut self, params: owhisper_interface::ListenParams) -> Self { self.params = Some(params); self } - fn listen_endpoint_url(&self) -> url::Url { - let mut url: url::Url = self - .api_base - .as_ref() - .expect("api_base is required") - .parse() - .expect("invalid api_base"); - - let mut path = url.path().to_string(); - if !path.ends_with('/') { - path.push('/'); - } - path.push_str("listen"); - url.set_path(&path); - - url - } - - pub(crate) fn build_batch_url(&self) -> url::Url { - let params = self.params.clone().unwrap_or_default(); - let mut url = self.listen_endpoint_url(); - - { - let mut query_pairs = url.query_pairs_mut(); - - append_language_query(&mut query_pairs, ¶ms); - - let model = params.model.as_deref().unwrap_or("hypr-whisper"); - let sample_rate = params.sample_rate.to_string(); - - query_pairs.append_pair("model", model); - query_pairs.append_pair("encoding", "linear16"); - query_pairs.append_pair("sample_rate", &sample_rate); - query_pairs.append_pair("diarize", "true"); - query_pairs.append_pair("multichannel", "false"); - query_pairs.append_pair("punctuate", "true"); - query_pairs.append_pair("smart_format", "true"); - query_pairs.append_pair("utterances", "true"); - query_pairs.append_pair("numerals", "true"); - query_pairs.append_pair("filler_words", "false"); - query_pairs.append_pair("dictation", "false"); - query_pairs.append_pair("paragraphs", "false"); - query_pairs.append_pair("profanity_filter", "false"); - query_pairs.append_pair("measurements", "false"); - query_pairs.append_pair("topics", "false"); - query_pairs.append_pair("sentiment", "false"); - query_pairs.append_pair("intents", "false"); - query_pairs.append_pair("detect_entities", "false"); - query_pairs.append_pair("mip_opt_out", "true"); - - append_keyword_query(&mut query_pairs, ¶ms); - } - - url + /// Get a reference to the adapter. + pub fn adapter(&self) -> &A { + &self.adapter } pub(crate) fn build_url(&self, channels: u8) -> url::Url { - let mut params = self.params.clone().unwrap_or_default(); - params.channels = channels; - - let mut url = self.listen_endpoint_url(); - - { - let mut query_pairs = url.query_pairs_mut(); - - append_language_query(&mut query_pairs, ¶ms); - - let model = params.model.as_deref().unwrap_or("hypr-whisper"); - let channel_string = channels.to_string(); - let sample_rate = params.sample_rate.to_string(); - - query_pairs.append_pair("model", model); - query_pairs.append_pair("channels", &channel_string); - query_pairs.append_pair("filler_words", "false"); - query_pairs.append_pair("interim_results", "true"); - query_pairs.append_pair("mip_opt_out", "true"); - query_pairs.append_pair("sample_rate", &sample_rate); - query_pairs.append_pair("encoding", "linear16"); - query_pairs.append_pair("diarize", "true"); - query_pairs.append_pair("multichannel", "true"); - query_pairs.append_pair("punctuate", "true"); - query_pairs.append_pair("smart_format", "true"); - query_pairs.append_pair("vad_events", "false"); - query_pairs.append_pair("numerals", "true"); - - let redemption_time = params.redemption_time_ms.unwrap_or(400).to_string(); - query_pairs.append_pair("redemption_time_ms", &redemption_time); - - append_keyword_query(&mut query_pairs, ¶ms); - } - - url + let api_base = self.api_base.as_ref().expect("api_base is required"); + let params = self.params.clone().unwrap_or_default(); + self.adapter.build_url(api_base, ¶ms, channels) } - pub(crate) fn build_uri(&self, channels: u8) -> String { - let mut url = self.build_url(channels); - - if let Some(host) = url.host_str() { - if host.contains("127.0.0.1") || host.contains("localhost") || host.contains("0.0.0.0") - { - let _ = url.set_scheme("ws"); - } else { - let _ = url.set_scheme("wss"); - } - } - - url.to_string() + pub(crate) fn build_batch_url(&self) -> url::Url { + let api_base = self.api_base.as_ref().expect("api_base is required"); + let params = self.params.clone().unwrap_or_default(); + self.adapter.build_batch_url(api_base, ¶ms) } pub(crate) fn build_request(&self, channels: u8) -> hypr_ws::client::ClientRequestBuilder { - let uri = self.build_uri(channels).parse().unwrap(); - - match &self.api_key { - Some(key) => hypr_ws::client::ClientRequestBuilder::new(uri) - .with_header("Authorization", format!("Token {}", key)), - None => hypr_ws::client::ClientRequestBuilder::new(uri), - } + let url = self.build_url(channels); + self.adapter.build_request(url, self.api_key.as_deref()) } - pub fn build_with_channels(self, channels: u8) -> ListenClient { + /// Build a client with the specified number of channels. + pub fn build_with_channels(self, channels: u8) -> ListenClient { let request = self.build_request(channels); - ListenClient { request } + ListenClient { + adapter: self.adapter, + request, + } } + /// Build a batch client for pre-recorded audio transcription. pub fn build_batch(self) -> BatchClient { let url = self.build_batch_url(); @@ -167,59 +127,18 @@ impl ListenClientBuilder { } } - pub fn build_single(self) -> ListenClient { + /// Build a single-channel client. + pub fn build_single(self) -> ListenClient { self.build_with_channels(1) } - pub fn build_dual(self) -> ListenClientDual { + /// Build a dual-channel client (mic + speaker). + pub fn build_dual(self) -> ListenClientDual { let request = self.build_request(2); - ListenClientDual { request } - } -} - -pub(crate) fn append_language_query<'a>( - query_pairs: &mut Serializer<'a, UrlQuery>, - params: &owhisper_interface::ListenParams, -) { - match params.languages.len() { - 0 => { - query_pairs.append_pair("detect_language", "true"); + ListenClientDual { + adapter: self.adapter, + request, } - 1 => { - if let Some(language) = params.languages.first() { - let code = language.iso639().code(); - query_pairs.append_pair("language", code); - query_pairs.append_pair("languages", code); - } - } - _ => { - query_pairs.append_pair("language", "multi"); - for language in ¶ms.languages { - let code = language.iso639().code(); - query_pairs.append_pair("languages", code); - } - } - } -} - -pub(crate) fn append_keyword_query<'a>( - query_pairs: &mut Serializer<'a, UrlQuery>, - params: &owhisper_interface::ListenParams, -) { - if params.keywords.is_empty() { - return; - } - - let use_keyterms = params - .model - .as_ref() - .map(|model| model.contains("nova-3")) - .unwrap_or(false); - - let param_name = if use_keyterms { "keyterm" } else { "keywords" }; - - for keyword in ¶ms.keywords { - query_pairs.append_pair(param_name, keyword); } } diff --git a/owhisper/owhisper-client/src/live.rs b/owhisper/owhisper-client/src/live.rs index c0d8f6e571..525d6f4ac3 100644 --- a/owhisper/owhisper-client/src/live.rs +++ b/owhisper/owhisper-client/src/live.rs @@ -1,3 +1,4 @@ +use std::marker::PhantomData; use std::time::Duration; use futures_util::Stream; @@ -6,18 +7,29 @@ use hypr_ws::client::{ClientRequestBuilder, Message, WebSocketClient, WebSocketI use owhisper_interface::stream::StreamResponse; use owhisper_interface::{ControlMessage, MixedMessage}; -use crate::ListenClientBuilder; +use crate::adapter::SttAdapter; +use crate::{DeepgramAdapter, ListenClientBuilder}; pub type ListenClientInput = MixedMessage; pub type ListenClientDualInput = MixedMessage<(bytes::Bytes, bytes::Bytes), ControlMessage>; +/// A single-channel STT client. +/// +/// This client is generic over the adapter type, which determines how to +/// communicate with the STT provider. #[derive(Clone)] -pub struct ListenClient { +pub struct ListenClient { + pub(crate) adapter: A, pub(crate) request: ClientRequestBuilder, } +/// A dual-channel STT client (mic + speaker). +/// +/// This client is generic over the adapter type, which determines how to +/// communicate with the STT provider. #[derive(Clone)] -pub struct ListenClientDual { +pub struct ListenClientDual { + pub(crate) adapter: A, pub(crate) request: ClientRequestBuilder, } @@ -44,7 +56,15 @@ fn interleave_audio(mic: &[u8], speaker: &[u8]) -> Vec { interleaved } -impl WebSocketIO for ListenClient { +/// WebSocket IO wrapper for single-channel client. +/// +/// This struct is used internally to implement the WebSocketIO trait +/// for the ListenClient, handling message encoding/decoding. +pub struct ListenClientIO { + _marker: PhantomData, +} + +impl WebSocketIO for ListenClientIO { type Data = ListenClientInput; type Input = ListenClientInput; type Output = StreamResponse; @@ -70,7 +90,16 @@ impl WebSocketIO for ListenClient { } } -impl WebSocketIO for ListenClientDual { +/// WebSocket IO wrapper for dual-channel client. +/// +/// This struct is used internally to implement the WebSocketIO trait +/// for the ListenClientDual, handling message encoding/decoding and +/// audio interleaving. +pub struct ListenClientDualIO { + _marker: PhantomData, +} + +impl WebSocketIO for ListenClientDualIO { type Data = ListenClientDualInput; type Input = ListenClientInput; type Output = StreamResponse; @@ -102,11 +131,22 @@ impl WebSocketIO for ListenClientDual { } } -impl ListenClient { - pub fn builder() -> ListenClientBuilder { +impl ListenClient { + /// Create a new builder with the default Deepgram adapter. + pub fn builder() -> ListenClientBuilder { ListenClientBuilder::default() } +} +impl ListenClient { + /// Get a reference to the adapter. + pub fn adapter(&self) -> &A { + &self.adapter + } + + /// Connect to the STT service and start streaming audio. + /// + /// Returns a stream of transcription responses and a handle to control the connection. pub async fn from_realtime_audio( self, audio_stream: impl Stream + Send + Unpin + 'static, @@ -117,12 +157,20 @@ impl ListenClient { ), hypr_ws::Error, > { - let ws = websocket_client_with_keep_alive(&self.request); - ws.from_audio::(audio_stream).await + let ws = websocket_client_with_keep_alive(&self.request, &self.adapter); + ws.from_audio::>(audio_stream).await } } -impl ListenClientDual { +impl ListenClientDual { + /// Get a reference to the adapter. + pub fn adapter(&self) -> &A { + &self.adapter + } + + /// Connect to the STT service and start streaming dual-channel audio. + /// + /// Returns a stream of transcription responses and a handle to control the connection. pub async fn from_realtime_audio( self, stream: impl Stream + Send + Unpin + 'static, @@ -133,14 +181,22 @@ impl ListenClientDual { ), hypr_ws::Error, > { - let ws = websocket_client_with_keep_alive(&self.request); - ws.from_audio::(stream).await + let ws = websocket_client_with_keep_alive(&self.request, &self.adapter); + ws.from_audio::>(stream).await } } -fn websocket_client_with_keep_alive(request: &ClientRequestBuilder) -> WebSocketClient { - WebSocketClient::new(request.clone()) - .with_keep_alive_message(Duration::from_secs(5), keep_alive_message()) +fn websocket_client_with_keep_alive( + request: &ClientRequestBuilder, + adapter: &A, +) -> WebSocketClient { + let ws = WebSocketClient::new(request.clone()); + + if let Some((interval, message)) = adapter.keep_alive_config() { + ws.with_keep_alive_message(interval, message) + } else { + ws.with_keep_alive_message(Duration::from_secs(5), keep_alive_message()) + } } fn keep_alive_message() -> Message {