Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 131 additions & 85 deletions plugins/listener/src/actors/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use crate::SessionEvent;
const LISTEN_STREAM_TIMEOUT: Duration = Duration::from_secs(15 * 60);

pub enum ListenerMsg {
Audio(Bytes, Bytes),
AudioSingle(Bytes),
AudioDual(Bytes, Bytes),
StreamResponse(StreamResponse),
StreamError(String),
StreamEnded,
Expand All @@ -42,11 +43,16 @@ pub struct ListenerArgs {

pub struct ListenerState {
pub args: ListenerArgs,
tx: tokio::sync::mpsc::Sender<MixedMessage<(Bytes, Bytes), ControlMessage>>,
tx: ChannelSender,
rx_task: tokio::task::JoinHandle<()>,
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
}

enum ChannelSender {
Single(tokio::sync::mpsc::Sender<MixedMessage<Bytes, ControlMessage>>),
Dual(tokio::sync::mpsc::Sender<MixedMessage<(Bytes, Bytes), ControlMessage>>),
}

pub struct ListenerActor;

impl ListenerActor {
Expand Down Expand Up @@ -98,8 +104,16 @@ impl Actor for ListenerActor {
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
match message {
ListenerMsg::Audio(mic, spk) => {
let _ = state.tx.try_send(MixedMessage::Audio((mic, spk)));
ListenerMsg::AudioSingle(audio) => {
if let ChannelSender::Single(tx) = &state.tx {
let _ = tx.try_send(MixedMessage::Audio(audio));
}
}

ListenerMsg::AudioDual(mic, spk) => {
if let ChannelSender::Dual(tx) = &state.tx {
let _ = tx.try_send(MixedMessage::Audio((mic, spk)));
}
}

ListenerMsg::StreamResponse(mut response) => {
Expand Down Expand Up @@ -178,15 +192,31 @@ async fn spawn_rx_task(
myself: ActorRef<ListenerMsg>,
) -> Result<
(
tokio::sync::mpsc::Sender<MixedMessage<(Bytes, Bytes), ControlMessage>>,
ChannelSender,
tokio::task::JoinHandle<()>,
tokio::sync::oneshot::Sender<()>,
),
ActorProcessingErr,
> {
let (tx, rx) = tokio::sync::mpsc::channel::<MixedMessage<(Bytes, Bytes), ControlMessage>>(32);
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
if args.mode == crate::actors::ChannelMode::Single {
spawn_rx_task_single(args, myself).await
} else {
spawn_rx_task_dual(args, myself).await
}
}

fn build_listen_params(args: &ListenerArgs) -> owhisper_interface::ListenParams {
owhisper_interface::ListenParams {
model: Some(args.model.clone()),
languages: args.languages.clone(),
sample_rate: super::SAMPLE_RATE,
redemption_time_ms: Some(if args.onboarding { 60 } else { 400 }),
keywords: args.keywords.clone(),
..Default::default()
}
}

fn build_extra(args: &ListenerArgs) -> (f64, Extra) {
let session_offset_secs = args.session_started_at.elapsed().as_secs_f64();
let started_unix_millis = args
.session_started_at_unix
Expand All @@ -199,87 +229,103 @@ async fn spawn_rx_task(
started_unix_millis,
};

(session_offset_secs, extra)
}

async fn spawn_rx_task_single(
args: ListenerArgs,
myself: ActorRef<ListenerMsg>,
) -> Result<
(
ChannelSender,
tokio::task::JoinHandle<()>,
tokio::sync::oneshot::Sender<()>,
),
ActorProcessingErr,
> {
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let (session_offset_secs, extra) = build_extra(&args);

let (tx, rx) = tokio::sync::mpsc::channel::<MixedMessage<Bytes, ControlMessage>>(32);

let rx_task = tokio::spawn(async move {
use crate::actors::ChannelMode;

if args.mode == ChannelMode::Single {
let client = owhisper_client::ListenClient::builder()
.api_base(args.base_url.clone())
.api_key(args.api_key.clone())
.params(owhisper_interface::ListenParams {
model: Some(args.model.clone()),
languages: args.languages.clone(),
sample_rate: super::SAMPLE_RATE,
redemption_time_ms: Some(if args.onboarding { 60 } else { 400 }),
keywords: args.keywords.clone(),
..Default::default()
})
.build_single();

let outbound = tokio_stream::StreamExt::map(
tokio_stream::wrappers::ReceiverStream::new(rx),
|msg| match msg {
MixedMessage::Audio((_mic, spk)) => MixedMessage::Audio(spk),
MixedMessage::Control(c) => MixedMessage::Control(c),
},
);

let (listen_stream, handle) = match client.from_realtime_audio(outbound).await {
Ok(res) => res,
Err(e) => {
let _ = myself.send_message(ListenerMsg::StreamStartFailed(format!("{:?}", e)));
return;
}
};
futures_util::pin_mut!(listen_stream);

process_stream(
listen_stream,
handle,
myself,
shutdown_rx,
session_offset_secs,
extra.clone(),
)
.await;
} else {
let client = owhisper_client::ListenClient::builder()
.api_base(args.base_url)
.api_key(args.api_key)
.params(owhisper_interface::ListenParams {
model: Some(args.model),
languages: args.languages,
sample_rate: 16000,
redemption_time_ms: Some(if args.onboarding { 60 } else { 400 }),
keywords: args.keywords,
..Default::default()
})
.build_dual();

let outbound = tokio_stream::wrappers::ReceiverStream::new(rx);

let (listen_stream, handle) = match client.from_realtime_audio(outbound).await {
Ok(res) => res,
Err(e) => {
let _ = myself.send_message(ListenerMsg::StreamStartFailed(format!("{:?}", e)));
return;
}
};
futures_util::pin_mut!(listen_stream);

process_stream(
listen_stream,
handle,
myself,
shutdown_rx,
session_offset_secs,
extra.clone(),
)
.await;
}
let client = owhisper_client::ListenClient::builder()
.api_base(args.base_url.clone())
.api_key(args.api_key.clone())
.params(build_listen_params(&args))
.build_single();

let outbound = tokio_stream::wrappers::ReceiverStream::new(rx);

let (listen_stream, handle) = match client.from_realtime_audio(outbound).await {
Ok(res) => res,
Err(e) => {
let _ = myself.send_message(ListenerMsg::StreamStartFailed(format!("{:?}", e)));
return;
}
};
futures_util::pin_mut!(listen_stream);

process_stream(
listen_stream,
handle,
myself,
shutdown_rx,
session_offset_secs,
extra,
)
.await;
});

Ok((ChannelSender::Single(tx), rx_task, shutdown_tx))
}

async fn spawn_rx_task_dual(
args: ListenerArgs,
myself: ActorRef<ListenerMsg>,
) -> Result<
(
ChannelSender,
tokio::task::JoinHandle<()>,
tokio::sync::oneshot::Sender<()>,
),
ActorProcessingErr,
> {
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let (session_offset_secs, extra) = build_extra(&args);

let (tx, rx) = tokio::sync::mpsc::channel::<MixedMessage<(Bytes, Bytes), ControlMessage>>(32);

let rx_task = tokio::spawn(async move {
let client = owhisper_client::ListenClient::builder()
.api_base(args.base_url.clone())
.api_key(args.api_key.clone())
.params(build_listen_params(&args))
.build_dual();

let outbound = tokio_stream::wrappers::ReceiverStream::new(rx);

let (listen_stream, handle) = match client.from_realtime_audio(outbound).await {
Ok(res) => res,
Err(e) => {
let _ = myself.send_message(ListenerMsg::StreamStartFailed(format!("{:?}", e)));
return;
}
};
futures_util::pin_mut!(listen_stream);

process_stream(
listen_stream,
handle,
myself,
shutdown_rx,
session_offset_secs,
extra,
)
.await;
});

Ok((tx, rx_task, shutdown_tx))
Ok((ChannelSender::Dual(tx), rx_task, shutdown_tx))
}

async fn process_stream<S, E>(
Expand Down
2 changes: 1 addition & 1 deletion plugins/listener/src/actors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub use recorder::*;
pub use source::*;

#[cfg(target_os = "macos")]
pub const SAMPLE_RATE: u32 = 24 * 1000;
pub const SAMPLE_RATE: u32 = 16 * 1000;
#[cfg(not(target_os = "macos"))]
pub const SAMPLE_RATE: u32 = 16 * 1000;

Expand Down
30 changes: 14 additions & 16 deletions plugins/listener/src/actors/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,12 @@ impl Pipeline {
fn dispatch(&mut self, mic: Arc<[f32]>, spk: Arc<[f32]>, mode: ChannelMode) {
if let Some(cell) = registry::where_is(RecorderActor::name()) {
let actor: ActorRef<RecMsg> = cell.into();
let mixed = Self::mix(mic.as_ref(), spk.as_ref());
if let Err(e) = actor.cast(RecMsg::Audio(mixed)) {
let audio_for_recording = if mode == ChannelMode::Single {
mic.to_vec()
} else {
Self::mix(mic.as_ref(), spk.as_ref())
};
if let Err(e) = actor.cast(RecMsg::Audio(audio_for_recording)) {
tracing::error!(error = ?e, "failed_to_send_audio_to_recorder");
}
}
Expand All @@ -387,23 +391,17 @@ impl Pipeline {
};

let actor: ActorRef<ListenerMsg> = cell.into();
let (mic_bytes, spk_bytes) = if mode == ChannelMode::Single {
let mixed = Self::mix(mic.as_ref(), spk.as_ref());
(
f32_to_i16_bytes(mic.iter().copied()),
f32_to_i16_bytes(mixed.iter().copied()),
)

let result = if mode == ChannelMode::Single {
let audio_bytes = f32_to_i16_bytes(mic.to_vec().iter().copied());
actor.cast(ListenerMsg::AudioSingle(audio_bytes))
} else {
(
f32_to_i16_bytes(mic.iter().copied()),
f32_to_i16_bytes(spk.iter().copied()),
)
let mic_bytes = f32_to_i16_bytes(mic.iter().copied());
let spk_bytes = f32_to_i16_bytes(spk.iter().copied());
actor.cast(ListenerMsg::AudioDual(mic_bytes, spk_bytes))
};

if actor
.cast(ListenerMsg::Audio(mic_bytes, spk_bytes))
.is_err()
{
if result.is_err() {
tracing::warn!(actor = ListenerActor::name(), "cast_failed");
return;
}
Expand Down
Loading