diff --git a/plugins/listener/src/actors/listener.rs b/plugins/listener/src/actors/listener.rs index cf241863b3..1ec882a0fa 100644 --- a/plugins/listener/src/actors/listener.rs +++ b/plugins/listener/src/actors/listener.rs @@ -16,7 +16,8 @@ use crate::SessionEvent; const LISTEN_STREAM_TIMEOUT: Duration = Duration::from_secs(15 * 60); pub enum ListenerMsg { - Audio(Bytes, Bytes), + AudioSingle(Bytes), + AudioDual(Bytes, Bytes), StreamResponse(StreamResponse), StreamError(String), StreamEnded, @@ -42,11 +43,16 @@ pub struct ListenerArgs { pub struct ListenerState { pub args: ListenerArgs, - tx: tokio::sync::mpsc::Sender>, + tx: ChannelSender, rx_task: tokio::task::JoinHandle<()>, shutdown_tx: Option>, } +enum ChannelSender { + Single(tokio::sync::mpsc::Sender>), + Dual(tokio::sync::mpsc::Sender>), +} + pub struct ListenerActor; impl ListenerActor { @@ -98,8 +104,16 @@ impl Actor for ListenerActor { state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { match message { - ListenerMsg::Audio(mic, spk) => { - let _ = state.tx.try_send(MixedMessage::Audio((mic, spk))); + ListenerMsg::AudioSingle(audio) => { + if let ChannelSender::Single(tx) = &state.tx { + let _ = tx.try_send(MixedMessage::Audio(audio)); + } + } + + ListenerMsg::AudioDual(mic, spk) => { + if let ChannelSender::Dual(tx) = &state.tx { + let _ = tx.try_send(MixedMessage::Audio((mic, spk))); + } } ListenerMsg::StreamResponse(mut response) => { @@ -178,15 +192,31 @@ async fn spawn_rx_task( myself: ActorRef, ) -> Result< ( - tokio::sync::mpsc::Sender>, + ChannelSender, tokio::task::JoinHandle<()>, tokio::sync::oneshot::Sender<()>, ), ActorProcessingErr, > { - let (tx, rx) = tokio::sync::mpsc::channel::>(32); - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + if args.mode == crate::actors::ChannelMode::Single { + spawn_rx_task_single(args, myself).await + } else { + spawn_rx_task_dual(args, myself).await + } +} + +fn build_listen_params(args: &ListenerArgs) -> owhisper_interface::ListenParams { + owhisper_interface::ListenParams { + model: Some(args.model.clone()), + languages: args.languages.clone(), + sample_rate: super::SAMPLE_RATE, + redemption_time_ms: Some(if args.onboarding { 60 } else { 400 }), + keywords: args.keywords.clone(), + ..Default::default() + } +} +fn build_extra(args: &ListenerArgs) -> (f64, Extra) { let session_offset_secs = args.session_started_at.elapsed().as_secs_f64(); let started_unix_millis = args .session_started_at_unix @@ -199,87 +229,103 @@ async fn spawn_rx_task( started_unix_millis, }; + (session_offset_secs, extra) +} + +async fn spawn_rx_task_single( + args: ListenerArgs, + myself: ActorRef, +) -> Result< + ( + ChannelSender, + tokio::task::JoinHandle<()>, + tokio::sync::oneshot::Sender<()>, + ), + ActorProcessingErr, +> { + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + let (session_offset_secs, extra) = build_extra(&args); + + let (tx, rx) = tokio::sync::mpsc::channel::>(32); + let rx_task = tokio::spawn(async move { - use crate::actors::ChannelMode; - - if args.mode == ChannelMode::Single { - let client = owhisper_client::ListenClient::builder() - .api_base(args.base_url.clone()) - .api_key(args.api_key.clone()) - .params(owhisper_interface::ListenParams { - model: Some(args.model.clone()), - languages: args.languages.clone(), - sample_rate: super::SAMPLE_RATE, - redemption_time_ms: Some(if args.onboarding { 60 } else { 400 }), - keywords: args.keywords.clone(), - ..Default::default() - }) - .build_single(); - - let outbound = tokio_stream::StreamExt::map( - tokio_stream::wrappers::ReceiverStream::new(rx), - |msg| match msg { - MixedMessage::Audio((_mic, spk)) => MixedMessage::Audio(spk), - MixedMessage::Control(c) => MixedMessage::Control(c), - }, - ); - - let (listen_stream, handle) = match client.from_realtime_audio(outbound).await { - Ok(res) => res, - Err(e) => { - let _ = myself.send_message(ListenerMsg::StreamStartFailed(format!("{:?}", e))); - return; - } - }; - futures_util::pin_mut!(listen_stream); - - process_stream( - listen_stream, - handle, - myself, - shutdown_rx, - session_offset_secs, - extra.clone(), - ) - .await; - } else { - let client = owhisper_client::ListenClient::builder() - .api_base(args.base_url) - .api_key(args.api_key) - .params(owhisper_interface::ListenParams { - model: Some(args.model), - languages: args.languages, - sample_rate: 16000, - redemption_time_ms: Some(if args.onboarding { 60 } else { 400 }), - keywords: args.keywords, - ..Default::default() - }) - .build_dual(); - - let outbound = tokio_stream::wrappers::ReceiverStream::new(rx); - - let (listen_stream, handle) = match client.from_realtime_audio(outbound).await { - Ok(res) => res, - Err(e) => { - let _ = myself.send_message(ListenerMsg::StreamStartFailed(format!("{:?}", e))); - return; - } - }; - futures_util::pin_mut!(listen_stream); - - process_stream( - listen_stream, - handle, - myself, - shutdown_rx, - session_offset_secs, - extra.clone(), - ) - .await; - } + let client = owhisper_client::ListenClient::builder() + .api_base(args.base_url.clone()) + .api_key(args.api_key.clone()) + .params(build_listen_params(&args)) + .build_single(); + + let outbound = tokio_stream::wrappers::ReceiverStream::new(rx); + + let (listen_stream, handle) = match client.from_realtime_audio(outbound).await { + Ok(res) => res, + Err(e) => { + let _ = myself.send_message(ListenerMsg::StreamStartFailed(format!("{:?}", e))); + return; + } + }; + futures_util::pin_mut!(listen_stream); + + process_stream( + listen_stream, + handle, + myself, + shutdown_rx, + session_offset_secs, + extra, + ) + .await; + }); + + Ok((ChannelSender::Single(tx), rx_task, shutdown_tx)) +} + +async fn spawn_rx_task_dual( + args: ListenerArgs, + myself: ActorRef, +) -> Result< + ( + ChannelSender, + tokio::task::JoinHandle<()>, + tokio::sync::oneshot::Sender<()>, + ), + ActorProcessingErr, +> { + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + let (session_offset_secs, extra) = build_extra(&args); + + let (tx, rx) = tokio::sync::mpsc::channel::>(32); + + let rx_task = tokio::spawn(async move { + let client = owhisper_client::ListenClient::builder() + .api_base(args.base_url.clone()) + .api_key(args.api_key.clone()) + .params(build_listen_params(&args)) + .build_dual(); + + let outbound = tokio_stream::wrappers::ReceiverStream::new(rx); + + let (listen_stream, handle) = match client.from_realtime_audio(outbound).await { + Ok(res) => res, + Err(e) => { + let _ = myself.send_message(ListenerMsg::StreamStartFailed(format!("{:?}", e))); + return; + } + }; + futures_util::pin_mut!(listen_stream); + + process_stream( + listen_stream, + handle, + myself, + shutdown_rx, + session_offset_secs, + extra, + ) + .await; }); - Ok((tx, rx_task, shutdown_tx)) + Ok((ChannelSender::Dual(tx), rx_task, shutdown_tx)) } async fn process_stream( diff --git a/plugins/listener/src/actors/mod.rs b/plugins/listener/src/actors/mod.rs index f114251fb4..77965fc4c9 100644 --- a/plugins/listener/src/actors/mod.rs +++ b/plugins/listener/src/actors/mod.rs @@ -11,7 +11,7 @@ pub use recorder::*; pub use source::*; #[cfg(target_os = "macos")] -pub const SAMPLE_RATE: u32 = 24 * 1000; +pub const SAMPLE_RATE: u32 = 16 * 1000; #[cfg(not(target_os = "macos"))] pub const SAMPLE_RATE: u32 = 16 * 1000; diff --git a/plugins/listener/src/actors/source.rs b/plugins/listener/src/actors/source.rs index 3bb89dd5ed..e49edf5750 100644 --- a/plugins/listener/src/actors/source.rs +++ b/plugins/listener/src/actors/source.rs @@ -375,8 +375,12 @@ impl Pipeline { fn dispatch(&mut self, mic: Arc<[f32]>, spk: Arc<[f32]>, mode: ChannelMode) { if let Some(cell) = registry::where_is(RecorderActor::name()) { let actor: ActorRef = cell.into(); - let mixed = Self::mix(mic.as_ref(), spk.as_ref()); - if let Err(e) = actor.cast(RecMsg::Audio(mixed)) { + let audio_for_recording = if mode == ChannelMode::Single { + mic.to_vec() + } else { + Self::mix(mic.as_ref(), spk.as_ref()) + }; + if let Err(e) = actor.cast(RecMsg::Audio(audio_for_recording)) { tracing::error!(error = ?e, "failed_to_send_audio_to_recorder"); } } @@ -387,23 +391,17 @@ impl Pipeline { }; let actor: ActorRef = cell.into(); - let (mic_bytes, spk_bytes) = if mode == ChannelMode::Single { - let mixed = Self::mix(mic.as_ref(), spk.as_ref()); - ( - f32_to_i16_bytes(mic.iter().copied()), - f32_to_i16_bytes(mixed.iter().copied()), - ) + + let result = if mode == ChannelMode::Single { + let audio_bytes = f32_to_i16_bytes(mic.to_vec().iter().copied()); + actor.cast(ListenerMsg::AudioSingle(audio_bytes)) } else { - ( - f32_to_i16_bytes(mic.iter().copied()), - f32_to_i16_bytes(spk.iter().copied()), - ) + let mic_bytes = f32_to_i16_bytes(mic.iter().copied()); + let spk_bytes = f32_to_i16_bytes(spk.iter().copied()); + actor.cast(ListenerMsg::AudioDual(mic_bytes, spk_bytes)) }; - if actor - .cast(ListenerMsg::Audio(mic_bytes, spk_bytes)) - .is_err() - { + if result.is_err() { tracing::warn!(actor = ListenerActor::name(), "cast_failed"); return; }