diff --git a/plugins/listener/src/actors/listener.rs b/plugins/listener/src/actors/listener.rs index 20e30fd14f..83b3de573c 100644 --- a/plugins/listener/src/actors/listener.rs +++ b/plugins/listener/src/actors/listener.rs @@ -8,24 +8,35 @@ use owhisper_interface::{ControlMessage, MixedMessage, Word2}; use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, SupervisionEvent}; use tauri_specta::Event; -use crate::{manager::TranscriptManager, SessionEvent}; +use crate::{ + manager::{TranscriptManager, WordsByChannel}, + SessionEvent, +}; // Not too short to support non-realtime pipelines like whisper.cpp const LISTEN_STREAM_TIMEOUT: Duration = Duration::from_secs(15 * 60); pub enum ListenerMsg { Audio(Bytes, Bytes), + StreamResponse(owhisper_interface::StreamResponse), + StreamError(String), + StreamEnded, + StreamTimeout, + StreamStartFailed(String), } +#[derive(Clone)] pub struct ListenerArgs { pub app: tauri::AppHandle, pub session_id: String, pub languages: Vec, pub onboarding: bool, - pub session_start_ts_ms: u64, + pub partial_words_by_channel: WordsByChannel, } pub struct ListenerState { + pub args: ListenerArgs, + pub manager: TranscriptManager, tx: tokio::sync::mpsc::Sender>, rx_task: tokio::task::JoinHandle<()>, shutdown_tx: Option>, @@ -55,11 +66,24 @@ impl Actor for ListenerActor { tracing::info!("{:?}", r); } - let (tx, rx_task, shutdown_tx) = spawn_rx_task(args, myself).await.unwrap(); + let current_timestamp_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + let manager = TranscriptManager::builder() + .with_manager_offset(current_timestamp_ms) + .with_existing_partial_words(args.partial_words_by_channel.clone()) + .build(); + + let (tx, rx_task, shutdown_tx) = spawn_rx_task(args.clone(), myself).await?; + let state = ListenerState { + args, tx, rx_task, shutdown_tx: Some(shutdown_tx), + manager, }; Ok(state) @@ -79,7 +103,7 @@ impl Actor for ListenerActor { async fn handle( &self, - _myself: ActorRef, + myself: ActorRef, message: Self::Msg, state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { @@ -87,6 +111,81 @@ impl Actor for ListenerActor { ListenerMsg::Audio(mic, spk) => { let _ = state.tx.try_send(MixedMessage::Audio((mic, spk))); } + + ListenerMsg::StreamResponse(response) => { + let diff = state.manager.append(response); + + let partial_words_by_channel: HashMap> = diff + .partial_words + .iter() + .map(|(channel_idx, words)| { + ( + *channel_idx, + words + .iter() + .map(|w| Word2::from(w.clone())) + .collect::>(), + ) + }) + .collect(); + + SessionEvent::PartialWords { + words: partial_words_by_channel, + } + .emit(&state.args.app)?; + + let final_words_by_channel: HashMap> = diff + .final_words + .iter() + .map(|(channel_idx, words)| { + ( + *channel_idx, + words + .iter() + .map(|w| Word2::from(w.clone())) + .collect::>(), + ) + }) + .collect(); + + update_session( + &state.args.app, + &state.args.session_id, + final_words_by_channel + .clone() + .values() + .flatten() + .cloned() + .collect(), + ) + .await + .unwrap(); + + SessionEvent::FinalWords { + words: final_words_by_channel, + } + .emit(&state.args.app)?; + } + + ListenerMsg::StreamStartFailed(error) => { + tracing::error!("listen_ws_connect_failed: {}", error); + myself.stop(Some(format!("listen_ws_connect_failed: {}", error))); + } + + ListenerMsg::StreamError(error) => { + tracing::info!("listen_stream_error: {}", error); + myself.stop(None); + } + + ListenerMsg::StreamEnded => { + tracing::info!("listen_stream_ended"); + myself.stop(None); + } + + ListenerMsg::StreamTimeout => { + tracing::info!("listen_stream_timeout"); + myself.stop(None); + } } Ok(()) } @@ -125,8 +224,6 @@ async fn spawn_rx_task( let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>(); let app = args.app.clone(); - let session_id = args.session_id.clone(); - let session_start_ts_ms = args.session_start_ts_ms; let conn = { use tauri_plugin_local_stt::LocalSttPluginExt; @@ -149,17 +246,12 @@ async fn spawn_rx_task( let (listen_stream, handle) = match client.from_realtime_audio(outbound).await { Ok(res) => res, Err(e) => { - tracing::error!("listen_ws_connect_failed: {:?}", e); - myself.stop(Some(format!("listen_ws_connect_failed: {:?}", e))); + let _ = myself.send_message(ListenerMsg::StreamStartFailed(format!("{:?}", e))); return; } }; futures_util::pin_mut!(listen_stream); - let mut manager = TranscriptManager::builder() - .with_unix_timestamp(session_start_ts_ms) - .build(); - loop { tokio::select! { _ = &mut shutdown_rx => { @@ -169,82 +261,27 @@ async fn spawn_rx_task( result = tokio::time::timeout(LISTEN_STREAM_TIMEOUT, listen_stream.next()) => { match result { Ok(Some(Ok(response))) => { - let diff = manager.append(response.clone()); - - let partial_words_by_channel: HashMap> = diff - .partial_words - .iter() - .map(|(channel_idx, words)| { - ( - *channel_idx, - words - .iter() - .map(|w| Word2::from(w.clone())) - .collect::>(), - ) - }) - .collect(); - - SessionEvent::PartialWords { - words: partial_words_by_channel, - } - .emit(&app) - .unwrap(); - - let final_words_by_channel: HashMap> = diff - .final_words - .iter() - .map(|(channel_idx, words)| { - ( - *channel_idx, - words - .iter() - .map(|w| Word2::from(w.clone())) - .collect::>(), - ) - }) - .collect(); - - update_session( - &app, - &session_id, - final_words_by_channel - .clone() - .values() - .flatten() - .cloned() - .collect(), - ) - .await - .unwrap(); - - SessionEvent::FinalWords { - words: final_words_by_channel, - } - .emit(&app) - .unwrap(); + let _ = myself.send_message(ListenerMsg::StreamResponse(response)); } // Something went wrong while sending or receiving a websocket message. Should restart. Ok(Some(Err(e))) => { - tracing::info!("listen_stream_error: {:?}", e); + let _ = myself.send_message(ListenerMsg::StreamError(format!("{:?}", e))); break; } - // Stream ended gracefully. Safe to stop the whole session. + // Stream ended gracefully. Safe to stop the whole session. Ok(None) => { - tracing::info!("listen_stream_ended"); + let _ = myself.send_message(ListenerMsg::StreamEnded); break; } // We're not hearing back any transcript. Better to stop the whole session. Err(_) => { - tracing::info!("listen_stream_timeout"); + let _ = myself.send_message(ListenerMsg::StreamTimeout); break; } } } } } - - myself.stop(None); }); Ok((tx, rx_task, shutdown_tx)) diff --git a/plugins/listener/src/actors/session.rs b/plugins/listener/src/actors/session.rs index de5f3caade..720af6a819 100644 --- a/plugins/listener/src/actors/session.rs +++ b/plugins/listener/src/actors/session.rs @@ -9,8 +9,8 @@ use tokio_util::sync::CancellationToken; use crate::{ actors::{ - ListenerActor, ListenerArgs, ListenerMsg, ProcArgs, ProcMsg, ProcessorActor, RecArgs, - RecMsg, RecorderActor, SourceActor, SourceArgs, SourceMsg, + ListenerActor, ListenerArgs, ListenerMsg, ListenerState, ProcArgs, ProcMsg, ProcessorActor, + RecArgs, RecMsg, RecorderActor, SourceActor, SourceArgs, SourceMsg, }, SessionEvent, }; @@ -33,7 +33,6 @@ pub struct SessionArgs { pub struct SessionState { app: tauri::AppHandle, session_id: String, - session_start_ts_ms: u64, languages: Vec, onboarding: bool, token: CancellationToken, @@ -85,15 +84,9 @@ impl Actor for SessionActor { let _ = args.app.set_start_disabled(true); } - let session_start_ts_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis() as u64; - let state = SessionState { app: args.app, session_id, - session_start_ts_ms, languages, onboarding, token: cancellation_token, @@ -192,21 +185,30 @@ impl Actor for SessionActor { SupervisionEvent::ActorStarted(actor) => { tracing::info!("{:?}_actor_started", actor.get_name()); } - - SupervisionEvent::ActorFailed(actor, _) - | SupervisionEvent::ActorTerminated(actor, _, _) => { + SupervisionEvent::ActorTerminated(actor, maybe_state, _) => { let actor_name = actor .get_name() .map(|n| n.to_string()) .unwrap_or_else(|| "unknown".to_string()); if actor_name == ListenerActor::name() { - Self::start_listener(myself.get_cell(), state).await?; + let last_state: Option = + maybe_state.and_then(|mut s| s.take().ok()); + + Self::start_listener( + myself.get_cell(), + state, + last_state.map(|s| ListenerArgs { + partial_words_by_channel: s.manager.partial_words_by_channel, + ..s.args + }), + ) + .await?; } else { let _ = myself.stop_and_wait(None, None).await; } } - + SupervisionEvent::ActorFailed(_, _) => {} _ => {} } @@ -254,7 +256,7 @@ impl SessionActor { ) -> Result<(), ActorProcessingErr> { Self::start_processor(supervisor.clone(), state).await?; Self::start_source(supervisor.clone(), state).await?; - Self::start_listener(supervisor.clone(), state).await?; + Self::start_listener(supervisor.clone(), state, None).await?; if state.record_enabled { Self::start_recorder(supervisor, state).await?; @@ -359,18 +361,19 @@ impl SessionActor { async fn start_listener( supervisor: ActorCell, - state: &SessionState, + session_state: &SessionState, + listener_args: Option, ) -> Result, ActorProcessingErr> { let (listen_ref, _) = Actor::spawn_linked( Some(ListenerActor::name()), ListenerActor, - ListenerArgs { - app: state.app.clone(), - session_id: state.session_id.to_string(), - languages: state.languages.clone(), - onboarding: state.onboarding, - session_start_ts_ms: state.session_start_ts_ms, - }, + listener_args.unwrap_or(ListenerArgs { + app: session_state.app.clone(), + session_id: session_state.session_id.to_string(), + languages: session_state.languages.clone(), + onboarding: session_state.onboarding, + partial_words_by_channel: Default::default(), + }), supervisor, ) .await?; diff --git a/plugins/listener/src/manager.rs b/plugins/listener/src/manager.rs index ad77cb2d4e..1c610557b5 100644 --- a/plugins/listener/src/manager.rs +++ b/plugins/listener/src/manager.rs @@ -1,29 +1,38 @@ use std::collections::HashMap; +pub type WordsByChannel = HashMap>; + #[derive(Default)] pub struct TranscriptManagerBuilder { - session_start_timestamp_ms: Option, + manager_offset: Option, + partial_words_by_channel: Option, } impl TranscriptManagerBuilder { - pub fn with_unix_timestamp(mut self, session_start_timestamp_ms: u64) -> Self { - self.session_start_timestamp_ms = Some(session_start_timestamp_ms); + // unix timestamp in ms + pub fn with_manager_offset(mut self, manager_offset: u64) -> Self { + self.manager_offset = Some(manager_offset); + self + } + + pub fn with_existing_partial_words(mut self, m: impl Into) -> Self { + self.partial_words_by_channel = Some(m.into()); self } pub fn build(self) -> TranscriptManager { TranscriptManager { id: uuid::Uuid::new_v4(), - partial_words_by_channel: HashMap::new(), - session_start_timestamp_ms: self.session_start_timestamp_ms.unwrap_or(0), + partial_words_by_channel: self.partial_words_by_channel.unwrap_or_default(), + manager_offset: self.manager_offset.unwrap_or(0), } } } pub struct TranscriptManager { - id: uuid::Uuid, - partial_words_by_channel: HashMap>, - session_start_timestamp_ms: u64, + pub id: uuid::Uuid, + pub partial_words_by_channel: WordsByChannel, + pub manager_offset: u64, } impl TranscriptManager { @@ -110,8 +119,8 @@ impl TranscriptManager { w.speaker = Some(speaker); } - let start_ms = self.session_start_timestamp_ms as f64 + (w.start * 1000.0); - let end_ms = self.session_start_timestamp_ms as f64 + (w.end * 1000.0); + let start_ms = self.manager_offset as f64 + (w.start * 1000.0); + let end_ms = self.manager_offset as f64 + (w.end * 1000.0); w.start = start_ms / 1000.0; w.end = end_ms / 1000.0;