Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 109 additions & 72 deletions plugins/listener/src/actors/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,35 @@ use owhisper_interface::{ControlMessage, MixedMessage, Word2};
use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, SupervisionEvent};
use tauri_specta::Event;

use crate::{manager::TranscriptManager, SessionEvent};
use crate::{
manager::{TranscriptManager, WordsByChannel},
SessionEvent,
};

// Not too short to support non-realtime pipelines like whisper.cpp
const LISTEN_STREAM_TIMEOUT: Duration = Duration::from_secs(15 * 60);

pub enum ListenerMsg {
Audio(Bytes, Bytes),
StreamResponse(owhisper_interface::StreamResponse),
StreamError(String),
StreamEnded,
StreamTimeout,
StreamStartFailed(String),
}

#[derive(Clone)]
pub struct ListenerArgs {
pub app: tauri::AppHandle,
pub session_id: String,
pub languages: Vec<hypr_language::Language>,
pub onboarding: bool,
pub session_start_ts_ms: u64,
pub partial_words_by_channel: WordsByChannel,
}

pub struct ListenerState {
pub args: ListenerArgs,
pub manager: TranscriptManager,
tx: tokio::sync::mpsc::Sender<MixedMessage<(Bytes, Bytes), ControlMessage>>,
rx_task: tokio::task::JoinHandle<()>,
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
Expand Down Expand Up @@ -55,11 +66,24 @@ impl Actor for ListenerActor {
tracing::info!("{:?}", r);
}

let (tx, rx_task, shutdown_tx) = spawn_rx_task(args, myself).await.unwrap();
let current_timestamp_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;

let manager = TranscriptManager::builder()
.with_manager_offset(current_timestamp_ms)
.with_existing_partial_words(args.partial_words_by_channel.clone())
.build();

let (tx, rx_task, shutdown_tx) = spawn_rx_task(args.clone(), myself).await?;

let state = ListenerState {
args,
tx,
rx_task,
shutdown_tx: Some(shutdown_tx),
manager,
};

Ok(state)
Expand All @@ -79,14 +103,89 @@ impl Actor for ListenerActor {

async fn handle(
&self,
_myself: ActorRef<Self::Msg>,
myself: ActorRef<Self::Msg>,
message: Self::Msg,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
match message {
ListenerMsg::Audio(mic, spk) => {
let _ = state.tx.try_send(MixedMessage::Audio((mic, spk)));
}

ListenerMsg::StreamResponse(response) => {
let diff = state.manager.append(response);

let partial_words_by_channel: HashMap<usize, Vec<Word2>> = diff
.partial_words
.iter()
.map(|(channel_idx, words)| {
(
*channel_idx,
words
.iter()
.map(|w| Word2::from(w.clone()))
.collect::<Vec<_>>(),
)
})
.collect();

SessionEvent::PartialWords {
words: partial_words_by_channel,
}
.emit(&state.args.app)?;

let final_words_by_channel: HashMap<usize, Vec<Word2>> = diff
.final_words
.iter()
.map(|(channel_idx, words)| {
(
*channel_idx,
words
.iter()
.map(|w| Word2::from(w.clone()))
.collect::<Vec<_>>(),
)
})
.collect();

update_session(
&state.args.app,
&state.args.session_id,
final_words_by_channel
.clone()
.values()
.flatten()
.cloned()
.collect(),
)
.await
.unwrap();

SessionEvent::FinalWords {
words: final_words_by_channel,
}
.emit(&state.args.app)?;
}

ListenerMsg::StreamStartFailed(error) => {
tracing::error!("listen_ws_connect_failed: {}", error);
myself.stop(Some(format!("listen_ws_connect_failed: {}", error)));
}

ListenerMsg::StreamError(error) => {
tracing::info!("listen_stream_error: {}", error);
myself.stop(None);
}

ListenerMsg::StreamEnded => {
tracing::info!("listen_stream_ended");
myself.stop(None);
}

ListenerMsg::StreamTimeout => {
tracing::info!("listen_stream_timeout");
myself.stop(None);
}
}
Ok(())
}
Expand Down Expand Up @@ -125,8 +224,6 @@ async fn spawn_rx_task(
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();

let app = args.app.clone();
let session_id = args.session_id.clone();
let session_start_ts_ms = args.session_start_ts_ms;

let conn = {
use tauri_plugin_local_stt::LocalSttPluginExt;
Expand All @@ -149,17 +246,12 @@ async fn spawn_rx_task(
let (listen_stream, handle) = match client.from_realtime_audio(outbound).await {
Ok(res) => res,
Err(e) => {
tracing::error!("listen_ws_connect_failed: {:?}", e);
myself.stop(Some(format!("listen_ws_connect_failed: {:?}", e)));
let _ = myself.send_message(ListenerMsg::StreamStartFailed(format!("{:?}", e)));
return;
}
};
futures_util::pin_mut!(listen_stream);

let mut manager = TranscriptManager::builder()
.with_unix_timestamp(session_start_ts_ms)
.build();

loop {
tokio::select! {
_ = &mut shutdown_rx => {
Expand All @@ -169,82 +261,27 @@ async fn spawn_rx_task(
result = tokio::time::timeout(LISTEN_STREAM_TIMEOUT, listen_stream.next()) => {
match result {
Ok(Some(Ok(response))) => {
let diff = manager.append(response.clone());

let partial_words_by_channel: HashMap<usize, Vec<Word2>> = diff
.partial_words
.iter()
.map(|(channel_idx, words)| {
(
*channel_idx,
words
.iter()
.map(|w| Word2::from(w.clone()))
.collect::<Vec<_>>(),
)
})
.collect();

SessionEvent::PartialWords {
words: partial_words_by_channel,
}
.emit(&app)
.unwrap();

let final_words_by_channel: HashMap<usize, Vec<Word2>> = diff
.final_words
.iter()
.map(|(channel_idx, words)| {
(
*channel_idx,
words
.iter()
.map(|w| Word2::from(w.clone()))
.collect::<Vec<_>>(),
)
})
.collect();

update_session(
&app,
&session_id,
final_words_by_channel
.clone()
.values()
.flatten()
.cloned()
.collect(),
)
.await
.unwrap();

SessionEvent::FinalWords {
words: final_words_by_channel,
}
.emit(&app)
.unwrap();
let _ = myself.send_message(ListenerMsg::StreamResponse(response));
}
// Something went wrong while sending or receiving a websocket message. Should restart.
Ok(Some(Err(e))) => {
tracing::info!("listen_stream_error: {:?}", e);
let _ = myself.send_message(ListenerMsg::StreamError(format!("{:?}", e)));
break;
}
// Stream ended gracefully. Safe to stop the whole session.
// Stream ended gracefully. Safe to stop the whole session.
Ok(None) => {
tracing::info!("listen_stream_ended");
let _ = myself.send_message(ListenerMsg::StreamEnded);
break;
}
// We're not hearing back any transcript. Better to stop the whole session.
Err(_) => {
tracing::info!("listen_stream_timeout");
let _ = myself.send_message(ListenerMsg::StreamTimeout);
break;
}
}
}
}
}

myself.stop(None);
});

Ok((tx, rx_task, shutdown_tx))
Expand Down
49 changes: 26 additions & 23 deletions plugins/listener/src/actors/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use tokio_util::sync::CancellationToken;

use crate::{
actors::{
ListenerActor, ListenerArgs, ListenerMsg, ProcArgs, ProcMsg, ProcessorActor, RecArgs,
RecMsg, RecorderActor, SourceActor, SourceArgs, SourceMsg,
ListenerActor, ListenerArgs, ListenerMsg, ListenerState, ProcArgs, ProcMsg, ProcessorActor,
RecArgs, RecMsg, RecorderActor, SourceActor, SourceArgs, SourceMsg,
},
SessionEvent,
};
Expand All @@ -33,7 +33,6 @@ pub struct SessionArgs {
pub struct SessionState {
app: tauri::AppHandle,
session_id: String,
session_start_ts_ms: u64,
languages: Vec<hypr_language::Language>,
onboarding: bool,
token: CancellationToken,
Expand Down Expand Up @@ -85,15 +84,9 @@ impl Actor for SessionActor {
let _ = args.app.set_start_disabled(true);
}

let session_start_ts_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;

let state = SessionState {
app: args.app,
session_id,
session_start_ts_ms,
languages,
onboarding,
token: cancellation_token,
Expand Down Expand Up @@ -192,21 +185,30 @@ impl Actor for SessionActor {
SupervisionEvent::ActorStarted(actor) => {
tracing::info!("{:?}_actor_started", actor.get_name());
}

SupervisionEvent::ActorFailed(actor, _)
| SupervisionEvent::ActorTerminated(actor, _, _) => {
SupervisionEvent::ActorTerminated(actor, maybe_state, _) => {
let actor_name = actor
.get_name()
.map(|n| n.to_string())
.unwrap_or_else(|| "unknown".to_string());

if actor_name == ListenerActor::name() {
Self::start_listener(myself.get_cell(), state).await?;
let last_state: Option<ListenerState> =
maybe_state.and_then(|mut s| s.take().ok());

Self::start_listener(
myself.get_cell(),
state,
last_state.map(|s| ListenerArgs {
partial_words_by_channel: s.manager.partial_words_by_channel,
..s.args
}),
)
.await?;
} else {
let _ = myself.stop_and_wait(None, None).await;
}
}

SupervisionEvent::ActorFailed(_, _) => {}
_ => {}
}

Expand Down Expand Up @@ -254,7 +256,7 @@ impl SessionActor {
) -> Result<(), ActorProcessingErr> {
Self::start_processor(supervisor.clone(), state).await?;
Self::start_source(supervisor.clone(), state).await?;
Self::start_listener(supervisor.clone(), state).await?;
Self::start_listener(supervisor.clone(), state, None).await?;

if state.record_enabled {
Self::start_recorder(supervisor, state).await?;
Expand Down Expand Up @@ -359,18 +361,19 @@ impl SessionActor {

async fn start_listener(
supervisor: ActorCell,
state: &SessionState,
session_state: &SessionState,
listener_args: Option<ListenerArgs>,
) -> Result<ActorRef<ListenerMsg>, ActorProcessingErr> {
let (listen_ref, _) = Actor::spawn_linked(
Some(ListenerActor::name()),
ListenerActor,
ListenerArgs {
app: state.app.clone(),
session_id: state.session_id.to_string(),
languages: state.languages.clone(),
onboarding: state.onboarding,
session_start_ts_ms: state.session_start_ts_ms,
},
listener_args.unwrap_or(ListenerArgs {
app: session_state.app.clone(),
session_id: session_state.session_id.to_string(),
languages: session_state.languages.clone(),
onboarding: session_state.onboarding,
partial_words_by_channel: Default::default(),
}),
supervisor,
)
.await?;
Expand Down
Loading
Loading