diff --git a/owhisper/owhisper-client/src/adapter/assemblyai/live.rs b/owhisper/owhisper-client/src/adapter/assemblyai/live.rs index 0bf255a156..ecbf4c2423 100644 --- a/owhisper/owhisper-client/src/adapter/assemblyai/live.rs +++ b/owhisper/owhisper-client/src/adapter/assemblyai/live.rs @@ -1,9 +1,10 @@ use hypr_ws::client::Message; -use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse, Word}; +use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse}; use owhisper_interface::ListenParams; use serde::Deserialize; use super::AssemblyAIAdapter; +use crate::adapter::parsing::{calculate_time_span, ms_to_secs, WordBuilder}; use crate::adapter::RealtimeSttAdapter; // https://www.assemblyai.com/docs/api-reference/streaming-api/streaming-api.md @@ -213,30 +214,20 @@ impl AssemblyAIAdapter { let speech_final = turn.end_of_turn; let from_finalize = false; - let words: Vec = turn + let words: Vec<_> = turn .words .iter() .map(|w| { - let start_secs = w.start as f64 / 1000.0; - let end_secs = w.end as f64 / 1000.0; - - Word { - word: w.text.clone(), - start: start_secs, - end: end_secs, - confidence: w.confidence, - speaker: None, - punctuated_word: Some(w.text.clone()), - language: turn.language_code.clone(), - } + WordBuilder::new(&w.text) + .start(ms_to_secs(w.start)) + .end(ms_to_secs(w.end)) + .confidence(w.confidence) + .language(turn.language_code.clone()) + .build() }) .collect(); - let (start, duration) = if let (Some(first), Some(last)) = (words.first(), words.last()) { - (first.start, last.end - first.start) - } else { - (0.0, 0.0) - }; + let (start, duration) = calculate_time_span(&words); let transcript = if turn.turn_is_formatted { turn.transcript.clone() diff --git a/owhisper/owhisper-client/src/adapter/fireworks/live.rs b/owhisper/owhisper-client/src/adapter/fireworks/live.rs index 943f8ab757..6beff59be6 100644 --- a/owhisper/owhisper-client/src/adapter/fireworks/live.rs +++ b/owhisper/owhisper-client/src/adapter/fireworks/live.rs @@ -1,9 +1,10 @@ use hypr_ws::client::Message; -use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse, Word}; +use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse}; use owhisper_interface::ListenParams; use serde::Deserialize; use super::FireworksAdapter; +use crate::adapter::parsing::WordBuilder; use crate::adapter::RealtimeSttAdapter; // https://docs.fireworks.ai/guides/querying-asr-models#streaming-transcription @@ -75,16 +76,15 @@ impl RealtimeSttAdapter for FireworksAdapter { let is_final = words_to_use.iter().all(|w| w.is_final); - let words: Vec = words_to_use + let words: Vec<_> = words_to_use .iter() - .map(|w| Word { - word: w.word.clone(), - start: w.start.unwrap_or(0.0), - end: w.end.unwrap_or(0.0), - confidence: w.probability.unwrap_or(1.0), - speaker: None, - punctuated_word: Some(w.word.clone()), - language: w.language.clone(), + .map(|w| { + WordBuilder::new(&w.word) + .start(w.start.unwrap_or(0.0)) + .end(w.end.unwrap_or(0.0)) + .confidence(w.probability.unwrap_or(1.0)) + .language(w.language.clone()) + .build() }) .collect(); @@ -121,17 +121,16 @@ impl RealtimeSttAdapter for FireworksAdapter { } else if !msg.text.is_empty() { let is_final = msg.words.iter().all(|w| w.is_final); - let words: Vec = msg + let words: Vec<_> = msg .words .iter() - .map(|w| Word { - word: w.word.clone(), - start: w.start.unwrap_or(0.0), - end: w.end.unwrap_or(0.0), - confidence: w.probability.unwrap_or(1.0), - speaker: None, - punctuated_word: Some(w.word.clone()), - language: w.language.clone(), + .map(|w| { + WordBuilder::new(&w.word) + .start(w.start.unwrap_or(0.0)) + .end(w.end.unwrap_or(0.0)) + .confidence(w.probability.unwrap_or(1.0)) + .language(w.language.clone()) + .build() }) .collect(); diff --git a/owhisper/owhisper-client/src/adapter/mod.rs b/owhisper/owhisper-client/src/adapter/mod.rs index ef7706e637..58cc275b51 100644 --- a/owhisper/owhisper-client/src/adapter/mod.rs +++ b/owhisper/owhisper-client/src/adapter/mod.rs @@ -4,6 +4,7 @@ mod deepgram; mod deepgram_compat; mod fireworks; mod owhisper; +pub mod parsing; mod soniox; pub use argmax::*; diff --git a/owhisper/owhisper-client/src/adapter/parsing.rs b/owhisper/owhisper-client/src/adapter/parsing.rs new file mode 100644 index 0000000000..9bc5835bf9 --- /dev/null +++ b/owhisper/owhisper-client/src/adapter/parsing.rs @@ -0,0 +1,187 @@ +use owhisper_interface::stream::Word; + +pub fn parse_speaker_id(value: &str) -> Option { + if let Ok(n) = value.parse::() { + return Some(n); + } + + value + .trim_start_matches(|c: char| !c.is_ascii_digit()) + .parse() + .ok() +} + +pub fn ms_to_secs(ms: u64) -> f64 { + ms as f64 / 1000.0 +} + +pub fn ms_to_secs_opt(ms: Option) -> f64 { + ms.map(ms_to_secs).unwrap_or(0.0) +} + +pub trait HasTimeSpan { + fn start_time(&self) -> f64; + fn end_time(&self) -> f64; +} + +impl HasTimeSpan for Word { + fn start_time(&self) -> f64 { + self.start + } + + fn end_time(&self) -> f64 { + self.end + } +} + +pub fn calculate_time_span(words: &[T]) -> (f64, f64) { + match (words.first(), words.last()) { + (Some(first), Some(last)) => { + let start = first.start_time(); + let end = last.end_time(); + (start, end - start) + } + _ => (0.0, 0.0), + } +} + +pub struct WordBuilder { + word: String, + start: f64, + end: f64, + confidence: f64, + speaker: Option, + punctuated_word: Option, + language: Option, +} + +impl WordBuilder { + pub fn new(word: impl Into) -> Self { + let word = word.into(); + Self { + punctuated_word: Some(word.clone()), + word, + start: 0.0, + end: 0.0, + confidence: 1.0, + speaker: None, + language: None, + } + } + + pub fn start(mut self, start: f64) -> Self { + self.start = start; + self + } + + pub fn end(mut self, end: f64) -> Self { + self.end = end; + self + } + + pub fn confidence(mut self, confidence: f64) -> Self { + self.confidence = confidence; + self + } + + pub fn speaker(mut self, speaker: Option) -> Self { + self.speaker = speaker; + self + } + + pub fn language(mut self, language: Option) -> Self { + self.language = language; + self + } + + pub fn build(self) -> Word { + Word { + word: self.word, + start: self.start, + end: self.end, + confidence: self.confidence, + speaker: self.speaker, + punctuated_word: self.punctuated_word, + language: self.language, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_speaker_id_numeric() { + assert_eq!(parse_speaker_id("0"), Some(0)); + assert_eq!(parse_speaker_id("1"), Some(1)); + assert_eq!(parse_speaker_id("42"), Some(42)); + } + + #[test] + fn test_parse_speaker_id_prefixed() { + assert_eq!(parse_speaker_id("SPEAKER_0"), Some(0)); + assert_eq!(parse_speaker_id("SPEAKER_1"), Some(1)); + assert_eq!(parse_speaker_id("speaker_2"), Some(2)); + } + + #[test] + fn test_parse_speaker_id_invalid() { + assert_eq!(parse_speaker_id(""), None); + assert_eq!(parse_speaker_id("abc"), None); + } + + #[test] + fn test_ms_to_secs() { + assert_eq!(ms_to_secs(0), 0.0); + assert_eq!(ms_to_secs(1000), 1.0); + assert_eq!(ms_to_secs(1500), 1.5); + } + + #[test] + fn test_ms_to_secs_opt() { + assert_eq!(ms_to_secs_opt(None), 0.0); + assert_eq!(ms_to_secs_opt(Some(1000)), 1.0); + assert_eq!(ms_to_secs_opt(Some(2500)), 2.5); + } + + #[test] + fn test_calculate_time_span_empty() { + let words: Vec = vec![]; + assert_eq!(calculate_time_span(&words), (0.0, 0.0)); + } + + #[test] + fn test_calculate_time_span_single() { + let words = vec![WordBuilder::new("hello").start(1.0).end(2.0).build()]; + assert_eq!(calculate_time_span(&words), (1.0, 1.0)); + } + + #[test] + fn test_calculate_time_span_multiple() { + let words = vec![ + WordBuilder::new("hello").start(1.0).end(2.0).build(), + WordBuilder::new("world").start(2.5).end(3.5).build(), + ]; + assert_eq!(calculate_time_span(&words), (1.0, 2.5)); + } + + #[test] + fn test_word_builder() { + let word = WordBuilder::new("test") + .start(1.5) + .end(2.5) + .confidence(0.95) + .speaker(Some(1)) + .language(Some("en".to_string())) + .build(); + + assert_eq!(word.word, "test"); + assert_eq!(word.start, 1.5); + assert_eq!(word.end, 2.5); + assert_eq!(word.confidence, 0.95); + assert_eq!(word.speaker, Some(1)); + assert_eq!(word.punctuated_word, Some("test".to_string())); + assert_eq!(word.language, Some("en".to_string())); + } +} diff --git a/owhisper/owhisper-client/src/adapter/soniox/live.rs b/owhisper/owhisper-client/src/adapter/soniox/live.rs index 202592dc43..ba29a9412a 100644 --- a/owhisper/owhisper-client/src/adapter/soniox/live.rs +++ b/owhisper/owhisper-client/src/adapter/soniox/live.rs @@ -1,9 +1,10 @@ use hypr_ws::client::Message; -use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse, Word}; +use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse}; use owhisper_interface::ListenParams; use serde::{Deserialize, Serialize}; use super::SonioxAdapter; +use crate::adapter::parsing::{ms_to_secs_opt, WordBuilder}; use crate::adapter::RealtimeSttAdapter; // https://soniox.com/docs/stt/rt/real-time-transcription @@ -229,24 +230,23 @@ impl SonioxAdapter { transcript.push_str(&t.text); - let start_secs = t.start_ms.unwrap_or(0) as f64 / 1000.0; - let end_secs = t.end_ms.unwrap_or(0) as f64 / 1000.0; + let start_secs = ms_to_secs_opt(t.start_ms); + let end_secs = ms_to_secs_opt(t.end_ms); let speaker = t.speaker.as_ref().and_then(|s| s.as_i32()); - words.push(Word { - word: t.text.clone(), - start: start_secs, - end: end_secs, - confidence: t.confidence.unwrap_or(1.0), - speaker, - punctuated_word: Some(t.text.clone()), - language: None, - }); + words.push( + WordBuilder::new(&t.text) + .start(start_secs) + .end(end_secs) + .confidence(t.confidence.unwrap_or(1.0)) + .speaker(speaker) + .build(), + ); } let (start, duration) = if let (Some(first), Some(last)) = (tokens.first(), tokens.last()) { - let start_secs = first.start_ms.unwrap_or(0) as f64 / 1000.0; - let end_secs = last.end_ms.unwrap_or(0) as f64 / 1000.0; + let start_secs = ms_to_secs_opt(first.start_ms); + let end_secs = ms_to_secs_opt(last.end_ms); (start_secs, end_secs - start_secs) } else { (0.0, 0.0)