diff --git a/crates/audio/src/device_monitor.rs b/crates/audio/src/device_monitor.rs index 309c912242..4c9b8532b9 100644 --- a/crates/audio/src/device_monitor.rs +++ b/crates/audio/src/device_monitor.rs @@ -3,7 +3,7 @@ use std::thread::JoinHandle; #[derive(Debug, Clone)] pub enum DeviceEvent { - DefaultInputChanged { headphone: bool }, + DefaultInputChanged, DefaultOutputChanged { headphone: bool }, } @@ -63,7 +63,8 @@ impl DeviceMonitor { #[cfg(target_os = "macos")] mod macos { use super::*; - use cidre::{core_audio as ca, io, ns, os}; + use crate::utils::macos::is_headphone_from_default_output_device; + use cidre::{core_audio as ca, ns, os}; extern "C-unwind" fn listener( _obj_id: ca::Obj, @@ -77,11 +78,10 @@ mod macos { for addr in addresses { match addr.selector { ca::PropSelector::HW_DEFAULT_INPUT_DEVICE => { - let headphone = detect_headphones(ca::System::default_input_device().ok()); - let _ = event_tx.send(DeviceEvent::DefaultInputChanged { headphone }); + let _ = event_tx.send(DeviceEvent::DefaultInputChanged); } ca::PropSelector::HW_DEFAULT_OUTPUT_DEVICE => { - let headphone = detect_headphones(ca::System::default_output_device().ok()); + let headphone = is_headphone_from_default_output_device(); let _ = event_tx.send(DeviceEvent::DefaultOutputChanged { headphone }); } _ => {} @@ -90,23 +90,6 @@ mod macos { os::Status::NO_ERR } - fn detect_headphones(device: Option) -> bool { - match device { - Some(device) => match device.streams() { - Ok(streams) => streams.iter().any(|s| { - if let Ok(term_type) = s.terminal_type() { - term_type.0 == io::audio::output_term::HEADPHONES - || term_type == ca::StreamTerminalType::HEADPHONES - } else { - false - } - }), - Err(_) => false, - }, - None => false, - } - } - pub(super) fn monitor(event_tx: mpsc::Sender, stop_rx: mpsc::Receiver<()>) { let selectors = [ ca::PropSelector::HW_DEFAULT_INPUT_DEVICE, diff --git a/crates/audio/src/lib.rs b/crates/audio/src/lib.rs index 374c244a82..6a9a523c71 100644 --- a/crates/audio/src/lib.rs +++ b/crates/audio/src/lib.rs @@ -1,16 +1,20 @@ mod device_monitor; mod errors; mod mic; +mod mixed; mod norm; mod resampler; mod speaker; +mod utils; pub use device_monitor::*; pub use errors::*; pub use mic::*; +pub use mixed::*; pub use norm::*; pub use resampler::*; pub use speaker::*; +pub use utils::*; pub use cpal; use cpal::traits::{DeviceTrait, HostTrait}; @@ -62,6 +66,7 @@ impl AudioOutput { pub enum AudioSource { RealtimeMic, RealtimeSpeaker, + RealtimeMixed, Recorded, } @@ -69,14 +74,34 @@ pub struct AudioInput { source: AudioSource, mic: Option, speaker: Option, + mixed: Option, data: Option>, } impl AudioInput { - pub fn get_default_mic_device_name() -> String { - let host = cpal::default_host(); - let device = host.default_input_device().unwrap(); - device.name().unwrap_or("Unknown Microphone".to_string()) + pub fn get_default_mic_name() -> String { + let name = { + let host = cpal::default_host(); + let device = host.default_input_device().unwrap(); + device.name().unwrap_or("Unknown Microphone".to_string()) + }; + + name + } + + pub fn is_using_headphone() -> bool { + let headphone = { + #[cfg(target_os = "macos")] + { + utils::macos::is_headphone_from_default_output_device() + } + #[cfg(not(target_os = "macos"))] + { + false + } + }; + + headphone } pub fn list_mic_devices() -> Vec { @@ -101,6 +126,7 @@ impl AudioInput { source: AudioSource::RealtimeMic, mic: Some(mic), speaker: None, + mixed: None, data: None, }) } @@ -110,15 +136,30 @@ impl AudioInput { source: AudioSource::RealtimeSpeaker, mic: None, speaker: Some(SpeakerInput::new().unwrap()), + mixed: None, data: None, } } + #[cfg(target_os = "macos")] + pub fn from_mixed() -> Result { + let mixed = MixedInput::new().unwrap(); + + Ok(Self { + source: AudioSource::RealtimeMixed, + mic: None, + speaker: None, + mixed: Some(mixed), + data: None, + }) + } + pub fn from_recording(data: Vec) -> Self { Self { source: AudioSource::Recorded, mic: None, speaker: None, + mixed: None, data: Some(data), } } @@ -126,8 +167,9 @@ impl AudioInput { pub fn device_name(&self) -> String { match &self.source { AudioSource::RealtimeMic => self.mic.as_ref().unwrap().device_name(), - AudioSource::RealtimeSpeaker => "TODO".to_string(), - AudioSource::Recorded => "TODO".to_string(), + AudioSource::RealtimeSpeaker => "RealtimeSpeaker".to_string(), + AudioSource::RealtimeMixed => "Mixed Audio".to_string(), + AudioSource::Recorded => "Recorded".to_string(), } } @@ -139,6 +181,9 @@ impl AudioInput { AudioSource::RealtimeSpeaker => AudioStream::RealtimeSpeaker { speaker: self.speaker.take().unwrap().stream().unwrap(), }, + AudioSource::RealtimeMixed => AudioStream::RealtimeMixed { + mixed: self.mixed.take().unwrap().stream().unwrap(), + }, AudioSource::Recorded => AudioStream::Recorded { data: self.data.as_ref().unwrap().clone(), position: 0, @@ -150,6 +195,7 @@ impl AudioInput { pub enum AudioStream { RealtimeMic { mic: MicStream }, RealtimeSpeaker { speaker: SpeakerStream }, + RealtimeMixed { mixed: MixedStream }, Recorded { data: Vec, position: usize }, } @@ -166,7 +212,7 @@ impl Stream for AudioStream { match &mut *self { AudioStream::RealtimeMic { mic } => mic.poll_next_unpin(cx), AudioStream::RealtimeSpeaker { speaker } => speaker.poll_next_unpin(cx), - // assume pcm_s16le, without WAV header + AudioStream::RealtimeMixed { mixed } => mixed.poll_next_unpin(cx), AudioStream::Recorded { data, position } => { if *position + 2 <= data.len() { let bytes = [data[*position], data[*position + 1]]; @@ -192,7 +238,34 @@ impl kalosm_sound::AsyncSource for AudioStream { match self { AudioStream::RealtimeMic { mic } => mic.sample_rate(), AudioStream::RealtimeSpeaker { speaker } => speaker.sample_rate(), + AudioStream::RealtimeMixed { mixed } => mixed.sample_rate(), AudioStream::Recorded { .. } => 16000, } } } + +#[cfg(test)] +pub(crate) fn play_sine_for_sec(seconds: u64) -> std::thread::JoinHandle<()> { + use rodio::{ + cpal::SampleRate, + source::{Function::Sine, SignalGenerator, Source}, + OutputStream, + }; + use std::{ + thread::{sleep, spawn}, + time::Duration, + }; + + spawn(move || { + let (_stream, stream_handle) = OutputStream::try_default().unwrap(); + let source = SignalGenerator::new(SampleRate(44100), 440.0, Sine); + + let source = source + .convert_samples() + .take_duration(Duration::from_secs(seconds)) + .amplify(0.01); + + stream_handle.play_raw(source).unwrap(); + sleep(Duration::from_secs(seconds)); + }) +} diff --git a/crates/audio/src/mixed/macos.rs b/crates/audio/src/mixed/macos.rs new file mode 100644 index 0000000000..6f337bfbfe --- /dev/null +++ b/crates/audio/src/mixed/macos.rs @@ -0,0 +1,320 @@ +use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Poll, Waker}; + +use anyhow::Result; +use futures_util::Stream; + +use ringbuf::{ + traits::{Consumer, Producer, Split}, + HeapCons, HeapProd, HeapRb, +}; + +use ca::aggregate_device_keys as agg_keys; +use ca::sub_device_keys; +use cidre::{arc, av, cat, cf, core_audio as ca, ns, os}; + +pub struct MixedInput { + tap: ca::TapGuard, + agg_desc: arc::Retained>, +} + +pub struct MixedStream { + consumer: HeapCons, + _device: ca::hardware::StartedDevice, + _ctx: Box, + _tap: ca::TapGuard, + waker_state: Arc>, + current_sample_rate: Arc, +} + +impl MixedStream { + pub fn sample_rate(&self) -> u32 { + self.current_sample_rate.load(Ordering::Acquire) + } +} + +struct WakerState { + waker: Option, + has_data: bool, +} + +struct MixedCtx { + format: arc::R, + producer: HeapProd, + waker_state: Arc>, + current_sample_rate: Arc, + consecutive_drops: Arc, + should_terminate: Arc, +} + +impl MixedInput { + pub fn new() -> Result { + let input_device = ca::System::default_input_device()?; + let input_uid = input_device.uid()?; + + let output_device = ca::System::default_output_device()?; + let output_uid = output_device.uid()?; + + let tap_desc = ca::TapDesc::with_mono_global_tap_excluding_processes(&ns::Array::new()); + let tap = tap_desc.create_process_tap()?; + + let agg_desc = Self::create_aggregate_description(&tap, &input_uid, &output_uid)?; + + Ok(Self { tap, agg_desc }) + } + + fn create_aggregate_description( + tap: &ca::TapGuard, + input_uid: &cf::String, + output_uid: &cf::String, + ) -> Result>> { + let input_sub_device = cf::DictionaryOf::with_keys_values( + &[sub_device_keys::uid()], + &[input_uid.as_type_ref()], + ); + + let output_sub_device = cf::DictionaryOf::with_keys_values( + &[sub_device_keys::uid()], + &[output_uid.as_type_ref()], + ); + + let sub_tap = cf::DictionaryOf::with_keys_values( + &[ca::sub_device_keys::uid()], + &[tap.uid().unwrap().as_type_ref()], + ); + + let agg_desc = cf::DictionaryOf::with_keys_values( + &[ + agg_keys::is_private(), + agg_keys::is_stacked(), + agg_keys::tap_auto_start(), + agg_keys::name(), + agg_keys::main_sub_device(), + agg_keys::uid(), + agg_keys::sub_device_list(), + agg_keys::tap_list(), + agg_keys::clock_device(), + ], + &[ + cf::Boolean::value_true().as_type_ref(), + cf::Boolean::value_false(), + cf::Boolean::value_true(), + cf::str!(c"mixed-audio-tap"), + &output_uid, + &cf::Uuid::new().to_cf_string(), + &cf::ArrayOf::from_slice(&[input_sub_device.as_ref(), output_sub_device.as_ref()]), + &cf::ArrayOf::from_slice(&[sub_tap.as_ref()]), + &input_uid, // Use input device as clock source for consistency + ], + ); + + Ok(agg_desc) + } + + fn start_device( + &self, + ctx: &mut Box, + ) -> Result> { + extern "C" fn proc( + device: ca::Device, + _now: &cat::AudioTimeStamp, + input_data: &cat::AudioBufList<1>, + _input_time: &cat::AudioTimeStamp, + _output_data: &mut cat::AudioBufList<1>, + _output_time: &cat::AudioTimeStamp, + ctx: Option<&mut MixedCtx>, + ) -> os::Status { + let ctx = ctx.unwrap(); + + ctx.current_sample_rate.store( + device + .nominal_sample_rate() + .unwrap_or(ctx.format.absd().sample_rate) as u32, + Ordering::Release, + ); + + if let Some(view) = + av::AudioPcmBuf::with_buf_list_no_copy(&ctx.format, input_data, None) + { + let format = view.format(); + + if format.channel_count() > 1 { + let frame_count = view.frame_len() as usize; + let mut mixed_buffer = Vec::with_capacity(frame_count); + + for frame_idx in 0..frame_count { + let mut mixed_sample = 0.0f32; + let channel_count = format.channel_count() as usize; + + for channel in 0..channel_count { + if let Some(channel_data) = view.data_f32_at(channel) { + if frame_idx < channel_data.len() { + mixed_sample += channel_data[frame_idx]; + } + } + } + + // Average the mixed sample + mixed_sample /= channel_count as f32; + mixed_buffer.push(mixed_sample); + } + + process_mixed_audio_data(ctx, &mixed_buffer); + } else if let Some(data) = view.data_f32_at(0) { + process_mixed_audio_data(ctx, data); + } + } else if ctx.format.common_format() == av::audio::CommonFormat::PcmF32 { + let first_buffer = &input_data.buffers[0]; + let byte_count = first_buffer.data_bytes_size as usize; + let float_count = byte_count / std::mem::size_of::(); + + if float_count > 0 && first_buffer.data != std::ptr::null_mut() { + let data = unsafe { + std::slice::from_raw_parts(first_buffer.data as *const f32, float_count) + }; + process_mixed_audio_data(ctx, data); + } + } + + os::Status::NO_ERR + } + + let agg_device = ca::AggregateDevice::with_desc(&self.agg_desc)?; + let proc_id = agg_device.create_io_proc_id(proc, Some(ctx))?; + let started_device = ca::device_start(agg_device, Some(proc_id))?; + + Ok(started_device) + } + + pub fn stream(self) -> MixedStream { + let asbd = self.tap.asbd().unwrap(); + let format = av::AudioFormat::with_asbd(&asbd).unwrap(); + + let buffer_size = 1024 * 128; + let rb = HeapRb::::new(buffer_size); + let (producer, consumer) = rb.split(); + + let waker_state = Arc::new(Mutex::new(WakerState { + waker: None, + has_data: false, + })); + + let current_sample_rate = Arc::new(AtomicU32::new(asbd.sample_rate as u32)); + + let mut ctx = Box::new(MixedCtx { + format, + producer, + waker_state: waker_state.clone(), + current_sample_rate: current_sample_rate.clone(), + consecutive_drops: Arc::new(AtomicU32::new(0)), + should_terminate: Arc::new(AtomicBool::new(false)), + }); + + let device = self.start_device(&mut ctx).unwrap(); + + MixedStream { + consumer, + _device: device, + _ctx: ctx, + _tap: self.tap, + waker_state, + current_sample_rate, + } + } +} + +fn process_mixed_audio_data(ctx: &mut MixedCtx, data: &[f32]) { + let buffer_size = data.len(); + let pushed = ctx.producer.push_slice(data); + + if pushed < buffer_size { + let consecutive = ctx.consecutive_drops.fetch_add(1, Ordering::AcqRel) + 1; + + if consecutive > 10 { + ctx.should_terminate.store(true, Ordering::Release); + return; + } + } else { + ctx.consecutive_drops.store(0, Ordering::Release); + } + + if pushed > 0 { + let should_wake = { + let mut waker_state = ctx.waker_state.lock().unwrap(); + if !waker_state.has_data { + waker_state.has_data = true; + waker_state.waker.take() + } else { + None + } + }; + + if let Some(waker) = should_wake { + waker.wake(); + } + } +} + +impl Stream for MixedStream { + type Item = f32; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + if let Some(sample) = self.consumer.try_pop() { + return Poll::Ready(Some(sample)); + } + + if self._ctx.should_terminate.load(Ordering::Acquire) { + return match self.consumer.try_pop() { + Some(sample) => Poll::Ready(Some(sample)), + None => Poll::Ready(None), + }; + } + + { + let mut state = self.waker_state.lock().unwrap(); + state.has_data = false; + state.waker = Some(cx.waker().clone()); + } + + Poll::Pending + } +} + +impl Drop for MixedStream { + fn drop(&mut self) { + self._ctx.should_terminate.store(true, Ordering::Release); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::play_sine_for_sec; + + use futures_util::StreamExt; + + #[tokio::test] + async fn test_macos() { + let input = MixedInput::new().unwrap(); + let mut stream = input.stream(); + + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + let handle = play_sine_for_sec(2); + + let mut buffer = Vec::new(); + while let Some(sample) = stream.next().await { + buffer.push(sample); + if buffer.len() > 48000 { + break; + } + } + + handle.join().unwrap(); + assert!(buffer.iter().any(|x| *x != 0.0)); + } +} diff --git a/crates/audio/src/mixed/mod.rs b/crates/audio/src/mixed/mod.rs new file mode 100644 index 0000000000..bf492577a2 --- /dev/null +++ b/crates/audio/src/mixed/mod.rs @@ -0,0 +1,64 @@ +use anyhow::Result; +use futures_util::{Stream, StreamExt}; + +#[cfg(target_os = "macos")] +mod macos; +#[cfg(target_os = "macos")] +type PlatformMixedInput = macos::MixedInput; +#[cfg(target_os = "macos")] +type PlatformMixedStream = macos::MixedStream; + +#[cfg(not(target_os = "macos"))] +mod other; +#[cfg(not(target_os = "macos"))] +type PlatformMixedInput = other::MixedInput; +#[cfg(not(target_os = "macos"))] +type PlatformMixedStream = other::MixedStream; + +pub struct MixedInput { + inner: PlatformMixedInput, +} + +impl MixedInput { + pub fn new() -> Result { + let inner = PlatformMixedInput::new()?; + Ok(Self { inner }) + } + + pub fn stream(self) -> Result { + let inner = self.inner.stream(); + Ok(MixedStream { inner }) + } +} + +pub struct MixedStream { + inner: PlatformMixedStream, +} + +impl Stream for MixedStream { + type Item = f32; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_next_unpin(cx) + } +} + +impl kalosm_sound::AsyncSource for MixedStream { + fn as_stream(&mut self) -> impl Stream + '_ { + self + } + + fn sample_rate(&self) -> u32 { + #[cfg(target_os = "macos")] + { + self.inner.sample_rate() + } + #[cfg(not(target_os = "macos"))] + { + 0 + } + } +} diff --git a/crates/audio/src/mixed/other.rs b/crates/audio/src/mixed/other.rs new file mode 100644 index 0000000000..dd9819c9c4 --- /dev/null +++ b/crates/audio/src/mixed/other.rs @@ -0,0 +1,27 @@ +use anyhow::Result; +use futures_util::Stream; + +pub struct MixedInput {} + +impl MixedInput { + pub fn new() -> Result { + Ok(Self {}) + } + + pub fn stream(self) -> Result { + Ok(MixedStream {}) + } +} + +pub struct MixedStream {} + +impl Stream for MixedStream { + type Item = f32; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Poll::Pending + } +} diff --git a/crates/audio/src/speaker/macos.rs b/crates/audio/src/speaker/macos.rs index 7c471c06f8..622e496770 100644 --- a/crates/audio/src/speaker/macos.rs +++ b/crates/audio/src/speaker/macos.rs @@ -52,13 +52,6 @@ impl SpeakerInput { let output_device = ca::System::default_output_device()?; let output_uid = output_device.uid()?; - tracing::info!( - name = ?output_device.name().unwrap_or("Unknown Speaker".into()), - nominal_sample_rate = ?output_device.nominal_sample_rate().unwrap(), - actual_sample_rate = ?output_device.actual_sample_rate().unwrap(), - "speaker_output_device" - ); - let sub_device = cf::DictionaryOf::with_keys_values( &[ca::sub_device_keys::uid()], &[output_uid.as_type_ref()], @@ -115,7 +108,7 @@ impl SpeakerInput { ctx.current_sample_rate.store( device - .actual_sample_rate() + .nominal_sample_rate() .unwrap_or(ctx.format.absd().sample_rate) as u32, Ordering::Release, ); diff --git a/crates/audio/src/speaker/mod.rs b/crates/audio/src/speaker/mod.rs index d29a10a60c..41905a87c4 100644 --- a/crates/audio/src/speaker/mod.rs +++ b/crates/audio/src/speaker/mod.rs @@ -99,33 +99,9 @@ impl kalosm_sound::AsyncSource for SpeakerStream { #[cfg(test)] mod tests { use super::*; - use serial_test::serial; - - fn play_sine_for_sec(seconds: u64) -> std::thread::JoinHandle<()> { - use rodio::{ - cpal::SampleRate, - source::{Function::Sine, SignalGenerator, Source}, - OutputStream, - }; - use std::{ - thread::{sleep, spawn}, - time::Duration, - }; + use crate::play_sine_for_sec; - spawn(move || { - let (_stream, stream_handle) = OutputStream::try_default().unwrap(); - let source = SignalGenerator::new(SampleRate(44100), 440.0, Sine); - - let source = source - .convert_samples() - .take_duration(Duration::from_secs(seconds)) - .amplify(0.01); - - println!("Playing sine for {} seconds", seconds); - stream_handle.play_raw(source).unwrap(); - sleep(Duration::from_secs(seconds)); - }) - } + use serial_test::serial; #[cfg(target_os = "macos")] #[tokio::test] diff --git a/crates/audio/src/utils.rs b/crates/audio/src/utils.rs new file mode 100644 index 0000000000..46a4f35773 --- /dev/null +++ b/crates/audio/src/utils.rs @@ -0,0 +1,40 @@ +#[cfg(target_os = "macos")] +pub mod macos { + use cidre::{core_audio as ca, io}; + + fn is_headphone_from_device(device: Option) -> bool { + match device { + Some(device) => match device.streams() { + Ok(streams) => streams.iter().any(|s| { + if let Ok(term_type) = s.terminal_type() { + term_type.0 == io::audio::output_term::HEADPHONES + || term_type == ca::StreamTerminalType::HEADPHONES + } else { + false + } + }), + Err(_) => false, + }, + None => false, + } + } + + pub fn is_headphone_from_default_output_device() -> bool { + let device = ca::System::default_output_device().ok(); + is_headphone_from_device(device) + } +} + +#[cfg(target_os = "macos")] +#[cfg(test)] +pub mod test { + use super::macos::*; + + #[test] + fn test_is_headphone_from_default_output_device() { + println!( + "is_headphone_from_default_output_device={}", + is_headphone_from_default_output_device() + ); + } +} diff --git a/plugins/listener/src/actors/listen.rs b/plugins/listener/src/actors/listen.rs index 6e101ab8ef..44b7563a52 100644 --- a/plugins/listener/src/actors/listen.rs +++ b/plugins/listener/src/actors/listen.rs @@ -47,123 +47,7 @@ impl Actor for ListenBridge { myself: ActorRef, args: Self::Arguments, ) -> Result { - let (tx, rx) = - tokio::sync::mpsc::channel::>(32); - - let conn = { - use tauri_plugin_local_stt::LocalSttPluginExt; - - match args.app.get_connection().await { - Ok(c) => c, - Err(e) => { - tracing::error!("failed_to_get_connection: {:?}", e); - return Err(ActorProcessingErr::from(e)); - } - } - }; - - let client = owhisper_client::ListenClient::builder() - .api_base(conn.base_url) - .api_key(conn.api_key.unwrap_or_default()) - .params(owhisper_interface::ListenParams { - model: conn.model, - languages: args.languages, - redemption_time_ms: Some(if args.onboarding { 60 } else { 400 }), - ..Default::default() - }) - .build_dual(); - - let rx_task = tokio::spawn({ - let app = args.app.clone(); - let session_id = args.session_id.clone(); - - async move { - let outbound = tokio_stream::wrappers::ReceiverStream::new(rx); - let (listen_stream, _handle) = match client.from_realtime_audio(outbound).await { - Ok(res) => res, - Err(e) => { - tracing::error!("listen_ws_connect_failed: {:?}", e); - myself.stop(Some(format!("listen_ws_connect_failed: {:?}", e))); - return; - } - }; - futures_util::pin_mut!(listen_stream); - - let mut manager = TranscriptManager::with_unix_timestamp(args.session_start_ts_ms); - - loop { - match tokio::time::timeout(LISTEN_STREAM_TIMEOUT, listen_stream.next()).await { - Ok(Some(response)) => { - let diff = manager.append(response.clone()); - - let partial_words_by_channel: HashMap> = diff - .partial_words - .iter() - .map(|(channel_idx, words)| { - ( - *channel_idx, - words - .iter() - .map(|w| Word2::from(w.clone())) - .collect::>(), - ) - }) - .collect(); - - SessionEvent::PartialWords { - words: partial_words_by_channel, - } - .emit(&app) - .unwrap(); - - let final_words_by_channel: HashMap> = diff - .final_words - .iter() - .map(|(channel_idx, words)| { - ( - *channel_idx, - words - .iter() - .map(|w| Word2::from(w.clone())) - .collect::>(), - ) - }) - .collect(); - - update_session( - &app, - &session_id, - final_words_by_channel - .clone() - .values() - .flatten() - .cloned() - .collect(), - ) - .await - .unwrap(); - - SessionEvent::FinalWords { - words: final_words_by_channel, - } - .emit(&app) - .unwrap(); - } - Ok(None) => { - tracing::info!("listen_stream_ended"); - break; - } - Err(_) => { - tracing::info!("listen_stream_timeout"); - break; - } - } - } - - myself.stop(None); - } - }); - + let (tx, rx_task) = spawn_rx_task(args, myself).await.unwrap(); Ok(ListenState { tx, rx_task }) } @@ -191,6 +75,127 @@ impl Actor for ListenBridge { } } +async fn spawn_rx_task( + args: ListenArgs, + myself: ActorRef, +) -> Result< + ( + tokio::sync::mpsc::Sender>, + tokio::task::JoinHandle<()>, + ), + ActorProcessingErr, +> { + let (tx, rx) = tokio::sync::mpsc::channel::>(32); + + let app = args.app.clone(); + let session_id = args.session_id.clone(); + let session_start_ts_ms = args.session_start_ts_ms; + + let conn = { + use tauri_plugin_local_stt::LocalSttPluginExt; + app.get_connection().await? + }; + + let client = owhisper_client::ListenClient::builder() + .api_base(conn.base_url) + .api_key(conn.api_key.unwrap_or_default()) + .params(owhisper_interface::ListenParams { + model: conn.model, + languages: args.languages, + redemption_time_ms: Some(if args.onboarding { 60 } else { 400 }), + ..Default::default() + }) + .build_dual(); + + let rx_task = tokio::spawn(async move { + let outbound = tokio_stream::wrappers::ReceiverStream::new(rx); + let (listen_stream, _handle) = match client.from_realtime_audio(outbound).await { + Ok(res) => res, + Err(e) => { + tracing::error!("listen_ws_connect_failed: {:?}", e); + myself.stop(Some(format!("listen_ws_connect_failed: {:?}", e))); + return; + } + }; + futures_util::pin_mut!(listen_stream); + + let mut manager = TranscriptManager::with_unix_timestamp(session_start_ts_ms); + + loop { + match tokio::time::timeout(LISTEN_STREAM_TIMEOUT, listen_stream.next()).await { + Ok(Some(response)) => { + let diff = manager.append(response.clone()); + + let partial_words_by_channel: HashMap> = diff + .partial_words + .iter() + .map(|(channel_idx, words)| { + ( + *channel_idx, + words + .iter() + .map(|w| Word2::from(w.clone())) + .collect::>(), + ) + }) + .collect(); + + SessionEvent::PartialWords { + words: partial_words_by_channel, + } + .emit(&app) + .unwrap(); + + let final_words_by_channel: HashMap> = diff + .final_words + .iter() + .map(|(channel_idx, words)| { + ( + *channel_idx, + words + .iter() + .map(|w| Word2::from(w.clone())) + .collect::>(), + ) + }) + .collect(); + + update_session( + &app, + &session_id, + final_words_by_channel + .clone() + .values() + .flatten() + .cloned() + .collect(), + ) + .await + .unwrap(); + + SessionEvent::FinalWords { + words: final_words_by_channel, + } + .emit(&app) + .unwrap(); + } + Ok(None) => { + tracing::info!("listen_stream_ended"); + break; + } + Err(_) => { + tracing::info!("listen_stream_timeout"); + break; + } + } + } + + myself.stop(None); + }); + + Ok((tx, rx_task)) +} + async fn update_session( app: &tauri::AppHandle, session_id: impl Into, diff --git a/plugins/listener/src/actors/processor.rs b/plugins/listener/src/actors/processor.rs index 50f930552a..96744168f1 100644 --- a/plugins/listener/src/actors/processor.rs +++ b/plugins/listener/src/actors/processor.rs @@ -16,8 +16,9 @@ const AUDIO_AMPLITUDE_THROTTLE: Duration = Duration::from_millis(100); pub enum ProcMsg { Mic(AudioChunk), - Spk(AudioChunk), - AttachListen(ActorRef), + Speaker(AudioChunk), + Mixed(AudioChunk), + AttachListener(ActorRef), AttachRecorder(ActorRef), } @@ -81,7 +82,7 @@ impl Actor for AudioProcessor { st: &mut Self::State, ) -> Result<(), ActorProcessingErr> { match msg { - ProcMsg::AttachListen(actor) => st.listen = Some(actor), + ProcMsg::AttachListener(actor) => st.listen = Some(actor), ProcMsg::AttachRecorder(actor) => st.recorder = Some(actor), ProcMsg::Mic(mut c) => { st.agc_m.process(&mut c.data); @@ -90,13 +91,25 @@ impl Actor for AudioProcessor { st.joiner.push_mic(arc); process_ready(st).await; } - ProcMsg::Spk(mut c) => { + ProcMsg::Speaker(mut c) => { st.agc_s.process(&mut c.data); let arc = Arc::<[f32]>::from(c.data); st.last_spk = Some(arc.clone()); st.joiner.push_spk(arc); process_ready(st).await; } + ProcMsg::Mixed(mut c) => { + st.agc_m.process(&mut c.data); + + let empty_arc = Arc::<[f32]>::from(vec![0.0; c.data.len()]); + let arc = Arc::<[f32]>::from(c.data); + + st.last_mic = Some(empty_arc.clone()); + st.last_spk = Some(arc.clone()); + st.joiner.push_mic(empty_arc.clone()); + st.joiner.push_spk(arc); + process_ready(st).await; + } } Ok(()) } diff --git a/plugins/listener/src/actors/session.rs b/plugins/listener/src/actors/session.rs index 948bbed5d8..829ef52234 100644 --- a/plugins/listener/src/actors/session.rs +++ b/plugins/listener/src/actors/session.rs @@ -283,7 +283,7 @@ impl SessionSupervisor { if state.record_enabled { let app_dir = state.app.path().app_data_dir().unwrap(); let (rec_ref, _) = Actor::spawn_linked( - Some("recorder".to_string()), + Some(Recorder::name()), Recorder, RecArgs { app_dir: app_dir.clone(), @@ -311,7 +311,7 @@ impl SessionSupervisor { ) .await?; state.listen = Some(listen_ref.clone()); - processor_ref.cast(ProcMsg::AttachListen(listen_ref))?; + processor_ref.cast(ProcMsg::AttachListener(listen_ref))?; { use tauri_plugin_tray::TrayPluginExt; diff --git a/plugins/listener/src/actors/source.rs b/plugins/listener/src/actors/source.rs index 10d5d3dd56..e08c696ce3 100644 --- a/plugins/listener/src/actors/source.rs +++ b/plugins/listener/src/actors/source.rs @@ -28,7 +28,7 @@ pub struct SourceArgs { } pub struct SourceState { - device: Option, + mic_device: Option, proc: ActorRef, token: CancellationToken, mic_muted: Arc, @@ -63,18 +63,23 @@ impl Actor for SourceActor { let myself_clone = myself.clone(); std::thread::spawn(move || { while let Ok(event) = event_rx.recv() { - if let DeviceEvent::DefaultInputChanged { .. } = event { - let new_device = AudioInput::get_default_mic_device_name(); - let _ = myself_clone.cast(SourceCtrl::SetMicDevice(Some(new_device))); + match event { + DeviceEvent::DefaultInputChanged { .. } + | DeviceEvent::DefaultOutputChanged { .. } => { + let new_device = AudioInput::get_default_mic_name(); + let _ = myself_clone.cast(SourceCtrl::SetMicDevice(Some(new_device))); + } } } }); - let device = AudioInput::get_default_mic_device_name(); + let mic_device = args + .device + .or_else(|| Some(AudioInput::get_default_mic_name())); let silence_stream_tx = Some(hypr_audio::AudioOutput::silence()); let mut st = SourceState { - device: Some(device), + mic_device, proc: args.proc, token: args.token, mic_muted: Arc::new(AtomicBool::new(false)), @@ -114,11 +119,11 @@ impl Actor for SourceActor { } SourceCtrl::GetMicDevice(reply) => { if !reply.is_closed() { - let _ = reply.send(st.device.clone()); + let _ = reply.send(st.mic_device.clone()); } } SourceCtrl::SetMicDevice(dev) => { - st.device = dev; + st.mic_device = dev; if let Some(cancel_token) = st.stream_cancel_token.take() { cancel_token.cancel(); @@ -158,70 +163,123 @@ async fn start_source_loop( st: &mut SourceState, ) -> Result<(), ActorProcessingErr> { let myself2 = myself.clone(); - let proc = st.proc.clone(); let token = st.token.clone(); let mic_muted = st.mic_muted.clone(); let spk_muted = st.spk_muted.clone(); + let mic_device = st.mic_device.clone(); let stream_cancel_token = CancellationToken::new(); st.stream_cancel_token = Some(stream_cancel_token.clone()); - let mut mic_input = hypr_audio::AudioInput::from_mic(st.device.clone()).unwrap(); - let mic_stream = - ResampledAsyncSource::new(mic_input.stream(), SAMPLE_RATE).chunks(hypr_aec::BLOCK_SIZE); - - let spk_input = hypr_audio::AudioInput::from_speaker().stream(); - let spk_stream = ResampledAsyncSource::new(spk_input, SAMPLE_RATE).chunks(hypr_aec::BLOCK_SIZE); - - let handle = tokio::spawn(async move { - tokio::pin!(mic_stream); - tokio::pin!(spk_stream); - - loop { - tokio::select! { - _ = token.cancelled() => { - drop(mic_stream); - drop(spk_stream); - myself2.stop(None); - return (); - } - _ = stream_cancel_token.cancelled() => { - drop(mic_stream); - drop(spk_stream); - return (); - } - mic_next = mic_stream.next() => { - if let Some(data) = mic_next { - let output_data = if mic_muted.load(Ordering::Relaxed) { - vec![0.0; data.len()] + #[cfg(target_os = "macos")] + let use_mixed = !AudioInput::is_using_headphone(); + + #[cfg(not(target_os = "macos"))] + let use_mixed = false; + + let handle = if use_mixed { + #[cfg(target_os = "macos")] + tokio::spawn(async move { + let mixed_stream = { + let mut mixed_input = AudioInput::from_mixed().unwrap(); + ResampledAsyncSource::new(mixed_input.stream(), SAMPLE_RATE) + .chunks(hypr_aec::BLOCK_SIZE) + }; + + tokio::pin!(mixed_stream); + + loop { + tokio::select! { + _ = token.cancelled() => { + drop(mixed_stream); + myself2.stop(None); + return; + } + _ = stream_cancel_token.cancelled() => { + drop(mixed_stream); + return; + } + mixed_next = mixed_stream.next() => { + if let Some(data) = mixed_next { + // TODO: should be able to mute each stream + let output_data = if mic_muted.load(Ordering::Relaxed) && spk_muted.load(Ordering::Relaxed) { + vec![0.0; data.len()] + } else { + data + }; + + let msg = ProcMsg::Mixed(AudioChunk{ data: output_data }); + let _ = proc.cast(msg); } else { - data - }; - - let msg = ProcMsg::Mic(AudioChunk{ data: output_data }); - let _ = proc.cast(msg); - } else { - break; + break; + } } } - spk_next = spk_stream.next() => { - if let Some(data) = spk_next { - let output_data = if spk_muted.load(Ordering::Relaxed) { - vec![0.0; data.len()] + } + }) + } else { + tokio::spawn(async move { + let mic_stream = { + let mut mic_input = hypr_audio::AudioInput::from_mic(mic_device).unwrap(); + ResampledAsyncSource::new(mic_input.stream(), SAMPLE_RATE) + .chunks(hypr_aec::BLOCK_SIZE) + }; + + let spk_stream = { + let mut spk_input = hypr_audio::AudioInput::from_speaker(); + ResampledAsyncSource::new(spk_input.stream(), SAMPLE_RATE) + .chunks(hypr_aec::BLOCK_SIZE) + }; + + tokio::pin!(mic_stream); + tokio::pin!(spk_stream); + + loop { + tokio::select! { + _ = token.cancelled() => { + drop(mic_stream); + drop(spk_stream); + myself2.stop(None); + return; + } + _ = stream_cancel_token.cancelled() => { + drop(mic_stream); + drop(spk_stream); + return; + } + mic_next = mic_stream.next() => { + if let Some(data) = mic_next { + let output_data = if mic_muted.load(Ordering::Relaxed) { + vec![0.0; data.len()] + } else { + data + }; + + let msg = ProcMsg::Mic(AudioChunk{ data: output_data }); + let _ = proc.cast(msg); } else { - data - }; - - let msg = ProcMsg::Spk(AudioChunk{ data: output_data }); - let _ = proc.cast(msg); - } else { - break; + break; + } + } + spk_next = spk_stream.next() => { + if let Some(data) = spk_next { + let output_data = if spk_muted.load(Ordering::Relaxed) { + vec![0.0; data.len()] + } else { + data + }; + + let msg = ProcMsg::Speaker(AudioChunk{ data: output_data }); + let _ = proc.cast(msg); + } else { + break; + } } } } - } - }); + }) + }; st.run_task = Some(handle); Ok(())