diff --git a/Cargo.lock b/Cargo.lock index 01300c6452..80529ce0e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -349,7 +349,7 @@ dependencies = [ "apalis-core", "async-stream", "chrono", - "cron", + "cron 0.15.0", "futures", "tower 0.5.2", ] @@ -1802,16 +1802,39 @@ dependencies = [ "serde_with", ] +[[package]] +name = "bon" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97493a391b4b18ee918675fb8663e53646fd09321c58b46afa04e8ce2499c869" +dependencies = [ + "bon-macros 2.3.0", + "rustversion", +] + [[package]] name = "bon" version = "3.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2529c31017402be841eb45892278a6c21a000c0a17643af326c73a73f83f0fb" dependencies = [ - "bon-macros", + "bon-macros 3.7.2", "rustversion", ] +[[package]] +name = "bon-macros" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a2af3eac944c12cdf4423eab70d310da0a8e5851a18ffb192c0a5e3f7ae1663" +dependencies = [ + "darling 0.20.11", + "ident_case", + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "bon-macros" version = "3.7.2" @@ -3009,6 +3032,17 @@ dependencies = [ "itertools 0.10.5", ] +[[package]] +name = "cron" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8c3e73077b4b4a6ab1ea5047c37c57aee77657bc8ecd6f29b0af082d0b0c07" +dependencies = [ + "chrono", + "nom 7.1.3", + "once_cell", +] + [[package]] name = "cron" version = "0.15.0" @@ -4542,18 +4576,6 @@ dependencies = [ "miniz_oxide", ] -[[package]] -name = "flume" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" -dependencies = [ - "futures-core", - "futures-sink", - "nanorand", - "spin 0.9.8", -] - [[package]] name = "fnv" version = "1.0.7" @@ -6279,7 +6301,7 @@ checksum = "7314c5dcd0feb905728aa809f46d10a58587be2bdd90f3003e09bcef05e919dc" dependencies = [ "async-trait", "base64 0.22.1", - "bon", + "bon 3.7.2", "google-cloud-gax", "http 1.3.1", "reqwest 0.12.23", @@ -8930,15 +8952,6 @@ dependencies = [ "url", ] -[[package]] -name = "nanorand" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" -dependencies = [ - "getrandom 0.2.16", -] - [[package]] name = "native-tls" version = "0.2.14" @@ -11137,6 +11150,43 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "ractor" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a03628f080f90360ed29f8a577b90ad9488820e561d33d22f34f241e58845d" +dependencies = [ + "bon 2.3.0", + "dashmap", + "futures", + "js-sys", + "once_cell", + "strum 0.26.3", + "tokio", + "tokio_with_wasm", + "tracing", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-time", +] + +[[package]] +name = "ractor_actors" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5142ea740eeb967f05caa66e911f58d2863d33a405c0810bfc9fd9b17a9d5b15" +dependencies = [ + "async-trait", + "chrono", + "cron 0.12.1", + "notify", + "ractor", + "tokio", + "tokio-rustls 0.26.2", + "tokio-stream", + "tracing", +] + [[package]] name = "rand" version = "0.7.3" @@ -11791,7 +11841,7 @@ dependencies = [ "cc", "libc", "once_cell", - "spin 0.5.2", + "spin", "untrusted 0.7.1", "web-sys", "winapi", @@ -13397,15 +13447,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" -dependencies = [ - "lock_api", -] - [[package]] name = "spki" version = "0.6.0" @@ -13456,27 +13497,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" -[[package]] -name = "statig" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42c467cc59664639bf70b8225b1b4a9c30d926f3e010c29e804bf940d618c663" -dependencies = [ - "statig_macro", -] - -[[package]] -name = "statig_macro" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4c61563b68df6e452ceece3fba1329c8c6a5d348fe17b0778fada28bc95fde" -dependencies = [ - "proc-macro-error", - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "std_prelude" version = "0.2.12" @@ -14407,7 +14427,6 @@ dependencies = [ "db-core", "db-user", "dirs 6.0.0", - "flume", "futures-util", "hound", "insta", @@ -14417,12 +14436,13 @@ dependencies = [ "ordered-float 5.0.0", "owhisper-client", "owhisper-interface", + "ractor", + "ractor_actors", "rodio", "serde", "serde_json", "specta", "specta-typescript", - "statig", "strum 0.26.3", "tauri", "tauri-plugin", @@ -14438,6 +14458,7 @@ dependencies = [ "thiserror 2.0.16", "tokio", "tokio-stream", + "tokio-util", "tracing", "url", "uuid", @@ -15619,6 +15640,30 @@ dependencies = [ "webpki-roots 0.26.11", ] +[[package]] +name = "tokio_with_wasm" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dfba9b946459940fb564dcf576631074cdfb0bfe4c962acd4c31f0dca7897e6" +dependencies = [ + "js-sys", + "tokio", + "tokio_with_wasm_proc", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "tokio_with_wasm_proc" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37e04c1865c281139e5ccf633cb9f76ffdaabeebfe53b703984cf82878e2aabb" +dependencies = [ + "quote", + "syn 2.0.106", +] + [[package]] name = "toml" version = "0.8.23" diff --git a/crates/aec/Cargo.toml b/crates/aec/Cargo.toml index aa954b4c4f..7247a894e4 100644 --- a/crates/aec/Cargo.toml +++ b/crates/aec/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [features] -default = ["512"] +default = ["128"] 128 = [] 256 = [] 512 = [] diff --git a/crates/audio/src/speaker/macos.rs b/crates/audio/src/speaker/macos.rs index a8ef68a075..7c471c06f8 100644 --- a/crates/audio/src/speaker/macos.rs +++ b/crates/audio/src/speaker/macos.rs @@ -6,15 +6,13 @@ use anyhow::Result; use futures_util::Stream; use ringbuf::{ - traits::{Consumer, Observer, Producer, Split}, + traits::{Consumer, Producer, Split}, HeapCons, HeapProd, HeapRb, }; use ca::aggregate_device_keys as agg_keys; use cidre::{arc, av, cat, cf, core_audio as ca, ns, os}; -// https://github.com/yury/cidre/blob/7bc6c3a/cidre/examples/core-audio-record/main.rs -// https://github.com/floneum/floneum/blob/50afe10/interfaces/kalosm-sound/src/source/mic.rs#L41 pub struct SpeakerInput { tap: ca::TapGuard, agg_desc: arc::Retained>, @@ -36,7 +34,7 @@ pub struct SpeakerStream { impl SpeakerStream { pub fn sample_rate(&self) -> u32 { - self.current_sample_rate.load(Ordering::Relaxed) + self.current_sample_rate.load(Ordering::Acquire) } } @@ -119,25 +117,21 @@ impl SpeakerInput { device .actual_sample_rate() .unwrap_or(ctx.format.absd().sample_rate) as u32, - Ordering::Relaxed, + Ordering::Release, ); - assert_eq!(ctx.format.common_format(), av::audio::CommonFormat::PcmF32); - if let Some(view) = av::AudioPcmBuf::with_buf_list_no_copy(&ctx.format, input_data, None) { if let Some(data) = view.data_f32_at(0) { process_audio_data(ctx, data); - } else { - tracing::warn!("macos_speaker_view_no_channel_0"); } - } else { + } else if ctx.format.common_format() == av::audio::CommonFormat::PcmF32 { let first_buffer = &input_data.buffers[0]; let byte_count = first_buffer.data_bytes_size as usize; let float_count = byte_count / std::mem::size_of::(); - if float_count > 0 { + if float_count > 0 && first_buffer.data != std::ptr::null_mut() { let data = unsafe { std::slice::from_raw_parts(first_buffer.data as *const f32, float_count) }; @@ -157,9 +151,11 @@ impl SpeakerInput { pub fn stream(self) -> SpeakerStream { let asbd = self.tap.asbd().unwrap(); + let format = av::AudioFormat::with_asbd(&asbd).unwrap(); - let rb = HeapRb::::new(1024 * 32); + let buffer_size = 1024 * 128; + let rb = HeapRb::::new(buffer_size); let (producer, consumer) = rb.split(); let waker_state = Arc::new(Mutex::new(WakerState { @@ -193,41 +189,32 @@ impl SpeakerInput { fn process_audio_data(ctx: &mut Ctx, data: &[f32]) { let buffer_size = data.len(); - let available_space = ctx.producer.vacant_len(); - - let buffer_fill_ratio = 1.0 - (available_space as f32 / ctx.producer.capacity().get() as f32); - if buffer_fill_ratio > 0.7 { - tracing::warn!(ratio = buffer_fill_ratio, "buffer_nearly_full",); - } - let pushed = ctx.producer.push_slice(data); if pushed < buffer_size { - let dropped = buffer_size - pushed; - let consecutive = ctx.consecutive_drops.fetch_add(1, Ordering::Relaxed) + 1; - - tracing::warn!( - dropped = dropped, - consecutive = consecutive, - "macos_speaker_dropped", - ); + let consecutive = ctx.consecutive_drops.fetch_add(1, Ordering::AcqRel) + 1; if consecutive > 10 { - ctx.should_terminate.store(true, Ordering::Relaxed); + ctx.should_terminate.store(true, Ordering::Release); return; } } else { - ctx.consecutive_drops.store(0, Ordering::Relaxed); + ctx.consecutive_drops.store(0, Ordering::Release); } if pushed > 0 { - let mut waker_state = ctx.waker_state.lock().unwrap(); - if !waker_state.has_data { - waker_state.has_data = true; - if let Some(waker) = waker_state.waker.take() { - drop(waker_state); - waker.wake(); + let should_wake = { + let mut waker_state = ctx.waker_state.lock().unwrap(); + if !waker_state.has_data { + waker_state.has_data = true; + waker_state.waker.take() + } else { + None } + }; + + if let Some(waker) = should_wake { + waker.wake(); } } } @@ -239,24 +226,29 @@ impl Stream for SpeakerStream { mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { - if self._ctx.should_terminate.load(Ordering::Relaxed) { - return Poll::Ready(None); - } - if let Some(sample) = self.consumer.try_pop() { return Poll::Ready(Some(sample)); } + if self._ctx.should_terminate.load(Ordering::Acquire) { + return match self.consumer.try_pop() { + Some(sample) => Poll::Ready(Some(sample)), + None => Poll::Ready(None), + }; + } + { let mut state = self.waker_state.lock().unwrap(); state.has_data = false; state.waker = Some(cx.waker().clone()); - drop(state); } - match self.consumer.try_pop() { - Some(sample) => Poll::Ready(Some(sample)), - None => Poll::Pending, - } + Poll::Pending + } +} + +impl Drop for SpeakerStream { + fn drop(&mut self) { + self._ctx.should_terminate.store(true, Ordering::Release); } } diff --git a/plugins/listener/Cargo.toml b/plugins/listener/Cargo.toml index ad7dd8674a..463bde6a67 100644 --- a/plugins/listener/Cargo.toml +++ b/plugins/listener/Cargo.toml @@ -60,12 +60,13 @@ uuid = { workspace = true, features = ["v4"] } futures-util = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tokio-stream = { workspace = true } +tokio-util = { workspace = true } tracing = { workspace = true } hound = { workspace = true } -flume = { workspace = true } -statig = { workspace = true, features = ["async"] } +ractor = "0.15" +ractor_actors = "0.4" [target."cfg(target_os = \"macos\")".dependencies] objc2 = { workspace = true } diff --git a/plugins/listener/src/actors/listen.rs b/plugins/listener/src/actors/listen.rs new file mode 100644 index 0000000000..717c5a40f5 --- /dev/null +++ b/plugins/listener/src/actors/listen.rs @@ -0,0 +1,203 @@ +use std::collections::HashMap; +use std::time::Duration; + +use bytes::Bytes; +use futures_util::StreamExt; + +use owhisper_interface::{ControlMessage, MixedMessage, Word2}; +use ractor::{Actor, ActorProcessingErr, ActorRef}; +use tauri_specta::Event; + +use crate::{manager::TranscriptManager, SessionEvent}; + +const LISTEN_STREAM_TIMEOUT: Duration = Duration::from_secs(60 * 15); + +pub enum ListenMsg { + Audio(Bytes, Bytes), +} + +pub struct ListenArgs { + pub app: tauri::AppHandle, + pub session_id: String, + pub languages: Vec, + pub onboarding: bool, + pub session_start_ts_ms: u64, +} + +pub struct ListenState { + tx: tokio::sync::mpsc::Sender>, + rx_task: tokio::task::JoinHandle<()>, +} + +pub struct ListenBridge; +impl Actor for ListenBridge { + type Msg = ListenMsg; + type State = ListenState; + type Arguments = ListenArgs; + + async fn pre_start( + &self, + myself: ActorRef, + args: Self::Arguments, + ) -> Result { + let (tx, rx) = + tokio::sync::mpsc::channel::>(32); + + let conn = { + use tauri_plugin_local_stt::LocalSttPluginExt; + + match args.app.get_connection().await { + Ok(c) => c, + Err(e) => { + tracing::error!("failed_to_get_connection: {:?}", e); + return Err(ActorProcessingErr::from(e)); + } + } + }; + + let client = owhisper_client::ListenClient::builder() + .api_base(conn.base_url) + .api_key(conn.api_key.unwrap_or_default()) + .params(owhisper_interface::ListenParams { + model: conn.model, + languages: args.languages, + redemption_time_ms: Some(if args.onboarding { 60 } else { 400 }), + ..Default::default() + }) + .build_dual(); + + let rx_task = tokio::spawn({ + let app = args.app.clone(); + let session_id = args.session_id.clone(); + + async move { + let outbound = tokio_stream::wrappers::ReceiverStream::new(rx); + let (listen_stream, _handle) = match client.from_realtime_audio(outbound).await { + Ok(res) => res, + Err(e) => { + tracing::error!("listen_ws_connect_failed: {:?}", e); + myself.stop(Some(format!("listen_ws_connect_failed: {:?}", e))); + return; + } + }; + futures_util::pin_mut!(listen_stream); + + let mut manager = TranscriptManager::with_unix_timestamp(args.session_start_ts_ms); + + loop { + match tokio::time::timeout(LISTEN_STREAM_TIMEOUT, listen_stream.next()).await { + Ok(Some(response)) => { + let diff = manager.append(response.clone()); + + let partial_words_by_channel: HashMap> = diff + .partial_words + .iter() + .map(|(channel_idx, words)| { + ( + *channel_idx, + words + .iter() + .map(|w| Word2::from(w.clone())) + .collect::>(), + ) + }) + .collect(); + + SessionEvent::PartialWords { + words: partial_words_by_channel, + } + .emit(&app) + .unwrap(); + + let final_words_by_channel: HashMap> = diff + .final_words + .iter() + .map(|(channel_idx, words)| { + ( + *channel_idx, + words + .iter() + .map(|w| Word2::from(w.clone())) + .collect::>(), + ) + }) + .collect(); + + update_session( + &app, + &session_id, + final_words_by_channel + .clone() + .values() + .flatten() + .cloned() + .collect(), + ) + .await + .unwrap(); + + SessionEvent::FinalWords { + words: final_words_by_channel, + } + .emit(&app) + .unwrap(); + } + Ok(None) => { + tracing::info!("listen_stream_ended"); + break; + } + Err(_) => { + tracing::info!("listen_stream_timeout"); + break; + } + } + } + + myself.stop(None); + } + }); + + Ok(ListenState { tx, rx_task }) + } + + async fn handle( + &self, + _myself: ActorRef, + message: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match message { + ListenMsg::Audio(mic, spk) => { + let _ = state.tx.try_send(MixedMessage::Audio((mic, spk))); + } + } + Ok(()) + } + + async fn post_stop( + &self, + _myself: ActorRef, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + state.rx_task.abort(); + Ok(()) + } +} + +async fn update_session( + app: &tauri::AppHandle, + session_id: impl Into, + words: Vec, +) -> Result, crate::Error> { + use tauri_plugin_db::DatabasePluginExt; + + let mut session = app + .db_get_session(session_id) + .await? + .ok_or(crate::Error::NoneSession)?; + + session.words.extend(words); + app.db_upsert_session(session.clone()).await.unwrap(); + + Ok(session.words) +} diff --git a/plugins/listener/src/actors/mod.rs b/plugins/listener/src/actors/mod.rs new file mode 100644 index 0000000000..6713b787c3 --- /dev/null +++ b/plugins/listener/src/actors/mod.rs @@ -0,0 +1,16 @@ +mod listen; +mod processor; +mod recorder; +mod session; +mod source; + +pub use listen::*; +pub use processor::*; +pub use recorder::*; +pub use session::*; +pub use source::*; + +#[derive(Clone)] +pub struct AudioChunk { + data: Vec, +} diff --git a/plugins/listener/src/actors/processor.rs b/plugins/listener/src/actors/processor.rs new file mode 100644 index 0000000000..cbadb5a94c --- /dev/null +++ b/plugins/listener/src/actors/processor.rs @@ -0,0 +1,169 @@ +use std::{ + collections::VecDeque, + sync::Arc, + time::{Duration, Instant}, +}; + +use ractor::{Actor, ActorProcessingErr, ActorRef}; +use tauri_specta::Event; + +use crate::{ + actors::{AudioChunk, ListenMsg, RecMsg}, + SessionEvent, +}; + +const AUDIO_AMPLITUDE_THROTTLE: Duration = Duration::from_millis(100); + +pub enum ProcMsg { + Mic(AudioChunk), + Spk(AudioChunk), + AttachRecorder(ActorRef), + AttachListen(ActorRef), +} + +pub struct ProcArgs { + pub app: tauri::AppHandle, + pub mixed_to: Option>, + pub rec_to: Option>, + pub listen_tx: Option>, +} + +pub struct ProcState { + app: tauri::AppHandle, + joiner: Joiner, + aec: hypr_aec::AEC, + agc_m: hypr_agc::Agc, + agc_s: hypr_agc::Agc, + last_amp: Instant, + recorder: Option>, + listen: Option>, + last_mic: Option>, + last_spk: Option>, +} + +pub struct AudioProcessor {} +impl Actor for AudioProcessor { + type Msg = ProcMsg; + type State = ProcState; + type Arguments = ProcArgs; + + async fn pre_start( + &self, + _myself: ActorRef, + args: Self::Arguments, + ) -> Result { + Ok(ProcState { + app: args.app.clone(), + joiner: Joiner::new(), + aec: hypr_aec::AEC::new().unwrap(), + agc_m: hypr_agc::Agc::default(), + agc_s: hypr_agc::Agc::default(), + last_amp: Instant::now(), + recorder: args.mixed_to.or(args.rec_to), + listen: args.listen_tx, + last_mic: None, + last_spk: None, + }) + } + + async fn handle( + &self, + _myself: ActorRef, + msg: Self::Msg, + st: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match msg { + ProcMsg::AttachRecorder(r) => st.recorder = Some(r), + ProcMsg::AttachListen(l) => st.listen = Some(l), + ProcMsg::Mic(mut c) => { + st.agc_m.process(&mut c.data); + let arc = Arc::<[f32]>::from(c.data); + st.last_mic = Some(arc.clone()); + st.joiner.push_mic(arc); + process_ready(st).await; + } + ProcMsg::Spk(mut c) => { + st.agc_s.process(&mut c.data); + let arc = Arc::<[f32]>::from(c.data); + st.last_spk = Some(arc.clone()); + st.joiner.push_spk(arc); + process_ready(st).await; + } + } + Ok(()) + } +} + +async fn process_ready(st: &mut ProcState) { + while let Some((mic, spk)) = st.joiner.pop_pair() { + let mic_out = st + .aec + .process_streaming(&mic, &spk) + .unwrap_or_else(|_| mic.to_vec()); + + if let Some(rec) = &st.recorder { + let mixed: Vec = mic_out + .iter() + .zip(spk.iter()) + .map(|(m, s)| (m + s).clamp(-1.0, 1.0)) + .collect(); + rec.cast(RecMsg::Audio(mixed)).ok(); + } + + if let Some(list) = &st.listen { + let mic_bytes = hypr_audio_utils::f32_to_i16_bytes(mic_out.into_iter()); + let spk_bytes = hypr_audio_utils::f32_to_i16_bytes(spk.iter().copied()); + list.cast(ListenMsg::Audio(mic_bytes.into(), spk_bytes.into())) + .ok(); + } + } + + if st.last_amp.elapsed() >= AUDIO_AMPLITUDE_THROTTLE { + if let (Some(mic_data), Some(spk_data)) = (&st.last_mic, &st.last_spk) { + if let Err(e) = SessionEvent::from((mic_data.as_ref(), spk_data.as_ref())).emit(&st.app) + { + tracing::error!("Failed to emit AudioAmplitude event: {:?}", e); + } + + st.last_amp = Instant::now(); + } + } +} + +struct Joiner { + mic: VecDeque>, + spk: VecDeque>, +} + +impl Joiner { + fn new() -> Self { + Self { + mic: VecDeque::new(), + spk: VecDeque::new(), + } + } + + fn push_mic(&mut self, data: Arc<[f32]>) { + self.mic.push_back(data); + if self.mic.len() > 10 { + self.mic.pop_front(); + } + } + + fn push_spk(&mut self, data: Arc<[f32]>) { + self.spk.push_back(data); + if self.spk.len() > 10 { + self.spk.pop_front(); + } + } + + fn pop_pair(&mut self) -> Option<(Arc<[f32]>, Arc<[f32]>)> { + if !self.mic.is_empty() && !self.spk.is_empty() { + let mic = self.mic.pop_front()?; + let spk = self.spk.pop_front()?; + Some((mic, spk)) + } else { + None + } + } +} diff --git a/plugins/listener/src/actors/recorder.rs b/plugins/listener/src/actors/recorder.rs new file mode 100644 index 0000000000..686378a1f5 --- /dev/null +++ b/plugins/listener/src/actors/recorder.rs @@ -0,0 +1,78 @@ +use std::path::PathBuf; + +use ractor::{Actor, ActorProcessingErr, ActorRef}; + +pub enum RecMsg { + Audio(Vec), +} + +pub struct RecArgs { + pub app_dir: PathBuf, + pub session_id: String, +} + +pub struct RecState { + writer: Option>>, +} + +pub struct Recorder; +impl Actor for Recorder { + type Msg = RecMsg; + type State = RecState; + type Arguments = RecArgs; + + async fn pre_start( + &self, + _myself: ActorRef, + args: Self::Arguments, + ) -> Result { + let dir = args.app_dir.join(&args.session_id); + std::fs::create_dir_all(&dir)?; + let path = dir.join("audio.wav"); + let spec = hound::WavSpec { + channels: 1, + sample_rate: 16000, + bits_per_sample: 32, + sample_format: hound::SampleFormat::Float, + }; + let writer = if path.exists() { + hound::WavWriter::append(path)? + } else { + hound::WavWriter::create(path, spec)? + }; + Ok(RecState { + writer: Some(writer), + }) + } + + async fn handle( + &self, + _myself: ActorRef, + msg: Self::Msg, + st: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match msg { + RecMsg::Audio(v) => { + if let Some(ref mut writer) = st.writer { + for s in v { + writer.write_sample(s)?; + } + } + } + } + + Ok(()) + } + + async fn post_stop( + &self, + _myself: ActorRef, + st: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + if let Some(writer) = st.writer.take() { + writer.finalize()?; + } + + Ok(()) + } +} diff --git a/plugins/listener/src/actors/session.rs b/plugins/listener/src/actors/session.rs new file mode 100644 index 0000000000..932ab371e1 --- /dev/null +++ b/plugins/listener/src/actors/session.rs @@ -0,0 +1,384 @@ +use ractor::{ + call_t, Actor, ActorCell, ActorProcessingErr, ActorRef, RpcReplyPort, SupervisionEvent, +}; +use tauri::Manager; +use tauri_specta::Event; +use tokio_util::sync::CancellationToken; + +use crate::{ + actors::{ + AudioProcessor, ListenArgs, ListenBridge, ListenMsg, ProcArgs, ProcMsg, RecArgs, RecMsg, + Recorder, SourceActor, SrcArgs, SrcCtrl, SrcWhich, + }, + fsm::State, + SessionEvent, +}; + +#[derive(Debug)] +pub enum SessionMsg { + Start { session_id: String }, + Stop, + SetMicMute(bool), + SetSpeakerMute(bool), + GetMicMute(RpcReplyPort), + GetSpeakerMute(RpcReplyPort), + GetMicDeviceName(RpcReplyPort>), + ChangeMicDevice(Option), + GetState(RpcReplyPort), +} + +pub struct SessionArgs { + pub app: tauri::AppHandle, +} + +pub struct SessionState { + app: tauri::AppHandle, + state: State, + session_id: Option, + session_start_ts_ms: Option, + + mic_source: Option>, + speaker_source: Option>, + processor: Option>, + recorder: Option>, + listen: Option>, + + record_enabled: bool, + languages: Vec, + onboarding: bool, + + token: CancellationToken, +} + +pub struct SessionSupervisor; + +impl Actor for SessionSupervisor { + type Msg = SessionMsg; + type State = SessionState; + type Arguments = SessionArgs; + + async fn pre_start( + &self, + _myself: ActorRef, + args: Self::Arguments, + ) -> Result { + Ok(SessionState { + app: args.app, + state: State::Inactive, + session_id: None, + session_start_ts_ms: None, + mic_source: None, + speaker_source: None, + processor: None, + recorder: None, + listen: None, + record_enabled: true, + languages: vec![], + onboarding: false, + token: CancellationToken::new(), + }) + } + + async fn handle( + &self, + myself: ActorRef, + message: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match message { + SessionMsg::Start { session_id } => { + if let State::RunningActive = state.state { + if let Some(current_id) = &state.session_id { + if current_id != &session_id { + self.stop_session(state).await?; + } else { + return Ok(()); + } + } + } + + self.start_session(myself.get_cell(), state, session_id) + .await?; + } + + SessionMsg::Stop => { + self.stop_session(state).await?; + } + + SessionMsg::SetMicMute(muted) => { + if let Some(mic) = &state.mic_source { + mic.cast(SrcCtrl::SetMute(muted))?; + } + SessionEvent::MicMuted { value: muted }.emit(&state.app)?; + } + + SessionMsg::SetSpeakerMute(muted) => { + if let Some(spk) = &state.speaker_source { + spk.cast(SrcCtrl::SetMute(muted))?; + } + SessionEvent::SpeakerMuted { value: muted }.emit(&state.app)?; + } + + SessionMsg::GetMicDeviceName(reply) => { + if !reply.is_closed() { + let device_name = if let Some(mic) = &state.mic_source { + call_t!(mic, SrcCtrl::GetDevice, 100).unwrap_or(None) + } else { + None + }; + + let _ = reply.send(device_name); + } + } + + SessionMsg::GetMicMute(reply) => { + let muted = if let Some(mic) = &state.mic_source { + call_t!(mic, SrcCtrl::GetMute, 100)? + } else { + false + }; + + if !reply.is_closed() { + let _ = reply.send(muted); + } + } + + SessionMsg::GetSpeakerMute(reply) => { + let muted = if let Some(spk) = &state.speaker_source { + call_t!(spk, SrcCtrl::GetMute, 100)? + } else { + false + }; + + if !reply.is_closed() { + let _ = reply.send(muted); + } + } + + SessionMsg::ChangeMicDevice(device) => { + if let Some(mic) = &state.mic_source { + mic.cast(SrcCtrl::SetDevice(device))?; + } + } + + SessionMsg::GetState(reply) => { + if !reply.is_closed() { + let _ = reply.send(state.state.clone()); + } + } + } + + Ok(()) + } + + async fn handle_supervisor_evt( + &self, + _myself: ActorRef, + event: SupervisionEvent, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match event { + SupervisionEvent::ActorStarted(actor) => { + tracing::info!("{:?}_actor_started", actor.get_name()); + } + + SupervisionEvent::ActorFailed(actor, _) => { + tracing::error!("{:?}_actor_failed", actor.get_name()); + self.stop_session(state).await?; + } + + SupervisionEvent::ActorTerminated(actor, _, exit_reason) => { + tracing::info!("{:?}_actor_terminated: {:?}", actor.get_name(), exit_reason); + + if matches!(state.state, State::RunningActive) { + self.stop_session(state).await?; + } + } + + _ => {} + } + + Ok(()) + } + + async fn post_stop( + &self, + _myself: ActorRef, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + self.stop_session(state).await?; + Ok(()) + } +} + +impl SessionSupervisor { + async fn start_session( + &self, + supervisor: ActorCell, + state: &mut SessionState, + session_id: String, + ) -> Result<(), ActorProcessingErr> { + use tauri_plugin_db::{DatabasePluginExt, UserDatabase}; + + let user_id = state.app.db_user_id().await?.unwrap(); + let onboarding_session_id = UserDatabase::onboarding_session_id(); + state.onboarding = session_id == onboarding_session_id; + + let config = state.app.db_get_config(&user_id).await?; + state.record_enabled = config + .as_ref() + .is_none_or(|c| c.general.save_recordings.unwrap_or(true)); + state.languages = config.as_ref().map_or_else( + || vec![hypr_language::ISO639::En.into()], + |c| c.general.spoken_languages.clone(), + ); + + state.session_id = Some(session_id.clone()); + state.session_start_ts_ms = Some( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64, + ); + + if let Ok(Some(mut session)) = state.app.db_get_session(&session_id).await { + session.record_start = Some(chrono::Utc::now()); + let _ = state.app.db_upsert_session(session).await; + } + + state.token = CancellationToken::new(); + + let (processor_ref, _) = Actor::spawn_linked( + Some("audio_processor".to_string()), + AudioProcessor {}, + ProcArgs { + app: state.app.clone(), + mixed_to: None, + rec_to: None, + listen_tx: None, + }, + supervisor.clone(), + ) + .await?; + state.processor = Some(processor_ref.clone()); + + let (mic_ref, _) = Actor::spawn_linked( + Some("mic_source".to_string()), + SourceActor, + SrcArgs { + which: SrcWhich::Mic { device: None }, + proc: processor_ref.clone(), + token: state.token.clone(), + }, + supervisor.clone(), + ) + .await?; + state.mic_source = Some(mic_ref.clone()); + + let (spk_ref, _) = Actor::spawn_linked( + Some("speaker_source".to_string()), + SourceActor, + SrcArgs { + which: SrcWhich::Speaker, + proc: processor_ref.clone(), + token: state.token.clone(), + }, + supervisor.clone(), + ) + .await?; + state.speaker_source = Some(spk_ref); + + if state.record_enabled { + let app_dir = state.app.path().app_data_dir().unwrap(); + let (rec_ref, _) = Actor::spawn_linked( + Some("recorder".to_string()), + Recorder, + RecArgs { + app_dir, + session_id: session_id.clone(), + }, + supervisor.clone(), + ) + .await?; + state.recorder = Some(rec_ref.clone()); + processor_ref.cast(ProcMsg::AttachRecorder(rec_ref))?; + } + + let (listen_ref, _) = Actor::spawn_linked( + Some("listen_bridge".to_string()), + ListenBridge, + ListenArgs { + app: state.app.clone(), + session_id: session_id.clone(), + languages: state.languages.clone(), + onboarding: state.onboarding, + session_start_ts_ms: state.session_start_ts_ms.unwrap_or(0), + }, + supervisor, + ) + .await?; + state.listen = Some(listen_ref.clone()); + processor_ref.cast(ProcMsg::AttachListen(listen_ref))?; + + { + use tauri_plugin_tray::TrayPluginExt; + let _ = state.app.set_start_disabled(true); + } + + state.state = State::RunningActive; + SessionEvent::RunningActive {}.emit(&state.app)?; + + Ok(()) + } + + async fn stop_session(&self, state: &mut SessionState) -> Result<(), ActorProcessingErr> { + if matches!(state.state, State::Inactive) { + return Ok(()); + } + + state.token.cancel(); + + if let Some(mic) = state.mic_source.take() { + mic.stop(None); + } + if let Some(spk) = state.speaker_source.take() { + spk.stop(None); + } + if let Some(proc) = state.processor.take() { + proc.stop(None); + } + if let Some(rec) = state.recorder.take() { + rec.stop(None); + } + if let Some(listen) = state.listen.take() { + listen.stop(None); + } + + if let Some(session_id) = &state.session_id { + use tauri_plugin_db::DatabasePluginExt; + + if let Ok(Some(mut session)) = state.app.db_get_session(session_id).await { + session.record_end = Some(chrono::Utc::now()); + let _ = state.app.db_upsert_session(session).await; + } + } + + { + use tauri_plugin_tray::TrayPluginExt; + let _ = state.app.set_start_disabled(false); + } + + { + use tauri_plugin_windows::{HyprWindow, WindowsPluginExt}; + let _ = state.app.window_hide(HyprWindow::Control); + } + + state.session_id = None; + state.session_start_ts_ms = None; + state.state = State::Inactive; + + SessionEvent::Inactive {}.emit(&state.app)?; + + Ok(()) + } +} diff --git a/plugins/listener/src/actors/source.rs b/plugins/listener/src/actors/source.rs new file mode 100644 index 0000000000..14dfdac8fe --- /dev/null +++ b/plugins/listener/src/actors/source.rs @@ -0,0 +1,202 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use futures_util::StreamExt; +use ractor::{Actor, ActorProcessingErr, ActorRef, RpcReplyPort}; +use tokio_util::sync::CancellationToken; + +use crate::actors::{AudioChunk, ProcMsg}; +use hypr_audio::{ + AudioInput, DeviceEvent, DeviceMonitor, DeviceMonitorHandle, ResampledAsyncSource, +}; + +const SAMPLE_RATE: u32 = 16000; + +pub enum SrcCtrl { + SetMute(bool), + GetMute(RpcReplyPort), + SetDevice(Option), + GetDevice(RpcReplyPort>), +} + +#[derive(Clone)] +pub enum SrcWhich { + Mic { device: Option }, + Speaker, +} + +pub struct SrcArgs { + pub which: SrcWhich, + pub proc: ActorRef, + pub token: CancellationToken, +} + +pub struct SrcState { + which: SrcWhich, + proc: ActorRef, + token: CancellationToken, + muted: Arc, + run_task: Option>, + _device_monitor_handle: Option, + _silence_stream_tx: Option>, +} + +pub struct SourceActor; +impl Actor for SourceActor { + type Msg = SrcCtrl; + type State = SrcState; + type Arguments = SrcArgs; + + async fn pre_start( + &self, + myself: ActorRef, + args: Self::Arguments, + ) -> Result { + let device_monitor_handle = if matches!(args.which, SrcWhich::Mic { .. }) { + let (event_tx, event_rx) = std::sync::mpsc::channel(); + let device_monitor_handle = DeviceMonitor::spawn(event_tx); + + let myself_clone = myself.clone(); + std::thread::spawn(move || { + while let Ok(event) = event_rx.recv() { + if let DeviceEvent::DefaultInputChanged { .. } = event { + let new_device = AudioInput::get_default_mic_device_name(); + let _ = myself_clone.cast(SrcCtrl::SetDevice(Some(new_device))); + } + } + }); + + Some(device_monitor_handle) + } else { + None + }; + + let silence_stream_tx = if matches!(args.which, SrcWhich::Speaker) { + Some(hypr_audio::AudioOutput::silence()) + } else { + None + }; + + let mut st = SrcState { + which: args.which, + proc: args.proc, + token: args.token, + muted: Arc::new(AtomicBool::new(false)), + run_task: None, + _device_monitor_handle: device_monitor_handle, + _silence_stream_tx: silence_stream_tx, + }; + + start_source_loop(&myself, &mut st).await?; + Ok(st) + } + + async fn handle( + &self, + myself: ActorRef, + msg: Self::Msg, + st: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match (msg, &mut st.which) { + (SrcCtrl::SetMute(muted), _) => { + st.muted.store(muted, Ordering::Relaxed); + } + (SrcCtrl::GetMute(reply), _) => { + if !reply.is_closed() { + let _ = reply.send(st.muted.load(Ordering::Relaxed)); + } + } + (SrcCtrl::GetDevice(reply), _) => { + if !reply.is_closed() { + let device = match &st.which { + SrcWhich::Mic { device } => device.clone(), + SrcWhich::Speaker => None, + }; + let _ = reply.send(device); + } + } + (SrcCtrl::SetDevice(dev), SrcWhich::Mic { device }) => { + *device = dev; + if let Some(t) = st.run_task.take() { + t.abort(); + } + start_source_loop(&myself, st).await?; + } + _ => {} + } + + Ok(()) + } + + async fn post_stop( + &self, + _myself: ActorRef, + st: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + if let Some(task) = st.run_task.take() { + task.abort(); + } + + st._silence_stream_tx = None; + + Ok(()) + } +} + +async fn start_source_loop( + myself: &ActorRef, + st: &mut SrcState, +) -> Result<(), ActorProcessingErr> { + let myself2 = myself.clone(); + + let proc = st.proc.clone(); + let token = st.token.clone(); + let which = st.which.clone(); + let muted = st.muted.clone(); + + let handle = tokio::spawn(async move { + loop { + let stream = match &which { + SrcWhich::Mic { device } => { + let mut input = hypr_audio::AudioInput::from_mic(device.clone()).unwrap(); + + ResampledAsyncSource::new(input.stream(), SAMPLE_RATE) + .chunks(hypr_aec::BLOCK_SIZE) + } + SrcWhich::Speaker => { + let input = hypr_audio::AudioInput::from_speaker().stream(); + ResampledAsyncSource::new(input, SAMPLE_RATE).chunks(hypr_aec::BLOCK_SIZE) + } + }; + tokio::pin!(stream); + + loop { + tokio::select! { + _ = token.cancelled() => { myself2.stop(None); return (); } + next = stream.next() => { + if let Some(data) = next { + let output_data = if muted.load(Ordering::Relaxed) { + vec![0.0; data.len()] + } else { + data + }; + + let msg = match &which { + SrcWhich::Mic {..} => ProcMsg::Mic(AudioChunk{ data: output_data }), + SrcWhich::Speaker => ProcMsg::Spk(AudioChunk{ data: output_data }), + }; + let _ = proc.cast(msg); + } else { + break; + } + } + } + } + tokio::time::sleep(Duration::from_millis(200)).await; + } + }); + + st.run_task = Some(handle); + Ok(()) +} diff --git a/plugins/listener/src/commands.rs b/plugins/listener/src/commands.rs index 99435c4ad0..fbe6e267d5 100644 --- a/plugins/listener/src/commands.rs +++ b/plugins/listener/src/commands.rs @@ -132,20 +132,14 @@ pub async fn start_session( session_id: String, ) -> Result<(), String> { app.start_session(session_id).await; - match app.get_state().await { - crate::fsm::State::RunningActive { .. } => Ok(()), - _ => Err(crate::Error::StartSessionFailed.to_string()), - } + Ok(()) } #[tauri::command] #[specta::specta] pub async fn stop_session(app: tauri::AppHandle) -> Result<(), String> { app.stop_session().await; - match app.get_state().await { - crate::fsm::State::Inactive { .. } => Ok(()), - _ => Err(crate::Error::StopSessionFailed.to_string()), - } + Ok(()) } #[tauri::command] diff --git a/plugins/listener/src/ext.rs b/plugins/listener/src/ext.rs index 6cfc3f74d7..323546bb50 100644 --- a/plugins/listener/src/ext.rs +++ b/plugins/listener/src/ext.rs @@ -1,6 +1,7 @@ use std::future::Future; use futures_util::StreamExt; +use ractor::call_t; #[cfg(target_os = "macos")] use { @@ -8,6 +9,8 @@ use { objc2_foundation::NSString, }; +use crate::actors::SessionMsg; + pub trait ListenerPluginExt { fn list_microphone_devices(&self) -> impl Future, crate::Error>>; fn get_current_microphone_device( @@ -44,8 +47,16 @@ impl> ListenerPluginExt for T { #[tracing::instrument(skip_all)] async fn get_current_microphone_device(&self) -> Result, crate::Error> { let state = self.state::(); - let s = state.lock().await; - Ok(s.fsm.get_current_mic_device()) + let guard = state.lock().await; + + if let Some(supervisor) = &guard.supervisor { + match call_t!(supervisor, SessionMsg::GetMicDeviceName, 100) { + Ok(device_name) => Ok(device_name), + Err(_) => Ok(None), + } + } else { + Ok(None) + } } #[tracing::instrument(skip_all)] @@ -54,11 +65,10 @@ impl> ListenerPluginExt for T { device_name: impl Into, ) -> Result<(), crate::Error> { let state = self.state::(); + let guard = state.lock().await; - { - let mut guard = state.lock().await; - let event = crate::fsm::StateEvent::MicChange(Some(device_name.into())); - guard.fsm.handle(&event).await; + if let Some(supervisor) = &guard.supervisor { + let _ = supervisor.cast(SessionMsg::ChangeMicDevice(Some(device_name.into()))); } Ok(()) @@ -178,70 +188,78 @@ impl> ListenerPluginExt for T { async fn get_state(&self) -> crate::fsm::State { let state = self.state::(); let guard = state.lock().await; - guard.fsm.state().clone() + guard.get_state().await } #[tracing::instrument(skip_all)] async fn get_mic_muted(&self) -> bool { let state = self.state::(); + let guard = state.lock().await; - { - let guard = state.lock().await; - guard.fsm.is_mic_muted() + if let Some(supervisor) = &guard.supervisor { + match call_t!(supervisor, SessionMsg::GetMicMute, 100) { + Ok(muted) => muted, + Err(_) => false, + } + } else { + false } } #[tracing::instrument(skip_all)] async fn get_speaker_muted(&self) -> bool { let state = self.state::(); + let guard = state.lock().await; - { - let guard = state.lock().await; - guard.fsm.is_speaker_muted() + if let Some(supervisor) = &guard.supervisor { + match call_t!(supervisor, SessionMsg::GetSpeakerMute, 100) { + Ok(muted) => muted, + Err(_) => false, + } + } else { + false } } #[tracing::instrument(skip_all)] async fn set_mic_muted(&self, muted: bool) { let state = self.state::(); + let guard = state.lock().await; - { - let mut guard = state.lock().await; - let event = crate::fsm::StateEvent::MicMuted(muted); - guard.fsm.handle(&event).await; + if let Some(supervisor) = &guard.supervisor { + let _ = supervisor.cast(SessionMsg::SetMicMute(muted)); } } #[tracing::instrument(skip_all)] async fn set_speaker_muted(&self, muted: bool) { let state = self.state::(); + let guard = state.lock().await; - { - let mut guard = state.lock().await; - let event = crate::fsm::StateEvent::SpeakerMuted(muted); - guard.fsm.handle(&event).await; + if let Some(supervisor) = &guard.supervisor { + let _ = supervisor.cast(SessionMsg::SetSpeakerMute(muted)); } } #[tracing::instrument(skip_all)] async fn start_session(&self, session_id: impl Into) { let state = self.state::(); + let guard = state.lock().await; - { - let mut guard = state.lock().await; - let event = crate::fsm::StateEvent::Start(session_id.into()); - guard.fsm.handle(&event).await; + if let Some(supervisor) = &guard.supervisor { + let _ = supervisor.cast(SessionMsg::Start { + session_id: session_id.into(), + }); } } #[tracing::instrument(skip_all)] async fn stop_session(&self) { let state = self.state::(); + let guard = state.lock().await; - { - let mut guard = state.lock().await; - let event = crate::fsm::StateEvent::Stop; - guard.fsm.handle(&event).await; + if let Some(supervisor) = &guard.supervisor { + let _ = supervisor.cast(SessionMsg::Stop); } } } diff --git a/plugins/listener/src/fsm.rs b/plugins/listener/src/fsm.rs index 9998773010..9380383c01 100644 --- a/plugins/listener/src/fsm.rs +++ b/plugins/listener/src/fsm.rs @@ -1,789 +1,7 @@ -use statig::prelude::*; - -use std::collections::HashMap; -use std::time::{Duration, Instant}; - -use tauri::Manager; -use tauri_specta::Event; - -use futures_util::StreamExt; -use tokio::task::JoinSet; - -use hypr_audio::ResampledAsyncSource; - -use crate::{manager::TranscriptManager, SessionEvent}; - -const SAMPLE_RATE: u32 = 16000; -const AUDIO_AMPLITUDE_THROTTLE: Duration = Duration::from_millis(100); -const LISTEN_STREAM_TIMEOUT: Duration = Duration::from_secs(60 * 15); - -const WAV_SPEC: hound::WavSpec = hound::WavSpec { - channels: 1, - sample_rate: SAMPLE_RATE, - bits_per_sample: 32, - sample_format: hound::SampleFormat::Float, -}; - -struct AudioSaver; - -impl AudioSaver { - async fn save_to_wav( - rx: flume::Receiver>, - session_id: &str, - app_dir: &std::path::Path, - filename: &str, - append: bool, - ) -> Result<(), Box> { - let dir = app_dir.join(session_id); - std::fs::create_dir_all(&dir)?; - let path = dir.join(filename); - - let mut wav = if append && path.exists() { - hound::WavWriter::append(path)? - } else { - hound::WavWriter::create(path, WAV_SPEC)? - }; - - while let Ok(chunk) = rx.recv_async().await { - for sample in chunk { - wav.write_sample(sample)?; - } - } - - wav.finalize()?; - Ok(()) - } -} - -struct AudioChannels { - mic_tx: flume::Sender>, - mic_rx: flume::Receiver>, - speaker_tx: flume::Sender>, - speaker_rx: flume::Receiver>, - save_mixed_tx: flume::Sender>, - save_mixed_rx: flume::Receiver>, - save_mic_raw_tx: Option>>, - save_mic_raw_rx: Option>>, - save_speaker_raw_tx: Option>>, - save_speaker_raw_rx: Option>>, - process_mic_tx: flume::Sender>, - process_mic_rx: flume::Receiver>, - process_speaker_tx: flume::Sender>, - process_speaker_rx: flume::Receiver>, -} - -impl AudioChannels { - fn new() -> Self { - const CHUNK_BUFFER_SIZE: usize = 64; - - let (mic_tx, mic_rx) = flume::bounded::>(CHUNK_BUFFER_SIZE); - let (speaker_tx, speaker_rx) = flume::bounded::>(CHUNK_BUFFER_SIZE); - let (save_mixed_tx, save_mixed_rx) = flume::bounded::>(CHUNK_BUFFER_SIZE); - let (process_mic_tx, process_mic_rx) = flume::bounded::>(CHUNK_BUFFER_SIZE); - let (process_speaker_tx, process_speaker_rx) = - flume::bounded::>(CHUNK_BUFFER_SIZE); - - let (save_mic_raw_tx, save_mic_raw_rx) = if cfg!(debug_assertions) { - let (tx, rx) = flume::bounded::>(CHUNK_BUFFER_SIZE); - (Some(tx), Some(rx)) - } else { - (None, None) - }; - - let (save_speaker_raw_tx, save_speaker_raw_rx) = if cfg!(debug_assertions) { - let (tx, rx) = flume::bounded::>(CHUNK_BUFFER_SIZE); - (Some(tx), Some(rx)) - } else { - (None, None) - }; - - Self { - mic_tx, - mic_rx, - speaker_tx, - speaker_rx, - save_mixed_tx, - save_mixed_rx, - save_mic_raw_tx, - save_mic_raw_rx, - save_speaker_raw_tx, - save_speaker_raw_rx, - process_mic_tx, - process_mic_rx, - process_speaker_tx, - process_speaker_rx, - } - } - - async fn process_mic_stream( - mut mic_stream: impl futures_util::Stream> + Unpin, - mic_muted_rx: tokio::sync::watch::Receiver, - mic_tx: flume::Sender>, - ) { - let mut is_muted = *mic_muted_rx.borrow(); - let watch_rx = mic_muted_rx.clone(); - - while let Some(actual) = mic_stream.next().await { - if watch_rx.has_changed().unwrap_or(false) { - is_muted = *watch_rx.borrow(); - } - - let maybe_muted = if is_muted { - vec![0.0; actual.len()] - } else { - actual - }; - - if let Err(e) = mic_tx.send_async(maybe_muted).await { - tracing::error!("mic_tx_send_error: {:?}", e); - break; - } - } - } - - async fn process_speaker_stream( - mut speaker_stream: impl futures_util::Stream> + Unpin, - speaker_muted_rx: tokio::sync::watch::Receiver, - speaker_tx: flume::Sender>, - ) { - let mut is_muted = *speaker_muted_rx.borrow(); - let watch_rx = speaker_muted_rx.clone(); - - while let Some(actual) = speaker_stream.next().await { - if watch_rx.has_changed().unwrap_or(false) { - is_muted = *watch_rx.borrow(); - } - - let maybe_muted = if is_muted { - vec![0.0; actual.len()] - } else { - actual - }; - - if let Err(e) = speaker_tx.send_async(maybe_muted).await { - tracing::error!("speaker_tx_send_error: {:?}", e); - break; - } - } - } -} - -pub struct Session { - app: tauri::AppHandle, - session_id: Option, - mic_device_name: Option, - mic_muted_tx: Option>, - mic_muted_rx: Option>, - speaker_muted_tx: Option>, - speaker_muted_rx: Option>, - silence_stream_tx: Option>, - session_state_tx: Option>, - tasks: Option>, - session_start_timestamp_ms: Option, -} - -impl Session { - pub fn new(app: tauri::AppHandle) -> Self { - let mic_device_name = hypr_audio::AudioInput::get_default_mic_device_name(); - - Self { - app, - session_id: None, - mic_device_name: Some(mic_device_name), - mic_muted_tx: None, - mic_muted_rx: None, - speaker_muted_tx: None, - speaker_muted_rx: None, - silence_stream_tx: None, - tasks: None, - session_state_tx: None, - session_start_timestamp_ms: None, - } - } - - #[tracing::instrument(skip_all)] - async fn setup_resources(&mut self, id: impl Into) -> Result<(), crate::Error> { - use tauri_plugin_db::{DatabasePluginExt, UserDatabase}; - - let session_id = id.into(); - let onboarding_session_id = UserDatabase::onboarding_session_id(); - - let user_id = self.app.db_user_id().await?.unwrap(); - self.session_id = Some(session_id.clone()); - - let (record, languages) = { - let config = self.app.db_get_config(&user_id).await?; - - let record = config - .as_ref() - .is_none_or(|c| c.general.save_recordings.unwrap_or(true)); - - let languages = config.as_ref().map_or_else( - || vec![hypr_language::ISO639::En.into()], - |c| c.general.spoken_languages.clone(), - ); - - (record, languages) - }; - - let session = self - .app - .db_get_session(&session_id) - .await? - .ok_or(crate::Error::NoneSession)?; - - self.session_start_timestamp_ms = Some( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis() as u64, - ); - - let (mic_muted_tx, mic_muted_rx_main) = tokio::sync::watch::channel(false); - let (speaker_muted_tx, speaker_muted_rx_main) = tokio::sync::watch::channel(false); - - let (stop_tx, mut stop_rx) = tokio::sync::mpsc::channel::<()>(1); - - self.mic_muted_tx = Some(mic_muted_tx); - self.mic_muted_rx = Some(mic_muted_rx_main.clone()); - self.speaker_muted_tx = Some(speaker_muted_tx); - self.speaker_muted_rx = Some(speaker_muted_rx_main.clone()); - - let listen_client = - setup_listen_client(&self.app, languages, session_id == onboarding_session_id).await?; - let mic_sample_stream = { - let mut input = hypr_audio::AudioInput::from_mic(self.mic_device_name.clone())?; - input.stream() - }; - let mic_stream = - ResampledAsyncSource::new(mic_sample_stream, SAMPLE_RATE).chunks(hypr_aec::BLOCK_SIZE); - - let speaker_sample_stream = hypr_audio::AudioInput::from_speaker().stream(); - let speaker_stream = ResampledAsyncSource::new(speaker_sample_stream, SAMPLE_RATE) - .chunks(hypr_aec::BLOCK_SIZE); - - let channels = AudioChannels::new(); - - { - let silence_stream_tx = hypr_audio::AudioOutput::silence(); - self.silence_stream_tx = Some(silence_stream_tx); - } - - let mut tasks = JoinSet::new(); - - tasks.spawn(AudioChannels::process_mic_stream( - mic_stream, - mic_muted_rx_main.clone(), - channels.mic_tx.clone(), - )); - - tasks.spawn(AudioChannels::process_speaker_stream( - speaker_stream, - speaker_muted_rx_main.clone(), - channels.speaker_tx.clone(), - )); - - let app_dir = self.app.path().app_data_dir().unwrap(); - - tasks.spawn({ - let app = self.app.clone(); - let mic_rx = channels.mic_rx.clone(); - let speaker_rx = channels.speaker_rx.clone(); - let save_mixed_tx = channels.save_mixed_tx.clone(); - let save_mic_raw_tx = channels.save_mic_raw_tx.clone(); - let save_speaker_raw_tx = channels.save_speaker_raw_tx.clone(); - let process_mic_tx = channels.process_mic_tx.clone(); - let process_speaker_tx = channels.process_speaker_tx.clone(); - - async move { - let mut aec = hypr_aec::AEC::new().unwrap(); - let mut mic_agc = hypr_agc::Agc::default(); - let mut speaker_agc = hypr_agc::Agc::default(); - let mut last_broadcast = Instant::now(); - - loop { - let (mut mic_chunk_raw, mut speaker_chunk): (Vec, Vec) = - match tokio::join!(mic_rx.recv_async(), speaker_rx.recv_async()) { - (Ok(mic), Ok(speaker)) => (mic, speaker), - _ => break, - }; - - mic_agc.process(&mut mic_chunk_raw); - speaker_agc.process(&mut speaker_chunk); - - let maybe_mic_chunk = aec.process_streaming(&mic_chunk_raw, &speaker_chunk); - - let mic_chunk = match maybe_mic_chunk { - Ok(mic_chunk) => mic_chunk, - Err(e) => { - tracing::error!("aec_error: {:?}", e); - mic_chunk_raw - } - }; - - let processed_mic = mic_chunk.clone(); - let processed_speaker = speaker_chunk.clone(); - - let now = Instant::now(); - if now.duration_since(last_broadcast) >= AUDIO_AMPLITUDE_THROTTLE { - if let Err(e) = SessionEvent::from((&mic_chunk, &speaker_chunk)).emit(&app) - { - tracing::error!("broadcast_error: {:?}", e); - } - last_broadcast = now; - } - - if let Some(ref tx) = save_mic_raw_tx { - let _ = tx.send_async(mic_chunk.clone()).await; - } - if let Some(ref tx) = save_speaker_raw_tx { - let _ = tx.send_async(speaker_chunk.clone()).await; - } - - if let Err(_) = process_mic_tx.send_async(processed_mic).await { - tracing::error!("process_mic_tx_send_error"); - return; - } - if let Err(_) = process_speaker_tx.send_async(processed_speaker).await { - tracing::error!("process_speaker_tx_send_error"); - return; - } - - if record { - let mixed: Vec = mic_chunk - .iter() - .zip(speaker_chunk.iter()) - .map(|(mic, speaker)| (mic + speaker).clamp(-1.0, 1.0)) - .collect(); - if save_mixed_tx.send_async(mixed).await.is_err() { - tracing::error!("save_mixed_tx_send_error"); - } - } - } - } - }); - - if record { - tasks.spawn({ - let app_dir = app_dir.clone(); - let session_id = session_id.clone(); - let save_mixed_rx = channels.save_mixed_rx.clone(); - - async move { - if let Err(e) = AudioSaver::save_to_wav( - save_mixed_rx, - &session_id, - &app_dir, - "audio.wav", - true, - ) - .await - { - tracing::error!("failed_to_save_mixed_audio: {:?}", e); - } - } - }); - } - - if let Some(save_mic_raw_rx) = channels.save_mic_raw_rx.clone() { - tasks.spawn({ - let session_id = session_id.clone(); - let app_dir = app_dir.clone(); - - async move { - if let Err(e) = AudioSaver::save_to_wav( - save_mic_raw_rx, - &session_id, - &app_dir, - "audio_mic.wav", - false, - ) - .await - { - tracing::error!("failed_to_save_raw_mic_audio: {:?}", e); - } - } - }); - } - - if let Some(save_speaker_raw_rx) = channels.save_speaker_raw_rx.clone() { - tasks.spawn({ - let session_id = session_id.clone(); - let app_dir = app_dir.clone(); - - async move { - if let Err(e) = AudioSaver::save_to_wav( - save_speaker_raw_rx, - &session_id, - &app_dir, - "audio_speaker.wav", - false, - ) - .await - { - tracing::error!("failed_to_save_raw_speaker_audio: {:?}", e); - } - } - }); - } - - let mic_audio_stream = channels - .process_mic_rx - .into_stream() - .map(|v| hypr_audio_utils::f32_to_i16_bytes(v.into_iter())); - - let speaker_audio_stream = channels - .process_speaker_rx - .into_stream() - .map(|v| hypr_audio_utils::f32_to_i16_bytes(v.into_iter())); - - let combined_audio_stream = - mic_audio_stream - .zip(speaker_audio_stream) - .map(|(mic, speaker)| { - owhisper_interface::MixedMessage::Audio((mic.into(), speaker.into())) - }); - - tasks.spawn({ - let app = self.app.clone(); - let stop_tx = stop_tx.clone(); - let session_start_timestamp_ms = self.session_start_timestamp_ms.unwrap_or(0); - - async move { - let (listen_stream, _listen_handle) = listen_client - .from_realtime_audio(combined_audio_stream) - .await - .unwrap(); - - futures_util::pin_mut!(listen_stream); - - let mut manager = - TranscriptManager::with_unix_timestamp(session_start_timestamp_ms); - - loop { - match tokio::time::timeout(LISTEN_STREAM_TIMEOUT, listen_stream.next()).await { - Ok(Some(response)) => { - let diff = manager.append(response.clone()); - - let partial_words_by_channel: HashMap< - usize, - Vec, - > = diff - .partial_words - .iter() - .map(|(channel_idx, words)| { - ( - *channel_idx, - words - .iter() - .map(|w| owhisper_interface::Word2::from(w.clone())) - .collect::>(), - ) - }) - .collect(); - SessionEvent::PartialWords { - words: partial_words_by_channel, - } - .emit(&app) - .unwrap(); - - let final_words_by_channel: HashMap< - usize, - Vec, - > = diff - .final_words - .iter() - .map(|(channel_idx, words)| { - ( - *channel_idx, - words - .iter() - .map(|w| owhisper_interface::Word2::from(w.clone())) - .collect::>(), - ) - }) - .collect(); - - update_session( - &app, - &session.id, - final_words_by_channel - .clone() - .values() - .flatten() - .cloned() - .collect(), - ) - .await - .unwrap(); - - SessionEvent::FinalWords { - words: final_words_by_channel, - } - .emit(&app) - .unwrap(); - } - Ok(None) => { - tracing::info!("listen_stream_ended"); - - // TODO: this not work - session still on ACTIVE - if stop_tx.send(()).await.is_err() { - tracing::warn!("failed_to_send_stop_signal"); - } - break; - } - Err(_) => { - tracing::info!("listen_stream_timeout"); - - if let Some(state) = app.try_state::() { - let mut guard = state.lock().await; - guard.fsm.handle(&crate::fsm::StateEvent::Stop).await; - } - } - } - } - } - }); - - let app_handle = self.app.clone(); - tasks.spawn(async move { - if stop_rx.recv().await.is_some() { - if let Some(state) = app_handle.try_state::() { - let mut guard = state.lock().await; - guard.fsm.handle(&crate::fsm::StateEvent::Stop).await; - } - } - }); - - self.tasks = Some(tasks); - - Ok(()) - } - - #[tracing::instrument(skip_all)] - async fn teardown_resources(&mut self) { - self.session_id = None; - - if let Some(tx) = self.silence_stream_tx.take() { - let _ = tx.send(()); - } - - if let Some(mut tasks) = self.tasks.take() { - tasks.abort_all(); - while let Some(res) = tasks.join_next().await { - let _ = res; - } - } - } - - pub fn is_mic_muted(&self) -> bool { - match &self.mic_muted_rx { - Some(rx) => *rx.borrow(), - None => false, - } - } - - pub fn is_speaker_muted(&self) -> bool { - match &self.speaker_muted_rx { - Some(rx) => *rx.borrow(), - None => false, - } - } - - pub fn get_available_mic_devices() -> Vec { - hypr_audio::AudioInput::list_mic_devices() - } - - pub fn get_current_mic_device(&self) -> Option { - self.mic_device_name.clone() - } -} - -async fn setup_listen_client( - app: &tauri::AppHandle, - languages: Vec, - is_onboarding: bool, -) -> Result { - let conn = { - use tauri_plugin_local_stt::LocalSttPluginExt; - app.get_connection().await? - }; - - Ok(owhisper_client::ListenClient::builder() - .api_base(conn.base_url) - .api_key(conn.api_key.unwrap_or_default()) - .params(owhisper_interface::ListenParams { - model: conn.model, - languages, - redemption_time_ms: Some(if is_onboarding { 60 } else { 400 }), - ..Default::default() - }) - .build_dual()) -} - -async fn update_session( - app: &tauri::AppHandle, - session_id: impl Into, - words: Vec, -) -> Result, crate::Error> { - use tauri_plugin_db::DatabasePluginExt; - - // TODO: not ideal. We might want to only do "update" everywhere instead of upserts. - // We do this because it is highly likely that the session fetched in the listener is stale (session can be updated on the React side). - let mut session = app - .db_get_session(session_id) - .await? - .ok_or(crate::Error::NoneSession)?; - - session.words.extend(words); - app.db_upsert_session(session.clone()).await.unwrap(); - - Ok(session.words) -} - -pub enum StateEvent { - Start(String), - Stop, - MicMuted(bool), - SpeakerMuted(bool), - MicChange(Option), -} - -#[state_machine( - initial = "State::inactive()", - on_transition = "Self::on_transition", - state(derive(Debug, Clone, PartialEq)) -)] -impl Session { - #[superstate] - async fn common(&mut self, event: &StateEvent) -> Response { - match event { - StateEvent::MicMuted(muted) => { - if let Some(tx) = &self.mic_muted_tx { - let _ = tx.send(*muted); - let _ = SessionEvent::MicMuted { value: *muted }.emit(&self.app); - } - Handled - } - StateEvent::SpeakerMuted(muted) => { - if let Some(tx) = &self.speaker_muted_tx { - let _ = tx.send(*muted); - let _ = SessionEvent::SpeakerMuted { value: *muted }.emit(&self.app); - } - Handled - } - StateEvent::MicChange(device_name) => { - self.mic_device_name = device_name.clone(); - - if self.session_id.is_some() && self.tasks.is_some() { - if let Some(session_id) = self.session_id.clone() { - self.teardown_resources().await; - self.setup_resources(&session_id).await.unwrap(); - } - } - - Handled - } - _ => Super, - } - } - - #[state(superstate = "common", entry_action = "enter_running_active")] - async fn running_active(&mut self, event: &StateEvent) -> Response { - match event { - StateEvent::Start(incoming_session_id) => match &self.session_id { - Some(current_id) if current_id != incoming_session_id => { - Transition(State::inactive()) - } - _ => Handled, - }, - StateEvent::Stop => Transition(State::inactive()), - _ => Super, - } - } - - #[state( - superstate = "common", - entry_action = "enter_inactive", - exit_action = "exit_inactive" - )] - async fn inactive(&mut self, event: &StateEvent) -> Response { - match event { - StateEvent::Start(id) => match self.setup_resources(id).await { - Ok(_) => Transition(State::running_active()), - Err(e) => { - // TODO: emit event - tracing::error!("error: {:?}", e); - Transition(State::inactive()) - } - }, - StateEvent::Stop => Handled, - _ => Super, - } - } - - #[action] - async fn enter_inactive(&mut self) { - { - use tauri_plugin_tray::TrayPluginExt; - let _ = self.app.set_start_disabled(false); - } - - { - use tauri_plugin_windows::{HyprWindow, WindowsPluginExt}; - let _ = self.app.window_hide(HyprWindow::Control); - } - - self.session_start_timestamp_ms = None; - - if let Some(session_id) = &self.session_id { - use tauri_plugin_db::DatabasePluginExt; - - if let Ok(Some(mut session)) = self.app.db_get_session(session_id).await { - session.record_end = Some(chrono::Utc::now()); - let _ = self.app.db_upsert_session(session).await; - } - } - - self.teardown_resources().await; - } - - #[action] - async fn exit_inactive(&mut self) { - use tauri_plugin_tray::TrayPluginExt; - let _ = self.app.set_start_disabled(true); - } - - #[action] - async fn enter_running_active(&mut self) { - // { - // use tauri_plugin_windows::{HyprWindow, WindowsPluginExt}; - // let _ = self.app.window_show(HyprWindow::Control); - // } - - if let Some(session_id) = &self.session_id { - use tauri_plugin_db::DatabasePluginExt; - - if let Ok(Some(mut session)) = self.app.db_get_session(session_id).await { - session.record_start = Some(chrono::Utc::now()); - let _ = self.app.db_upsert_session(session).await; - } - } - } - - fn on_transition(&mut self, source: &State, target: &State) { - #[cfg(debug_assertions)] - tracing::info!("transitioned from `{:?}` to `{:?}`", source, target); - - match target { - State::RunningActive {} => SessionEvent::RunningActive {}.emit(&self.app).unwrap(), - State::Inactive {} => SessionEvent::Inactive {}.emit(&self.app).unwrap(), - } - - if let Some(tx) = &self.session_state_tx { - let _ = tx.send(target.clone()); - } - } +#[derive(Debug, Clone)] +pub enum State { + RunningActive, + Inactive, } impl serde::Serialize for State { @@ -792,8 +10,8 @@ impl serde::Serialize for State { S: serde::Serializer, { match self { - State::Inactive {} => serializer.serialize_str("inactive"), - State::RunningActive {} => serializer.serialize_str("running_active"), + State::Inactive => serializer.serialize_str("inactive"), + State::RunningActive => serializer.serialize_str("running_active"), } } } diff --git a/plugins/listener/src/lib.rs b/plugins/listener/src/lib.rs index 7cf5ae4d20..1be2c6b96b 100644 --- a/plugins/listener/src/lib.rs +++ b/plugins/listener/src/lib.rs @@ -1,7 +1,8 @@ -use statig::awaitable::IntoStateMachineExt; +use ractor::{Actor, ActorRef}; use tauri::Manager; use tokio::sync::Mutex; +mod actors; mod commands; mod error; mod events; @@ -13,18 +14,26 @@ pub use error::*; pub use events::*; pub use ext::*; +use crate::actors::{SessionArgs, SessionMsg, SessionSupervisor}; + const PLUGIN_NAME: &str = "listener"; pub type SharedState = Mutex; pub struct State { - fsm: statig::awaitable::StateMachine, - _device_monitor_handle: Option, + supervisor: Option>, } impl State { - pub fn get_state(&self) -> fsm::State { - self.fsm.state().clone() + pub async fn get_state(&self) -> fsm::State { + if let Some(supervisor) = &self.supervisor { + match ractor::call_t!(supervisor, SessionMsg::GetState, 100) { + Ok(state) => state, + Err(_) => fsm::State::Inactive {}, + } + } else { + fsm::State::Inactive {} + } } } @@ -61,54 +70,42 @@ pub fn init() -> tauri::plugin::TauriPlugin { .setup(move |app, _api| { specta_builder.mount_events(app); - let handle = app.app_handle(); - let fsm = fsm::Session::new(handle.clone()).state_machine(); - - let device_monitor_handle = { - let (event_tx, event_rx) = std::sync::mpsc::channel(); - let device_monitor_handle = hypr_audio::DeviceMonitor::spawn(event_tx); - - let app_handle = handle.clone(); - std::thread::spawn(move || { - while let Ok(event) = event_rx.recv() { - if let hypr_audio::DeviceEvent::DefaultInputChanged { .. } = event { - let new_device = hypr_audio::AudioInput::get_default_mic_device_name(); - - let app_handle_clone = app_handle.clone(); - let device_name = new_device.clone(); - - app_handle_clone - .run_on_main_thread({ - let app_handle_inner = app_handle_clone.clone(); - let device_name_inner = device_name.clone(); - move || { - tauri::async_runtime::spawn(async move { - if let Some(state) = - app_handle_inner.try_state::() - { - let mut guard = state.lock().await; - let event = fsm::StateEvent::MicChange(Some( - device_name_inner, - )); - guard.fsm.handle(&event).await; - } - }); - } - }) - .ok(); - } - } - }); + let state: SharedState = Mutex::new(State { supervisor: None }); + app.manage(state); - device_monitor_handle - }; + let app_handle = app.app_handle().clone(); + + tokio::spawn(async move { + match Actor::spawn( + Some("session_supervisor".to_string()), + SessionSupervisor, + SessionArgs { + app: app_handle.clone(), + }, + ) + .await + { + Ok((supervisor_ref, join_handle)) => { + { + let state_ref = app_handle.state::(); + let mut state = state_ref.lock().await; + state.supervisor = Some(supervisor_ref); + } - let state: SharedState = Mutex::new(State { - fsm, - _device_monitor_handle: Some(device_monitor_handle), + tokio::spawn(async move { + if let Err(e) = join_handle.await { + tracing::error!("SessionSupervisor terminated with error: {:?}", e); + } else { + tracing::info!("SessionSupervisor terminated gracefully"); + } + }); + } + Err(e) => { + tracing::error!("Failed to spawn SessionSupervisor: {}", e); + } + } }); - app.manage(state); Ok(()) }) .build() diff --git a/plugins/notification/src/quit.rs b/plugins/notification/src/quit.rs index f8c211df7e..91ce388569 100644 --- a/plugins/notification/src/quit.rs +++ b/plugins/notification/src/quit.rs @@ -11,11 +11,10 @@ pub fn create_quit_handler( if let Some(shared_state) = app_handle.try_state::() { if let Ok(guard) = shared_state.try_lock() { - let state = guard.get_state(); - if !matches!( - state, - tauri_plugin_listener::fsm::State::RunningActive { .. } - ) { + let state = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(guard.get_state()) + }); + if !matches!(state, tauri_plugin_listener::fsm::State::RunningActive) { is_exit_intent = true; } else { is_exit_intent = app_handle