diff --git a/owhisper/owhisper-client/src/adapter/argmax/batch.rs b/owhisper/owhisper-client/src/adapter/argmax/batch.rs new file mode 100644 index 0000000000..dda36923f6 --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/argmax/batch.rs @@ -0,0 +1,20 @@ +use std::path::Path; + +use owhisper_interface::ListenParams; + +use super::ArgmaxAdapter; +use crate::adapter::{BatchFuture, BatchSttAdapter}; + +impl BatchSttAdapter for ArgmaxAdapter { + fn transcribe_file<'a, P: AsRef + Send + 'a>( + &'a self, + client: &'a reqwest::Client, + api_base: &'a str, + api_key: &'a str, + params: &'a ListenParams, + file_path: P, + ) -> BatchFuture<'a> { + self.inner + .transcribe_file(client, api_base, api_key, params, file_path) + } +} diff --git a/owhisper/owhisper-client/src/adapter/argmax.rs b/owhisper/owhisper-client/src/adapter/argmax/live.rs similarity index 57% rename from owhisper/owhisper-client/src/adapter/argmax.rs rename to owhisper/owhisper-client/src/adapter/argmax/live.rs index 23ab712f4e..4afa65a201 100644 --- a/owhisper/owhisper-client/src/adapter/argmax.rs +++ b/owhisper/owhisper-client/src/adapter/argmax/live.rs @@ -1,49 +1,11 @@ -use std::path::Path; - use hypr_ws::client::Message; use owhisper_interface::stream::StreamResponse; use owhisper_interface::ListenParams; -use super::{BatchFuture, DeepgramAdapter, SttAdapter}; - -const PARAKEET_V3_LANGS: &[&str] = &[ - "bg", "cs", "da", "de", "el", "en", "es", "et", "fi", "fr", "hr", "hu", "it", "lt", "lv", "mt", - "nl", "pl", "pt", "ro", "ru", "sk", "sl", "sv", "uk", -]; - -#[derive(Clone, Default)] -pub struct ArgmaxAdapter { - inner: DeepgramAdapter, -} - -impl ArgmaxAdapter { - fn adapt_params(params: &ListenParams) -> ListenParams { - let mut adapted = params.clone(); - let model = params.model.as_deref().unwrap_or(""); - - let lang = if model.contains("parakeet") && model.contains("v2") { - hypr_language::ISO639::En.into() - } else if model.contains("parakeet") && model.contains("v3") { - params - .languages - .iter() - .find(|lang| PARAKEET_V3_LANGS.contains(&lang.iso639().code())) - .cloned() - .unwrap_or_else(|| hypr_language::ISO639::En.into()) - } else { - params - .languages - .first() - .cloned() - .unwrap_or_else(|| hypr_language::ISO639::En.into()) - }; +use super::ArgmaxAdapter; +use crate::adapter::RealtimeSttAdapter; - adapted.languages = vec![lang]; - adapted - } -} - -impl SttAdapter for ArgmaxAdapter { +impl RealtimeSttAdapter for ArgmaxAdapter { fn supports_native_multichannel(&self) -> bool { false } @@ -68,18 +30,6 @@ impl SttAdapter for ArgmaxAdapter { fn parse_response(&self, raw: &str) -> Option { self.inner.parse_response(raw) } - - fn transcribe_file<'a, P: AsRef + Send + 'a>( - &'a self, - client: &'a reqwest::Client, - api_base: &'a str, - api_key: &'a str, - params: &'a ListenParams, - file_path: P, - ) -> BatchFuture<'a> { - self.inner - .transcribe_file(client, api_base, api_key, params, file_path) - } } #[cfg(test)] diff --git a/owhisper/owhisper-client/src/adapter/argmax/mod.rs b/owhisper/owhisper-client/src/adapter/argmax/mod.rs new file mode 100644 index 0000000000..a54d2368e6 --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/argmax/mod.rs @@ -0,0 +1,43 @@ +mod batch; +mod live; + +use owhisper_interface::ListenParams; + +use super::DeepgramAdapter; + +const PARAKEET_V3_LANGS: &[&str] = &[ + "bg", "cs", "da", "de", "el", "en", "es", "et", "fi", "fr", "hr", "hu", "it", "lt", "lv", "mt", + "nl", "pl", "pt", "ro", "ru", "sk", "sl", "sv", "uk", +]; + +#[derive(Clone, Default)] +pub struct ArgmaxAdapter { + inner: DeepgramAdapter, +} + +impl ArgmaxAdapter { + pub(crate) fn adapt_params(params: &ListenParams) -> ListenParams { + let mut adapted = params.clone(); + let model = params.model.as_deref().unwrap_or(""); + + let lang = if model.contains("parakeet") && model.contains("v2") { + hypr_language::ISO639::En.into() + } else if model.contains("parakeet") && model.contains("v3") { + params + .languages + .iter() + .find(|lang| PARAKEET_V3_LANGS.contains(&lang.iso639().code())) + .cloned() + .unwrap_or_else(|| hypr_language::ISO639::En.into()) + } else { + params + .languages + .first() + .cloned() + .unwrap_or_else(|| hypr_language::ISO639::En.into()) + }; + + adapted.languages = vec![lang]; + adapted + } +} diff --git a/owhisper/owhisper-client/src/adapter/deepgram.rs b/owhisper/owhisper-client/src/adapter/deepgram.rs deleted file mode 100644 index 1cb4c7138b..0000000000 --- a/owhisper/owhisper-client/src/adapter/deepgram.rs +++ /dev/null @@ -1,363 +0,0 @@ -use std::path::{Path, PathBuf}; - -use hypr_audio_utils::{f32_to_i16_bytes, resample_audio, source_from_path, Source}; -use hypr_ws::client::Message; -use owhisper_interface::batch::Response as BatchResponse; -use owhisper_interface::stream::StreamResponse; -use owhisper_interface::ListenParams; -use url::form_urlencoded::Serializer; -use url::UrlQuery; - -use super::{BatchFuture, SttAdapter}; -use crate::error::Error; - -const NOVA2_MULTI_LANGS: &[&str] = &["en", "es"]; -const NOVA3_MULTI_LANGS: &[&str] = &["en", "es", "fr", "de", "hi", "ru", "pt", "ja", "it", "nl"]; - -#[derive(Clone, Default)] -pub struct DeepgramAdapter; - -impl DeepgramAdapter { - fn listen_endpoint_url(api_base: &str) -> url::Url { - let mut url: url::Url = api_base.parse().expect("invalid_api_base"); - - let mut path = url.path().to_string(); - if !path.ends_with('/') { - path.push('/'); - } - path.push_str("listen"); - url.set_path(&path); - - url - } - - fn build_batch_url(api_base: &str, params: &ListenParams) -> url::Url { - let mut url = Self::listen_endpoint_url(api_base); - - { - let mut query_pairs = url.query_pairs_mut(); - - append_language_query(&mut query_pairs, params); - - let model = params.model.as_deref().unwrap_or("hypr-whisper"); - let sample_rate = params.sample_rate.to_string(); - - query_pairs.append_pair("model", model); - query_pairs.append_pair("encoding", "linear16"); - query_pairs.append_pair("sample_rate", &sample_rate); - query_pairs.append_pair("diarize", "true"); - query_pairs.append_pair("multichannel", "false"); - query_pairs.append_pair("punctuate", "true"); - query_pairs.append_pair("smart_format", "true"); - query_pairs.append_pair("utterances", "true"); - query_pairs.append_pair("numerals", "true"); - query_pairs.append_pair("filler_words", "false"); - query_pairs.append_pair("dictation", "false"); - query_pairs.append_pair("paragraphs", "false"); - query_pairs.append_pair("profanity_filter", "false"); - query_pairs.append_pair("measurements", "false"); - query_pairs.append_pair("topics", "false"); - query_pairs.append_pair("sentiment", "false"); - query_pairs.append_pair("intents", "false"); - query_pairs.append_pair("detect_entities", "false"); - query_pairs.append_pair("mip_opt_out", "true"); - - append_keyword_query(&mut query_pairs, params); - } - - url - } - - // https://developers.deepgram.com/reference/speech-to-text/listen-pre-recorded - // https://github.com/deepgram/deepgram-rust-sdk/blob/main/src/listen/rest.rs - async fn do_transcribe_file( - client: &reqwest::Client, - api_base: &str, - api_key: &str, - params: &ListenParams, - file_path: PathBuf, - ) -> Result { - let (audio_data, sample_rate) = decode_audio_to_linear16(file_path).await?; - - let url = { - let mut url = Self::build_batch_url(api_base, params); - url.query_pairs_mut() - .append_pair("sample_rate", &sample_rate.to_string()); - url - }; - - let content_type = format!("audio/raw;encoding=linear16;rate={}", sample_rate); - - let response = client - .post(url) - .header("Authorization", format!("Token {}", api_key)) - .header("Accept", "application/json") - .header("Content-Type", content_type) - .body(audio_data) - .send() - .await?; - - let status = response.status(); - if status.is_success() { - Ok(response.json().await?) - } else { - Err(Error::UnexpectedStatus { - status, - body: response.text().await.unwrap_or_default(), - }) - } - } -} - -impl SttAdapter for DeepgramAdapter { - fn supports_native_multichannel(&self) -> bool { - true - } - - fn build_ws_url(&self, api_base: &str, params: &ListenParams, channels: u8) -> url::Url { - let mut url = Self::listen_endpoint_url(api_base); - - { - let mut query_pairs = url.query_pairs_mut(); - - append_language_query(&mut query_pairs, params); - - let model = params.model.as_deref().unwrap_or("hypr-whisper"); - let channel_string = channels.to_string(); - let sample_rate = params.sample_rate.to_string(); - - query_pairs.append_pair("model", model); - query_pairs.append_pair("channels", &channel_string); - query_pairs.append_pair("filler_words", "false"); - query_pairs.append_pair("interim_results", "true"); - query_pairs.append_pair("mip_opt_out", "true"); - query_pairs.append_pair("sample_rate", &sample_rate); - query_pairs.append_pair("encoding", "linear16"); - query_pairs.append_pair("diarize", "true"); - query_pairs.append_pair("multichannel", "true"); - query_pairs.append_pair("punctuate", "true"); - query_pairs.append_pair("smart_format", "true"); - query_pairs.append_pair("vad_events", "false"); - query_pairs.append_pair("numerals", "true"); - - let redemption_time = params.redemption_time_ms.unwrap_or(400).to_string(); - query_pairs.append_pair("redemption_time_ms", &redemption_time); - - append_keyword_query(&mut query_pairs, params); - } - - if let Some(host) = url.host_str() { - if host.contains("127.0.0.1") || host.contains("localhost") || host.contains("0.0.0.0") - { - let _ = url.set_scheme("ws"); - } else { - let _ = url.set_scheme("wss"); - } - } - - url - } - - fn build_auth_header(&self, api_key: Option<&str>) -> Option<(&'static str, String)> { - api_key.map(|key| ("Authorization", format!("Token {}", key))) - } - - fn keep_alive_message(&self) -> Option { - Some(Message::Text( - serde_json::to_string(&owhisper_interface::ControlMessage::KeepAlive) - .unwrap() - .into(), - )) - } - - fn finalize_message(&self) -> Message { - Message::Text( - serde_json::to_string(&owhisper_interface::ControlMessage::Finalize) - .unwrap() - .into(), - ) - } - - fn parse_response(&self, raw: &str) -> Option { - serde_json::from_str(raw).ok() - } - - fn transcribe_file<'a, P: AsRef + Send + 'a>( - &'a self, - client: &'a reqwest::Client, - api_base: &'a str, - api_key: &'a str, - params: &'a ListenParams, - file_path: P, - ) -> BatchFuture<'a> { - let path = file_path.as_ref().to_path_buf(); - Box::pin(Self::do_transcribe_file( - client, api_base, api_key, params, path, - )) - } -} - -fn can_use_multi(model: &str, languages: &[hypr_language::Language]) -> bool { - if languages.len() < 2 { - return false; - } - - let multi_langs: &[&str] = if model.contains("nova-3") { - NOVA3_MULTI_LANGS - } else if model.contains("nova-2") { - NOVA2_MULTI_LANGS - } else { - return false; - }; - - languages - .iter() - .all(|lang| multi_langs.contains(&lang.iso639().code())) -} - -fn append_keyword_query<'a>(query_pairs: &mut Serializer<'a, UrlQuery>, params: &ListenParams) { - if params.keywords.is_empty() { - return; - } - - let use_keyterms = params - .model - .as_ref() - .map(|model| model.contains("nova-3") || model.contains("parakeet")) - .unwrap_or(false); - - let param_name = if use_keyterms { "keyterm" } else { "keywords" }; - - for keyword in ¶ms.keywords { - query_pairs.append_pair(param_name, keyword); - } -} - -pub(crate) fn append_language_query<'a>( - query_pairs: &mut Serializer<'a, UrlQuery>, - params: &ListenParams, -) { - let model = params.model.as_deref().unwrap_or(""); - - match params.languages.len() { - 0 => { - query_pairs.append_pair("detect_language", "true"); - } - 1 => { - if let Some(language) = params.languages.first() { - let code = language.iso639().code(); - query_pairs.append_pair("language", code); - } - } - _ => { - if can_use_multi(model, ¶ms.languages) { - query_pairs.append_pair("language", "multi"); - for language in ¶ms.languages { - let code = language.iso639().code(); - query_pairs.append_pair("languages", code); - } - } else { - query_pairs.append_pair("detect_language", "true"); - for language in ¶ms.languages { - let code = language.iso639().code(); - query_pairs.append_pair("languages", code); - } - } - } - } -} - -async fn decode_audio_to_linear16(path: PathBuf) -> Result<(bytes::Bytes, u32), Error> { - tokio::task::spawn_blocking(move || -> Result<(bytes::Bytes, u32), Error> { - let decoder = - source_from_path(&path).map_err(|err| Error::AudioProcessing(err.to_string()))?; - - let channels = decoder.channels().max(1); - let sample_rate = decoder.sample_rate(); - - let samples = resample_audio(decoder, sample_rate) - .map_err(|err| Error::AudioProcessing(err.to_string()))?; - - let samples = if channels == 1 { - samples - } else { - let channels_usize = channels as usize; - let mut mono = Vec::with_capacity(samples.len() / channels_usize); - for frame in samples.chunks(channels_usize) { - if frame.is_empty() { - continue; - } - let sum: f32 = frame.iter().copied().sum(); - mono.push(sum / frame.len() as f32); - } - mono - }; - - if samples.is_empty() { - return Err(Error::AudioProcessing( - "audio file contains no samples".to_string(), - )); - } - - let bytes = f32_to_i16_bytes(samples.into_iter()); - - Ok((bytes, sample_rate)) - }) - .await? -} - -#[cfg(test)] -mod tests { - use futures_util::StreamExt; - use hypr_audio_utils::AudioFormatExt; - - use crate::live::ListenClientInput; - use crate::ListenClient; - - #[tokio::test] - async fn test_client() { - let _ = tracing_subscriber::fmt::try_init(); - - let audio = rodio::Decoder::new(std::io::BufReader::new( - std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), - )) - .unwrap() - .to_i16_le_chunks(16000, 512); - - let input = Box::pin(tokio_stream::StreamExt::throttle( - audio.map(|chunk| ListenClientInput::Audio(chunk)), - std::time::Duration::from_millis(20), - )); - - let client = ListenClient::builder() - .api_base("https://api.deepgram.com/v1") - .api_key("71557216ffdd13bff22702be5017e4852c052b7c") - .params(owhisper_interface::ListenParams { - model: Some("nova-3".to_string()), - languages: vec![ - hypr_language::ISO639::En.into(), - hypr_language::ISO639::Es.into(), - ], - ..Default::default() - }) - .build_single(); - - let (stream, _) = client.from_realtime_audio(input).await.unwrap(); - futures_util::pin_mut!(stream); - - while let Some(result) = stream.next().await { - match result { - Ok(response) => match response { - owhisper_interface::stream::StreamResponse::TranscriptResponse { - channel, - .. - } => { - println!("{:?}", channel.alternatives.first().unwrap().transcript); - } - _ => {} - }, - _ => {} - } - } - } -} diff --git a/owhisper/owhisper-client/src/adapter/deepgram/batch.rs b/owhisper/owhisper-client/src/adapter/deepgram/batch.rs new file mode 100644 index 0000000000..c6846a8742 --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/deepgram/batch.rs @@ -0,0 +1,141 @@ +use std::path::{Path, PathBuf}; + +use hypr_audio_utils::{f32_to_i16_bytes, resample_audio, source_from_path, Source}; +use owhisper_interface::batch::Response as BatchResponse; +use owhisper_interface::ListenParams; + +use super::{append_keyword_query, append_language_query, DeepgramAdapter}; +use crate::adapter::{BatchFuture, BatchSttAdapter}; +use crate::error::Error; + +impl BatchSttAdapter for DeepgramAdapter { + fn transcribe_file<'a, P: AsRef + Send + 'a>( + &'a self, + client: &'a reqwest::Client, + api_base: &'a str, + api_key: &'a str, + params: &'a ListenParams, + file_path: P, + ) -> BatchFuture<'a> { + let path = file_path.as_ref().to_path_buf(); + Box::pin(Self::do_transcribe_file( + client, api_base, api_key, params, path, + )) + } +} + +impl DeepgramAdapter { + fn build_batch_url(api_base: &str, params: &ListenParams) -> url::Url { + let mut url = Self::listen_endpoint_url(api_base); + + { + let mut query_pairs = url.query_pairs_mut(); + + append_language_query(&mut query_pairs, params); + + let model = params.model.as_deref().unwrap_or("hypr-whisper"); + let sample_rate = params.sample_rate.to_string(); + + query_pairs.append_pair("model", model); + query_pairs.append_pair("encoding", "linear16"); + query_pairs.append_pair("sample_rate", &sample_rate); + query_pairs.append_pair("diarize", "true"); + query_pairs.append_pair("multichannel", "false"); + query_pairs.append_pair("punctuate", "true"); + query_pairs.append_pair("smart_format", "true"); + query_pairs.append_pair("utterances", "true"); + query_pairs.append_pair("numerals", "true"); + query_pairs.append_pair("filler_words", "false"); + query_pairs.append_pair("dictation", "false"); + query_pairs.append_pair("paragraphs", "false"); + query_pairs.append_pair("profanity_filter", "false"); + query_pairs.append_pair("measurements", "false"); + query_pairs.append_pair("topics", "false"); + query_pairs.append_pair("sentiment", "false"); + query_pairs.append_pair("intents", "false"); + query_pairs.append_pair("detect_entities", "false"); + query_pairs.append_pair("mip_opt_out", "true"); + + append_keyword_query(&mut query_pairs, params); + } + + url + } + + async fn do_transcribe_file( + client: &reqwest::Client, + api_base: &str, + api_key: &str, + params: &ListenParams, + file_path: PathBuf, + ) -> Result { + let (audio_data, sample_rate) = decode_audio_to_linear16(file_path).await?; + + let url = { + let mut url = Self::build_batch_url(api_base, params); + url.query_pairs_mut() + .append_pair("sample_rate", &sample_rate.to_string()); + url + }; + + let content_type = format!("audio/raw;encoding=linear16;rate={}", sample_rate); + + let response = client + .post(url) + .header("Authorization", format!("Token {}", api_key)) + .header("Accept", "application/json") + .header("Content-Type", content_type) + .body(audio_data) + .send() + .await?; + + let status = response.status(); + if status.is_success() { + Ok(response.json().await?) + } else { + Err(Error::UnexpectedStatus { + status, + body: response.text().await.unwrap_or_default(), + }) + } + } +} + +async fn decode_audio_to_linear16(path: PathBuf) -> Result<(bytes::Bytes, u32), Error> { + tokio::task::spawn_blocking(move || -> Result<(bytes::Bytes, u32), Error> { + let decoder = + source_from_path(&path).map_err(|err| Error::AudioProcessing(err.to_string()))?; + + let channels = decoder.channels().max(1); + let sample_rate = decoder.sample_rate(); + + let samples = resample_audio(decoder, sample_rate) + .map_err(|err| Error::AudioProcessing(err.to_string()))?; + + let samples = if channels == 1 { + samples + } else { + let channels_usize = channels as usize; + let mut mono = Vec::with_capacity(samples.len() / channels_usize); + for frame in samples.chunks(channels_usize) { + if frame.is_empty() { + continue; + } + let sum: f32 = frame.iter().copied().sum(); + mono.push(sum / frame.len() as f32); + } + mono + }; + + if samples.is_empty() { + return Err(Error::AudioProcessing( + "audio file contains no samples".to_string(), + )); + } + + let bytes = f32_to_i16_bytes(samples.into_iter()); + + Ok((bytes, sample_rate)) + }) + .await? +} diff --git a/owhisper/owhisper-client/src/adapter/deepgram/live.rs b/owhisper/owhisper-client/src/adapter/deepgram/live.rs new file mode 100644 index 0000000000..71715d5eb3 --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/deepgram/live.rs @@ -0,0 +1,136 @@ +use hypr_ws::client::Message; +use owhisper_interface::stream::StreamResponse; +use owhisper_interface::ListenParams; + +use super::{append_keyword_query, append_language_query, DeepgramAdapter}; +use crate::adapter::RealtimeSttAdapter; + +impl RealtimeSttAdapter for DeepgramAdapter { + fn supports_native_multichannel(&self) -> bool { + true + } + + fn build_ws_url(&self, api_base: &str, params: &ListenParams, channels: u8) -> url::Url { + let mut url = Self::listen_endpoint_url(api_base); + + { + let mut query_pairs = url.query_pairs_mut(); + + append_language_query(&mut query_pairs, params); + + let model = params.model.as_deref().unwrap_or("hypr-whisper"); + let channel_string = channels.to_string(); + let sample_rate = params.sample_rate.to_string(); + + query_pairs.append_pair("model", model); + query_pairs.append_pair("channels", &channel_string); + query_pairs.append_pair("filler_words", "false"); + query_pairs.append_pair("interim_results", "true"); + query_pairs.append_pair("mip_opt_out", "true"); + query_pairs.append_pair("sample_rate", &sample_rate); + query_pairs.append_pair("encoding", "linear16"); + query_pairs.append_pair("diarize", "true"); + query_pairs.append_pair("multichannel", "true"); + query_pairs.append_pair("punctuate", "true"); + query_pairs.append_pair("smart_format", "true"); + query_pairs.append_pair("vad_events", "false"); + query_pairs.append_pair("numerals", "true"); + + let redemption_time = params.redemption_time_ms.unwrap_or(400).to_string(); + query_pairs.append_pair("redemption_time_ms", &redemption_time); + + append_keyword_query(&mut query_pairs, params); + } + + if let Some(host) = url.host_str() { + if host.contains("127.0.0.1") || host.contains("localhost") || host.contains("0.0.0.0") + { + let _ = url.set_scheme("ws"); + } else { + let _ = url.set_scheme("wss"); + } + } + + url + } + + fn build_auth_header(&self, api_key: Option<&str>) -> Option<(&'static str, String)> { + api_key.map(|key| ("Authorization", format!("Token {}", key))) + } + + fn keep_alive_message(&self) -> Option { + Some(Message::Text( + serde_json::to_string(&owhisper_interface::ControlMessage::KeepAlive) + .unwrap() + .into(), + )) + } + + fn finalize_message(&self) -> Message { + Message::Text( + serde_json::to_string(&owhisper_interface::ControlMessage::Finalize) + .unwrap() + .into(), + ) + } + + fn parse_response(&self, raw: &str) -> Option { + serde_json::from_str(raw).ok() + } +} + +#[cfg(test)] +mod tests { + use futures_util::StreamExt; + use hypr_audio_utils::AudioFormatExt; + + use crate::live::ListenClientInput; + use crate::ListenClient; + + #[tokio::test] + async fn test_client() { + let _ = tracing_subscriber::fmt::try_init(); + + let audio = rodio::Decoder::new(std::io::BufReader::new( + std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), + )) + .unwrap() + .to_i16_le_chunks(16000, 512); + + let input = Box::pin(tokio_stream::StreamExt::throttle( + audio.map(|chunk| ListenClientInput::Audio(chunk)), + std::time::Duration::from_millis(20), + )); + + let client = ListenClient::builder() + .api_base("https://api.deepgram.com/v1") + .api_key(std::env::var("DEEPGRAM_API_KEY").expect("DEEPGRAM_API_KEY not set")) + .params(owhisper_interface::ListenParams { + model: Some("nova-3".to_string()), + languages: vec![ + hypr_language::ISO639::En.into(), + hypr_language::ISO639::Es.into(), + ], + ..Default::default() + }) + .build_single(); + + let (stream, _) = client.from_realtime_audio(input).await.unwrap(); + futures_util::pin_mut!(stream); + + while let Some(result) = stream.next().await { + match result { + Ok(response) => match response { + owhisper_interface::stream::StreamResponse::TranscriptResponse { + channel, + .. + } => { + println!("{:?}", channel.alternatives.first().unwrap().transcript); + } + _ => {} + }, + _ => {} + } + } + } +} diff --git a/owhisper/owhisper-client/src/adapter/deepgram/mod.rs b/owhisper/owhisper-client/src/adapter/deepgram/mod.rs new file mode 100644 index 0000000000..9b2af93aec --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/deepgram/mod.rs @@ -0,0 +1,100 @@ +mod batch; +mod live; + +use owhisper_interface::ListenParams; +use url::form_urlencoded::Serializer; +use url::UrlQuery; + +const NOVA2_MULTI_LANGS: &[&str] = &["en", "es"]; +const NOVA3_MULTI_LANGS: &[&str] = &["en", "es", "fr", "de", "hi", "ru", "pt", "ja", "it", "nl"]; + +#[derive(Clone, Default)] +pub struct DeepgramAdapter; + +impl DeepgramAdapter { + pub(crate) fn listen_endpoint_url(api_base: &str) -> url::Url { + let mut url: url::Url = api_base.parse().expect("invalid_api_base"); + + let mut path = url.path().to_string(); + if !path.ends_with('/') { + path.push('/'); + } + path.push_str("listen"); + url.set_path(&path); + + url + } +} + +fn can_use_multi(model: &str, languages: &[hypr_language::Language]) -> bool { + if languages.len() < 2 { + return false; + } + + let multi_langs: &[&str] = if model.contains("nova-3") { + NOVA3_MULTI_LANGS + } else if model.contains("nova-2") { + NOVA2_MULTI_LANGS + } else { + return false; + }; + + languages + .iter() + .all(|lang| multi_langs.contains(&lang.iso639().code())) +} + +pub(crate) fn append_keyword_query<'a>( + query_pairs: &mut Serializer<'a, UrlQuery>, + params: &ListenParams, +) { + if params.keywords.is_empty() { + return; + } + + let use_keyterms = params + .model + .as_ref() + .map(|model| model.contains("nova-3") || model.contains("parakeet")) + .unwrap_or(false); + + let param_name = if use_keyterms { "keyterm" } else { "keywords" }; + + for keyword in ¶ms.keywords { + query_pairs.append_pair(param_name, keyword); + } +} + +pub(crate) fn append_language_query<'a>( + query_pairs: &mut Serializer<'a, UrlQuery>, + params: &ListenParams, +) { + let model = params.model.as_deref().unwrap_or(""); + + match params.languages.len() { + 0 => { + query_pairs.append_pair("detect_language", "true"); + } + 1 => { + if let Some(language) = params.languages.first() { + let code = language.iso639().code(); + query_pairs.append_pair("language", code); + } + } + _ => { + if can_use_multi(model, ¶ms.languages) { + query_pairs.append_pair("language", "multi"); + for language in ¶ms.languages { + let code = language.iso639().code(); + query_pairs.append_pair("languages", code); + } + } else { + query_pairs.append_pair("detect_language", "true"); + for language in ¶ms.languages { + let code = language.iso639().code(); + query_pairs.append_pair("languages", code); + } + } + } + } +} diff --git a/owhisper/owhisper-client/src/adapter/mod.rs b/owhisper/owhisper-client/src/adapter/mod.rs index 339e513829..56e5a9e1cd 100644 --- a/owhisper/owhisper-client/src/adapter/mod.rs +++ b/owhisper/owhisper-client/src/adapter/mod.rs @@ -20,7 +20,7 @@ use crate::error::Error; pub type BatchFuture<'a> = Pin> + Send + 'a>>; -pub trait SttAdapter: Clone + Default + Send + Sync + 'static { +pub trait RealtimeSttAdapter: Clone + Default + Send + Sync + 'static { fn supports_native_multichannel(&self) -> bool; fn build_ws_url(&self, api_base: &str, params: &ListenParams, channels: u8) -> url::Url; @@ -41,7 +41,9 @@ pub trait SttAdapter: Clone + Default + Send + Sync + 'static { } fn parse_response(&self, raw: &str) -> Option; +} +pub trait BatchSttAdapter: Clone + Default + Send + Sync + 'static { fn transcribe_file<'a, P: AsRef + Send + 'a>( &'a self, client: &'a reqwest::Client, diff --git a/owhisper/owhisper-client/src/adapter/soniox.rs b/owhisper/owhisper-client/src/adapter/soniox/batch.rs similarity index 56% rename from owhisper/owhisper-client/src/adapter/soniox.rs rename to owhisper/owhisper-client/src/adapter/soniox/batch.rs index a64b9ff082..3c77937337 100644 --- a/owhisper/owhisper-client/src/adapter/soniox.rs +++ b/owhisper/owhisper-client/src/adapter/soniox/batch.rs @@ -1,42 +1,21 @@ use std::path::Path; use std::time::Duration; -use hypr_ws::client::Message; use owhisper_interface::batch::{ Alternatives as BatchAlternatives, Channel as BatchChannel, Response as BatchResponse, Results as BatchResults, Word as BatchWord, }; -use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse, Word}; use owhisper_interface::ListenParams; use serde::{Deserialize, Serialize}; -use super::{BatchFuture, SttAdapter}; +use super::SonioxAdapter; +use crate::adapter::{BatchFuture, BatchSttAdapter}; use crate::error::Error; -const DEFAULT_API_BASE: &str = "https://api.soniox.com"; const POLL_INTERVAL: Duration = Duration::from_secs(2); const MAX_POLL_ATTEMPTS: u32 = 300; -#[derive(Clone, Default)] -pub struct SonioxAdapter; - impl SonioxAdapter { - fn language_hints(params: &ListenParams) -> Vec { - params - .languages - .iter() - .map(|lang| lang.iso639().code().to_string()) - .collect() - } - - fn api_base_url(api_base: &str) -> String { - if api_base.is_empty() { - DEFAULT_API_BASE.to_string() - } else { - api_base.trim_end_matches('/').to_string() - } - } - async fn upload_file( client: &reqwest::Client, api_base: &str, @@ -60,7 +39,7 @@ impl SonioxAdapter { let part = reqwest::multipart::Part::bytes(file_bytes).file_name(file_name); let form = reqwest::multipart::Form::new().part("file", part); - let url = format!("{}/v1/files", Self::api_base_url(api_base)); + let url = format!("https://{}/v1/files", Self::api_host(api_base)); let response = client .post(&url) .header("Authorization", format!("Bearer {}", api_key)) @@ -118,16 +97,22 @@ impl SonioxAdapter { }) }; + let language_hints = params + .languages + .iter() + .map(|lang| lang.iso639().code().to_string()) + .collect(); + let request = CreateTranscriptionRequest { model, file_id, - language_hints: Self::language_hints(params), + language_hints, enable_speaker_diarization: true, enable_language_identification: true, context, }; - let url = format!("{}/v1/transcriptions", Self::api_base_url(api_base)); + let url = format!("https://{}/v1/transcriptions", Self::api_host(api_base)); let response = client .post(&url) .header("Authorization", format!("Bearer {}", api_key)) @@ -164,8 +149,8 @@ impl SonioxAdapter { } let url = format!( - "{}/v1/transcriptions/{}", - Self::api_base_url(api_base), + "https://{}/v1/transcriptions/{}", + Self::api_host(api_base), transcription_id ); @@ -264,8 +249,8 @@ impl SonioxAdapter { } let url = format!( - "{}/v1/transcriptions/{}/transcript", - Self::api_base_url(api_base), + "https://{}/v1/transcriptions/{}/transcript", + Self::api_host(api_base), transcription_id ); @@ -340,238 +325,7 @@ impl SonioxAdapter { } } -impl SttAdapter for SonioxAdapter { - fn supports_native_multichannel(&self) -> bool { - true - } - - fn build_ws_url(&self, api_base: &str, _params: &ListenParams, _channels: u8) -> url::Url { - let mut url: url::Url = api_base.parse().expect("invalid api_base"); - - match url.scheme() { - "http" => { - let _ = url.set_scheme("ws"); - } - "https" => { - let _ = url.set_scheme("wss"); - } - "ws" | "wss" => {} - _ => { - let _ = url.set_scheme("wss"); - } - } - - url - } - - fn build_auth_header(&self, _api_key: Option<&str>) -> Option<(&'static str, String)> { - None - } - - fn keep_alive_message(&self) -> Option { - Some(Message::Text(r#"{"type":"keepalive"}"#.into())) - } - - fn initial_message( - &self, - api_key: Option<&str>, - params: &ListenParams, - channels: u8, - ) -> Option { - let api_key = match api_key { - Some(key) => key, - None => { - tracing::warn!("soniox_api_key_missing"); - return None; - } - }; - - #[derive(Serialize)] - struct Context { - #[serde(skip_serializing_if = "Vec::is_empty")] - terms: Vec, - } - - // https://soniox.com/docs/stt/api-reference/websocket-api#configuration - #[derive(Serialize)] - struct SonioxConfig<'a> { - api_key: &'a str, - model: &'a str, - audio_format: &'a str, - num_channels: u8, - sample_rate: u32, - #[serde(skip_serializing_if = "Vec::is_empty")] - language_hints: Vec, - include_nonfinal: bool, - enable_endpoint_detection: bool, - enable_speaker_diarization: bool, - #[serde(skip_serializing_if = "Option::is_none")] - context: Option, - } - - let model = params.model.as_deref().unwrap_or("stt-rt-preview"); - - let context = if params.keywords.is_empty() { - None - } else { - Some(Context { - terms: params.keywords.clone(), - }) - }; - - let cfg = SonioxConfig { - api_key, - model, - audio_format: "pcm_s16le", - num_channels: channels, - sample_rate: params.sample_rate, - language_hints: Self::language_hints(params), - include_nonfinal: true, - enable_endpoint_detection: true, - enable_speaker_diarization: true, - context, - }; - - let json = serde_json::to_string(&cfg).unwrap(); - Some(Message::Text(json.into())) - } - - fn parse_response(&self, raw: &str) -> Option { - #[derive(Deserialize)] - struct Token { - text: String, - #[serde(default)] - start_ms: Option, - #[serde(default)] - end_ms: Option, - #[serde(default)] - confidence: Option, - #[serde(default)] - is_final: Option, - #[serde(default)] - speaker: Option, - #[serde(default)] - channel: Option, - } - - #[derive(Deserialize)] - #[serde(untagged)] - enum SpeakerId { - Num(i32), - Str(String), - } - - impl SpeakerId { - fn as_i32(&self) -> Option { - match self { - SpeakerId::Num(n) => Some(*n), - SpeakerId::Str(s) => s - .trim_start_matches(|c: char| !c.is_ascii_digit()) - .parse() - .ok(), - } - } - } - - #[derive(Deserialize)] - struct SonioxMessage { - #[serde(default)] - tokens: Vec, - #[serde(default)] - finished: Option, - #[serde(default)] - error: Option, - } - - let msg: SonioxMessage = match serde_json::from_str(raw) { - Ok(m) => m, - Err(e) => { - tracing::warn!(error = ?e, raw = raw, "soniox_json_parse_failed"); - return None; - } - }; - - if let Some(error) = msg.error { - tracing::error!(error = error, "soniox_error"); - return None; - } - - let has_fin_token = msg.tokens.iter().any(|t| t.text == ""); - let is_finished = msg.finished.unwrap_or(false) || has_fin_token; - - let content_tokens: Vec<_> = msg - .tokens - .iter() - .filter(|t| t.text != "" && t.text != "") - .collect(); - - if content_tokens.is_empty() && !is_finished { - return None; - } - - let all_final = content_tokens.iter().all(|t| t.is_final.unwrap_or(true)); - - let mut words = Vec::with_capacity(content_tokens.len()); - let mut transcript = String::new(); - - let channel_index = content_tokens.first().and_then(|t| t.channel).unwrap_or(0) as i32; - - for t in &content_tokens { - if !transcript.is_empty() && !t.text.starts_with(|c: char| c.is_ascii_punctuation()) { - transcript.push(' '); - } - transcript.push_str(&t.text); - - let start_secs = t.start_ms.unwrap_or(0) as f64 / 1000.0; - let end_secs = t.end_ms.unwrap_or(0) as f64 / 1000.0; - let speaker = t.speaker.as_ref().and_then(|s| s.as_i32()); - - words.push(Word { - word: t.text.clone(), - start: start_secs, - end: end_secs, - confidence: t.confidence.unwrap_or(1.0), - speaker, - punctuated_word: Some(t.text.clone()), - language: None, - }); - } - - let (start, duration) = - if let (Some(first), Some(last)) = (content_tokens.first(), content_tokens.last()) { - let start_secs = first.start_ms.unwrap_or(0) as f64 / 1000.0; - let end_secs = last.end_ms.unwrap_or(0) as f64 / 1000.0; - (start_secs, end_secs - start_secs) - } else { - (0.0, 0.0) - }; - - let channel = Channel { - alternatives: vec![Alternatives { - transcript, - words, - confidence: 1.0, - languages: vec![], - }], - }; - - Some(StreamResponse::TranscriptResponse { - is_final: all_final || is_finished, - speech_final: is_finished, - from_finalize: has_fin_token, - start, - duration, - channel, - metadata: Metadata::default(), - // TODO - channel_index: vec![channel_index, 1], - }) - } - - fn finalize_message(&self) -> Message { - Message::Text(r#"{"type":"finalize"}"#.into()) - } - +impl BatchSttAdapter for SonioxAdapter { fn transcribe_file<'a, P: AsRef + Send + 'a>( &'a self, client: &'a reqwest::Client, diff --git a/owhisper/owhisper-client/src/adapter/soniox/live.rs b/owhisper/owhisper-client/src/adapter/soniox/live.rs new file mode 100644 index 0000000000..4fa9612b4c --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/soniox/live.rs @@ -0,0 +1,233 @@ +use hypr_ws::client::Message; +use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse, Word}; +use owhisper_interface::ListenParams; +use serde::{Deserialize, Serialize}; + +use super::SonioxAdapter; +use crate::adapter::RealtimeSttAdapter; + +// https://soniox.com/docs/stt/rt/real-time-transcription +// https://soniox.com/docs/stt/api-reference/websocket-api +impl RealtimeSttAdapter for SonioxAdapter { + fn supports_native_multichannel(&self) -> bool { + true + } + + fn build_ws_url(&self, api_base: &str, _params: &ListenParams, _channels: u8) -> url::Url { + format!("wss://{}/transcribe-websocket", Self::ws_host(api_base)) + .parse() + .expect("invalid_ws_url") + } + + fn build_auth_header(&self, _api_key: Option<&str>) -> Option<(&'static str, String)> { + None + } + + fn keep_alive_message(&self) -> Option { + Some(Message::Text(r#"{"type":"keepalive"}"#.into())) + } + + fn initial_message( + &self, + api_key: Option<&str>, + params: &ListenParams, + channels: u8, + ) -> Option { + let api_key = match api_key { + Some(key) => key, + None => { + tracing::warn!("soniox_api_key_missing"); + return None; + } + }; + + let model = params.model.as_deref().unwrap_or("stt-rt-preview"); + + let context = if params.keywords.is_empty() { + None + } else { + Some(Context { + terms: params.keywords.clone(), + }) + }; + + let language_hints = params + .languages + .iter() + .map(|lang| lang.iso639().code().to_string()) + .collect(); + + let cfg = SonioxConfig { + api_key, + model, + audio_format: "pcm_s16le", + num_channels: channels, + sample_rate: params.sample_rate, + language_hints, + enable_endpoint_detection: true, + enable_speaker_diarization: true, + context, + }; + + let json = serde_json::to_string(&cfg).unwrap(); + Some(Message::Text(json.into())) + } + + fn parse_response(&self, raw: &str) -> Option { + let msg: SonioxMessage = match serde_json::from_str(raw) { + Ok(m) => m, + Err(e) => { + tracing::warn!(error = ?e, raw = raw, "soniox_json_parse_failed"); + return None; + } + }; + + if let Some(error_msg) = &msg.error_message { + tracing::error!(error_code = ?msg.error_code, error_message = %error_msg, "soniox_error"); + return None; + } + + let has_fin_token = msg.tokens.iter().any(|t| t.text == ""); + let is_finished = msg.finished.unwrap_or(false) || has_fin_token; + + let content_tokens: Vec<_> = msg + .tokens + .iter() + .filter(|t| t.text != "" && t.text != "") + .collect(); + + if content_tokens.is_empty() && !is_finished { + return None; + } + + let all_final = content_tokens.iter().all(|t| t.is_final.unwrap_or(true)); + + let mut words = Vec::with_capacity(content_tokens.len()); + let mut transcript = String::new(); + + let channel_index = content_tokens.first().and_then(|t| t.channel).unwrap_or(0) as i32; + + for t in &content_tokens { + if !transcript.is_empty() && !t.text.starts_with(|c: char| c.is_ascii_punctuation()) { + transcript.push(' '); + } + transcript.push_str(&t.text); + + let start_secs = t.start_ms.unwrap_or(0) as f64 / 1000.0; + let end_secs = t.end_ms.unwrap_or(0) as f64 / 1000.0; + let speaker = t.speaker.as_ref().and_then(|s| s.as_i32()); + + words.push(Word { + word: t.text.clone(), + start: start_secs, + end: end_secs, + confidence: t.confidence.unwrap_or(1.0), + speaker, + punctuated_word: Some(t.text.clone()), + language: None, + }); + } + + let (start, duration) = + if let (Some(first), Some(last)) = (content_tokens.first(), content_tokens.last()) { + let start_secs = first.start_ms.unwrap_or(0) as f64 / 1000.0; + let end_secs = last.end_ms.unwrap_or(0) as f64 / 1000.0; + (start_secs, end_secs - start_secs) + } else { + (0.0, 0.0) + }; + + let channel = Channel { + alternatives: vec![Alternatives { + transcript, + words, + confidence: 1.0, + languages: vec![], + }], + }; + + Some(StreamResponse::TranscriptResponse { + is_final: all_final || is_finished, + speech_final: is_finished, + from_finalize: has_fin_token, + start, + duration, + channel, + metadata: Metadata::default(), + channel_index: vec![channel_index, 1], + }) + } + + fn finalize_message(&self) -> Message { + Message::Text(r#"{"type":"finalize"}"#.into()) + } +} + +#[derive(Serialize)] +struct Context { + #[serde(skip_serializing_if = "Vec::is_empty")] + terms: Vec, +} + +#[derive(Serialize)] +struct SonioxConfig<'a> { + api_key: &'a str, + model: &'a str, + audio_format: &'a str, + num_channels: u8, + sample_rate: u32, + #[serde(skip_serializing_if = "Vec::is_empty")] + language_hints: Vec, + enable_endpoint_detection: bool, + enable_speaker_diarization: bool, + #[serde(skip_serializing_if = "Option::is_none")] + context: Option, +} + +#[derive(Deserialize)] +struct Token { + text: String, + #[serde(default)] + start_ms: Option, + #[serde(default)] + end_ms: Option, + #[serde(default)] + confidence: Option, + #[serde(default)] + is_final: Option, + #[serde(default)] + speaker: Option, + #[serde(default)] + channel: Option, +} + +#[derive(Deserialize)] +#[serde(untagged)] +enum SpeakerId { + Num(i32), + Str(String), +} + +impl SpeakerId { + fn as_i32(&self) -> Option { + match self { + SpeakerId::Num(n) => Some(*n), + SpeakerId::Str(s) => s + .trim_start_matches(|c: char| !c.is_ascii_digit()) + .parse() + .ok(), + } + } +} + +#[derive(Deserialize)] +struct SonioxMessage { + #[serde(default)] + tokens: Vec, + #[serde(default)] + finished: Option, + #[serde(default)] + error_code: Option, + #[serde(default)] + error_message: Option, +} diff --git a/owhisper/owhisper-client/src/adapter/soniox/mod.rs b/owhisper/owhisper-client/src/adapter/soniox/mod.rs new file mode 100644 index 0000000000..bfe70f2b09 --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/soniox/mod.rs @@ -0,0 +1,29 @@ +mod batch; +mod live; + +pub(crate) const DEFAULT_API_HOST: &str = "api.soniox.com"; +pub(crate) const DEFAULT_WS_HOST: &str = "stt-rt.soniox.com"; + +#[derive(Clone, Default)] +pub struct SonioxAdapter; + +impl SonioxAdapter { + pub(crate) fn api_host(api_base: &str) -> String { + if api_base.is_empty() { + return DEFAULT_API_HOST.to_string(); + } + + let url: url::Url = api_base.parse().expect("invalid_api_base"); + url.host_str().unwrap_or(DEFAULT_API_HOST).to_string() + } + + pub(crate) fn ws_host(api_base: &str) -> String { + let api_host = Self::api_host(api_base); + + if let Some(rest) = api_host.strip_prefix("api.") { + format!("stt-rt.{}", rest) + } else { + DEFAULT_WS_HOST.to_string() + } + } +} diff --git a/owhisper/owhisper-client/src/batch.rs b/owhisper/owhisper-client/src/batch.rs index 08f75545cd..683d6fcb0d 100644 --- a/owhisper/owhisper-client/src/batch.rs +++ b/owhisper/owhisper-client/src/batch.rs @@ -4,12 +4,12 @@ use std::path::Path; use owhisper_interface::batch::Response as BatchResponse; use owhisper_interface::ListenParams; -use crate::adapter::SttAdapter; +use crate::adapter::BatchSttAdapter; use crate::error::Error; use crate::DeepgramAdapter; #[derive(Clone)] -pub struct BatchClient { +pub struct BatchClient { client: reqwest::Client, api_base: String, api_key: String, @@ -17,7 +17,7 @@ pub struct BatchClient { _marker: PhantomData, } -impl BatchClient { +impl BatchClient { pub fn new(api_base: String, api_key: String, params: ListenParams) -> Self { Self { client: reqwest::Client::new(), diff --git a/owhisper/owhisper-client/src/lib.rs b/owhisper/owhisper-client/src/lib.rs index 719337facd..da7ae50d01 100644 --- a/owhisper/owhisper-client/src/lib.rs +++ b/owhisper/owhisper-client/src/lib.rs @@ -5,13 +5,15 @@ mod live; use std::marker::PhantomData; -pub use adapter::{ArgmaxAdapter, DeepgramAdapter, SonioxAdapter, SttAdapter}; +pub use adapter::{ + ArgmaxAdapter, BatchSttAdapter, DeepgramAdapter, RealtimeSttAdapter, SonioxAdapter, +}; pub use batch::BatchClient; pub use error::Error; pub use hypr_ws; pub use live::{DualHandle, FinalizeHandle, ListenClient, ListenClientDual}; -pub struct ListenClientBuilder { +pub struct ListenClientBuilder { api_base: Option, api_key: Option, params: Option, @@ -29,7 +31,7 @@ impl Default for ListenClientBuilder { } } -impl ListenClientBuilder { +impl ListenClientBuilder { pub fn api_base(mut self, api_base: impl Into) -> Self { self.api_base = Some(api_base.into()); self @@ -45,7 +47,7 @@ impl ListenClientBuilder { self } - pub fn adapter(self) -> ListenClientBuilder { + pub fn adapter(self) -> ListenClientBuilder { ListenClientBuilder { api_base: self.api_base, api_key: self.api_key, @@ -91,12 +93,6 @@ impl ListenClientBuilder { } } - pub fn build_batch(self) -> BatchClient { - let params = self.get_params(); - let api_base = self.get_api_base().to_string(); - BatchClient::new(api_base, self.api_key.unwrap_or_default(), params) - } - pub fn build_single(self) -> ListenClient { self.build_with_channels(1) } @@ -119,3 +115,11 @@ impl ListenClientBuilder { } } } + +impl ListenClientBuilder { + pub fn build_batch(self) -> BatchClient { + let params = self.get_params(); + let api_base = self.get_api_base().to_string(); + BatchClient::new(api_base, self.api_key.unwrap_or_default(), params) + } +} diff --git a/owhisper/owhisper-client/src/live.rs b/owhisper/owhisper-client/src/live.rs index 43e3b3b1b3..bc6e78676e 100644 --- a/owhisper/owhisper-client/src/live.rs +++ b/owhisper/owhisper-client/src/live.rs @@ -9,20 +9,20 @@ use hypr_ws::client::{ use owhisper_interface::stream::StreamResponse; use owhisper_interface::{ControlMessage, MixedMessage}; -use crate::{DeepgramAdapter, ListenClientBuilder, SttAdapter}; +use crate::{DeepgramAdapter, ListenClientBuilder, RealtimeSttAdapter}; pub type ListenClientInput = MixedMessage; pub type ListenClientDualInput = MixedMessage<(bytes::Bytes, bytes::Bytes), ControlMessage>; #[derive(Clone)] -pub struct ListenClient { +pub struct ListenClient { pub(crate) adapter: A, pub(crate) request: ClientRequestBuilder, pub(crate) initial_message: Option, } #[derive(Clone)] -pub struct ListenClientDual { +pub struct ListenClientDual { pub(crate) adapter: A, pub(crate) request: ClientRequestBuilder, pub(crate) initial_message: Option, @@ -181,7 +181,7 @@ impl ListenClient { } } -impl ListenClient { +impl ListenClient { pub async fn from_realtime_audio( self, audio_stream: impl Stream + Send + Unpin + 'static, @@ -219,7 +219,7 @@ impl ListenClient { type DualOutputStream = Pin> + Send>>; -impl ListenClientDual { +impl ListenClientDual { pub async fn from_realtime_audio( self, stream: impl Stream + Send + Unpin + 'static, @@ -366,7 +366,7 @@ where futures_util::stream::select(mic_mapped, spk_mapped) } -fn websocket_client_with_keep_alive( +fn websocket_client_with_keep_alive( request: &ClientRequestBuilder, adapter: &A, ) -> WebSocketClient { @@ -379,7 +379,7 @@ fn websocket_client_with_keep_alive( client } -fn extract_finalize_text(adapter: &A) -> Utf8Bytes { +fn extract_finalize_text(adapter: &A) -> Utf8Bytes { match adapter.finalize_message() { Message::Text(text) => text, _ => r#"{"type":"Finalize"}"#.into(), diff --git a/plugins/listener/src/actors/listener.rs b/plugins/listener/src/actors/listener.rs index 7c50897ecb..611f1fdee8 100644 --- a/plugins/listener/src/actors/listener.rs +++ b/plugins/listener/src/actors/listener.rs @@ -4,7 +4,9 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use futures_util::StreamExt; use tokio::time::error::Elapsed; -use owhisper_client::{ArgmaxAdapter, DeepgramAdapter, FinalizeHandle, SonioxAdapter, SttAdapter}; +use owhisper_client::{ + ArgmaxAdapter, DeepgramAdapter, FinalizeHandle, RealtimeSttAdapter, SonioxAdapter, +}; use owhisper_interface::stream::{Extra, StreamResponse}; use owhisper_interface::{ControlMessage, MixedMessage}; use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, SupervisionEvent}; @@ -265,7 +267,7 @@ fn build_extra(args: &ListenerArgs) -> (f64, Extra) { (session_offset_secs, extra) } -async fn spawn_rx_task_single_with_adapter( +async fn spawn_rx_task_single_with_adapter( args: ListenerArgs, myself: ActorRef, ) -> Result< @@ -324,7 +326,7 @@ async fn spawn_rx_task_single_with_adapter( Ok((ChannelSender::Single(tx), rx_task, shutdown_tx)) } -async fn spawn_rx_task_dual_with_adapter( +async fn spawn_rx_task_dual_with_adapter( args: ListenerArgs, myself: ActorRef, ) -> Result<