diff --git a/Cargo.lock b/Cargo.lock index a70de31f39..ec1f3e3ffc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -919,12 +919,15 @@ name = "audio-utils" version = "0.1.0" dependencies = [ "bytes", + "dasp", + "data", "futures-util", "hound", "kalosm-sound", "rodio", "rubato", "thiserror 2.0.17", + "tokio", "vorbis_rs", ] @@ -2596,8 +2599,9 @@ dependencies = [ [[package]] name = "cidre" -version = "0.11.3" -source = "git+https://github.com/yury/cidre?rev=a9587fa#a9587fa1d4ff6f0d5d082849ae7c62880fd739f7" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "885c8d7613de18ff99dd2290ec7eb8885652c9d87bd45ad876f8eb0741384ef1" dependencies = [ "cidre-macros", "half", diff --git a/Cargo.toml b/Cargo.toml index 13533c6b83..efdd4df004 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -209,7 +209,7 @@ specta = "2.0.0-rc.22" specta-typescript = "0.0.9" tauri-specta = "2.0.0-rc.21" -cidre = { git = "https://github.com/yury/cidre", rev = "a9587fa" } +cidre = "0.11.4" cpal = "0.15.3" dasp = "0.11.0" flume = "0.11.1" diff --git a/crates/audio-utils/Cargo.toml b/crates/audio-utils/Cargo.toml index f062c9b527..68fa00a832 100644 --- a/crates/audio-utils/Cargo.toml +++ b/crates/audio-utils/Cargo.toml @@ -9,7 +9,12 @@ futures-util = { workspace = true } kalosm-sound = { workspace = true, default-features = false } thiserror = { workspace = true } +dasp = { workspace = true } hound = { workspace = true } rodio = { workspace = true } rubato = "0.16.2" vorbis_rs = { workspace = true } + +[dev-dependencies] +hypr-data = { workspace = true } +tokio = { workspace = true, features = ["full"] } diff --git a/crates/audio-utils/src/error.rs b/crates/audio-utils/src/error.rs index 44a6b26718..62a1ae219c 100644 --- a/crates/audio-utils/src/error.rs +++ b/crates/audio-utils/src/error.rs @@ -12,8 +12,6 @@ pub enum Error { Hound(#[from] hound::Error), #[error(transparent)] Vorbis(#[from] vorbis_rs::VorbisError), - #[error("vorbis channel count mismatch (expected {expected}, actual {actual})")] - ChannelCountMismatch { expected: u8, actual: u8 }, #[error("vorbis channel data length mismatch for channel {channel}")] ChannelDataLengthMismatch { channel: usize }, #[error("unsupported channel count {count}")] diff --git a/crates/audio-utils/src/lib.rs b/crates/audio-utils/src/lib.rs index 609deffc29..22111dce24 100644 --- a/crates/audio-utils/src/lib.rs +++ b/crates/audio-utils/src/lib.rs @@ -5,8 +5,11 @@ use futures_util::{Stream, StreamExt}; use kalosm_sound::AsyncSource; mod error; -pub use error::*; +mod resampler; mod vorbis; + +pub use error::*; +pub use resampler::*; pub use vorbis::*; pub use rodio::Source; diff --git a/crates/audio-utils/src/resampler/driver.rs b/crates/audio-utils/src/resampler/driver.rs new file mode 100644 index 0000000000..6421be1f52 --- /dev/null +++ b/crates/audio-utils/src/resampler/driver.rs @@ -0,0 +1,194 @@ +use std::collections::VecDeque; + +use rubato::Resampler; + +/// Wraps a rubato Resampler with queues to enable sample-by-sample input and fixed-size chunk output. +/// Manages buffering between the streaming input and the block-based resampler requirements. +pub(crate) struct RubatoChunkResampler, const CHANNELS: usize> { + resampler: R, + output_chunk_size: usize, + input_block_size: usize, + input_queue: VecDeque, + rubato_input_buffer: Vec>, + rubato_output_buffer: Vec>, + output_queue: VecDeque, +} + +impl, const CHANNELS: usize> RubatoChunkResampler { + /// Creates a new wrapper with pre-allocated buffers sized for the resampler's requirements. + /// Allocates capacity upfront to avoid reallocations during audio processing. + pub(crate) fn new(resampler: R, output_chunk_size: usize, input_block_size: usize) -> Self { + let rubato_input_buffer = resampler.input_buffer_allocate(false); + let rubato_output_buffer = resampler.output_buffer_allocate(true); + let output_queue_capacity = resampler.output_frames_max().max(output_chunk_size); + let input_queue_capacity = input_block_size.max(1) * CHANNELS; + + Self { + resampler, + output_chunk_size, + input_block_size, + input_queue: VecDeque::with_capacity(input_queue_capacity), + rubato_input_buffer, + rubato_output_buffer, + output_queue: VecDeque::with_capacity(output_queue_capacity), + } + } + + /// Checks whether any resampled output is available. + pub(crate) fn output_is_empty(&self) -> bool { + self.output_queue.is_empty() + } + + /// Checks whether at least one full output chunk is ready to be consumed. + pub(crate) fn has_full_chunk(&self) -> bool { + self.output_queue.len() >= self.output_chunk_size + } + + /// Extracts exactly one output chunk if available, leaving the rest in the queue. + /// Returns None if insufficient samples are available. + pub(crate) fn take_full_chunk(&mut self) -> Option> { + if self.output_queue.len() >= self.output_chunk_size { + Some(self.output_queue.drain(..self.output_chunk_size).collect()) + } else { + None + } + } + + /// Drains all available output samples regardless of chunk boundaries. + /// Used when flushing remaining samples at stream end. + pub(crate) fn take_all_output(&mut self) -> Option> { + if self.output_queue.is_empty() { + None + } else { + Some(self.output_queue.drain(..).collect()) + } + } + + /// Checks whether any input samples are waiting to be processed. + pub(crate) fn has_input(&self) -> bool { + !self.input_queue.is_empty() + } + + /// Queues a single input sample for resampling. + pub(crate) fn push_sample(&mut self, sample: f32) { + self.input_queue.push_back(sample); + } + + /// Processes all complete input blocks currently available in the queue. + /// Stops when insufficient input remains for another block. + /// Returns whether any output was produced. + pub(crate) fn process_all_ready_blocks(&mut self) -> Result { + let mut produced_output = false; + loop { + let frames_needed = self.resampler.input_frames_next(); + if self.input_queue.len() < frames_needed { + break; + } + if self.process_one_block()? { + produced_output = true; + } + } + Ok(produced_output) + } + + /// Processes exactly one input block if enough samples are available. + /// Returns whether output was produced. + pub(crate) fn process_one_block(&mut self) -> Result { + let frames_needed = self.resampler.input_frames_next(); + if self.input_queue.len() < frames_needed { + return Ok(false); + } + self.rubato_input_buffer[0].clear(); + self.rubato_input_buffer[0].extend(self.input_queue.drain(..frames_needed)); + let produced_output = self.process_staged_input()?; + self.rubato_input_buffer[0].clear(); + Ok(produced_output) + } + + /// Processes an incomplete input block, optionally padding with zeros to meet resampler requirements. + /// Used for handling the final partial block at stream end when zero_pad is true. + pub(crate) fn process_partial_block(&mut self, zero_pad: bool) -> Result { + if self.input_queue.is_empty() { + return Ok(false); + } + + let frames_needed = self.resampler.input_frames_next(); + let frames_available = self.input_queue.len(); + + if !zero_pad && frames_available < frames_needed { + return Ok(false); + } + + self.rubato_input_buffer[0].clear(); + self.rubato_input_buffer[0].extend(self.input_queue.drain(..frames_available)); + if frames_available < frames_needed { + if zero_pad { + self.rubato_input_buffer[0].resize(frames_needed, 0.0); + } else { + return Ok(false); + } + } + + let produced_output = self.process_staged_input()?; + self.rubato_input_buffer[0].clear(); + Ok(produced_output) + } + + /// Discards all pending input samples without processing them. + pub(crate) fn clear_input(&mut self) { + self.input_queue.clear(); + } + + /// Hot-swaps the underlying resampler instance while preserving queue state. + /// Reallocates buffers and adjusts capacities as needed. Clears input queue to prevent + /// mixing samples from different configurations. + pub(crate) fn rebind_resampler( + &mut self, + resampler: R, + output_chunk_size: usize, + input_block_size: usize, + ) { + self.resampler = resampler; + self.output_chunk_size = output_chunk_size; + self.input_block_size = input_block_size; + self.rubato_input_buffer = self.resampler.input_buffer_allocate(false); + self.rubato_output_buffer = self.resampler.output_buffer_allocate(true); + + let desired_output_capacity = self + .resampler + .output_frames_max() + .max(self.output_chunk_size); + if self.output_queue.capacity() < desired_output_capacity { + self.output_queue + .reserve(desired_output_capacity - self.output_queue.capacity()); + } + + let desired_input_capacity = self.input_block_size.max(1) * CHANNELS; + if self.input_queue.capacity() < desired_input_capacity { + self.input_queue + .reserve(desired_input_capacity - self.input_queue.capacity()); + } + self.clear_input(); + } + + /// Runs the resampler on the staged input buffer and queues the output. + /// Returns whether any output frames were produced. + fn process_staged_input(&mut self) -> Result { + let (_, frames_produced) = self.resampler.process_into_buffer( + &self.rubato_input_buffer[..], + &mut self.rubato_output_buffer[..], + None, + )?; + if frames_produced > 0 { + self.output_queue.extend( + self.rubato_output_buffer[0] + .iter() + .take(frames_produced) + .copied(), + ); + Ok(true) + } else { + Ok(false) + } + } +} diff --git a/crates/audio-utils/src/resampler/dynamic_new.rs b/crates/audio-utils/src/resampler/dynamic_new.rs new file mode 100644 index 0000000000..26a6124478 --- /dev/null +++ b/crates/audio-utils/src/resampler/dynamic_new.rs @@ -0,0 +1,179 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use super::driver::RubatoChunkResampler; +use futures_util::{pin_mut, Stream}; +use kalosm_sound::AsyncSource; +use rubato::{FastFixedIn, PolynomialDegree}; + +pub trait ResampleExtDynamicNew: AsyncSource + Sized + Unpin { + fn resampled_chunks( + self, + target_rate: u32, + output_chunk_size: usize, + ) -> Result, crate::Error> { + ResamplerDynamicNew::new(self, target_rate, output_chunk_size) + } +} + +impl ResampleExtDynamicNew for T where T: AsyncSource + Sized + Unpin {} + +pub struct ResamplerDynamicNew +where + S: AsyncSource + Unpin, +{ + source: S, + target_rate: u32, + output_chunk_size: usize, + input_block_size: usize, + driver: RubatoChunkResampler, 1>, + last_source_rate: u32, + draining: bool, +} + +impl ResamplerDynamicNew +where + S: AsyncSource + Unpin, +{ + pub fn new( + source: S, + target_rate: u32, + output_chunk_size: usize, + ) -> Result { + let source_rate = source.sample_rate(); + let input_block_size = output_chunk_size; + let ratio = target_rate as f64 / source_rate as f64; + let resampler = Self::create_resampler(ratio, input_block_size)?; + let driver = RubatoChunkResampler::new(resampler, output_chunk_size, input_block_size); + Ok(Self { + source, + target_rate, + output_chunk_size, + input_block_size, + driver, + last_source_rate: source_rate, + draining: false, + }) + } + + fn rebuild_resampler(&mut self, new_rate: u32) -> Result<(), crate::Error> { + let ratio = self.target_rate as f64 / new_rate as f64; + let resampler = Self::create_resampler(ratio, self.input_block_size)?; + self.driver + .rebind_resampler(resampler, self.output_chunk_size, self.input_block_size); + self.last_source_rate = new_rate; + Ok(()) + } + + fn try_yield_chunk(&mut self) -> Option> { + if self.driver.has_full_chunk() { + self.driver.take_full_chunk() + } else if self.draining && !self.driver.output_is_empty() { + self.driver.take_all_output() + } else { + None + } + } + + fn drain_for_rate_change(&mut self) -> Result { + self.driver.process_all_ready_blocks()?; + if self.driver.has_input() { + self.driver.process_partial_block(true)?; + } + Ok(self.driver.output_is_empty()) + } + + fn drain_at_eos(&mut self) -> Result<(), crate::Error> { + self.driver.process_all_ready_blocks()?; + if self.driver.has_input() { + self.driver.process_partial_block(true)?; + } + Ok(()) + } + + fn create_resampler( + ratio: f64, + input_block_size: usize, + ) -> Result, crate::Error> { + FastFixedIn::::new( + ratio, + 2.0, + PolynomialDegree::Quintic, + input_block_size.max(1), + 1, + ) + .map_err(Into::into) + } +} + +impl Stream for ResamplerDynamicNew +where + S: AsyncSource + Unpin, +{ + type Item = Result, crate::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = Pin::into_inner(self); + + loop { + if let Some(chunk) = me.try_yield_chunk() { + return Poll::Ready(Some(Ok(chunk))); + } + + if me.draining { + return Poll::Ready(None); + } + + let current_rate = me.source.sample_rate(); + if current_rate != me.last_source_rate { + match me.drain_for_rate_change() { + Ok(true) => { + if let Err(err) = me.rebuild_resampler(current_rate) { + return Poll::Ready(Some(Err(err))); + } + continue; + } + Ok(false) => { + if me.driver.has_full_chunk() { + if let Some(chunk) = me.driver.take_full_chunk() { + return Poll::Ready(Some(Ok(chunk))); + } + } + if !me.driver.output_is_empty() { + if let Some(chunk) = me.driver.take_all_output() { + return Poll::Ready(Some(Ok(chunk))); + } + } + continue; + } + Err(err) => return Poll::Ready(Some(Err(err))), + } + } + + match me.driver.process_all_ready_blocks() { + Ok(true) => continue, + Ok(false) => {} + Err(err) => return Poll::Ready(Some(Err(err))), + } + + let sample_poll = { + let inner = me.source.as_stream(); + pin_mut!(inner); + inner.poll_next(cx) + }; + + match sample_poll { + Poll::Ready(Some(sample)) => { + me.driver.push_sample(sample); + } + Poll::Ready(None) => { + if let Err(err) = me.drain_at_eos() { + return Poll::Ready(Some(Err(err))); + } + me.draining = true; + } + Poll::Pending => return Poll::Pending, + } + } + } +} diff --git a/crates/audio-utils/src/resampler/dynamic_old.rs b/crates/audio-utils/src/resampler/dynamic_old.rs new file mode 100644 index 0000000000..b8fa3def70 --- /dev/null +++ b/crates/audio-utils/src/resampler/dynamic_old.rs @@ -0,0 +1,99 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use dasp::interpolate::Interpolator; +use futures_util::{pin_mut, Stream}; +use kalosm_sound::AsyncSource; + +pub struct ResamplerDynamicOld { + source: S, + target_sample_rate: u32, + last_source_rate: u32, + ratio: f64, + + phase: f64, + + interp: dasp::interpolate::linear::Linear, + last_sample: f32, + seeded: bool, +} + +impl ResamplerDynamicOld { + pub fn new(source: S, target_sample_rate: u32) -> Self { + let initial_rate = source.sample_rate(); + Self { + source, + target_sample_rate, + last_source_rate: initial_rate, + ratio: initial_rate as f64 / target_sample_rate as f64, + phase: 0.0, + interp: dasp::interpolate::linear::Linear::new(0.0, 0.0), + last_sample: 0.0, + seeded: false, + } + } + + #[inline] + fn handle_rate_change(&mut self) { + let new_rate = self.source.sample_rate(); + if new_rate == self.last_source_rate { + return; + } + + self.last_source_rate = new_rate; + self.ratio = new_rate as f64 / self.target_sample_rate as f64; + self.phase = 0.0; + self.interp = dasp::interpolate::linear::Linear::new(self.last_sample, self.last_sample); + } +} + +impl Stream for ResamplerDynamicOld { + type Item = f32; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.get_mut(); + + me.handle_rate_change(); + + let inner = me.source.as_stream(); + pin_mut!(inner); + + if !me.seeded { + match inner.as_mut().poll_next(cx) { + Poll::Ready(Some(frame)) => { + me.last_sample = frame; + me.interp = dasp::interpolate::linear::Linear::new(frame, frame); + me.seeded = true; + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + } + + while me.phase >= 1.0 { + match inner.as_mut().poll_next(cx) { + Poll::Ready(Some(frame)) => { + me.phase -= 1.0; + me.last_sample = frame; + me.interp.next_source_frame(frame); + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + } + + let out = me.interp.interpolate(me.phase); + me.phase += me.ratio; + Poll::Ready(Some(out)) + } +} + +impl AsyncSource for ResamplerDynamicOld { + fn as_stream(&mut self) -> impl Stream + '_ { + self + } + + fn sample_rate(&self) -> u32 { + self.target_sample_rate + } +} diff --git a/crates/audio-utils/src/resampler/mod.rs b/crates/audio-utils/src/resampler/mod.rs new file mode 100644 index 0000000000..17a597a9e4 --- /dev/null +++ b/crates/audio-utils/src/resampler/mod.rs @@ -0,0 +1,188 @@ +mod driver; +mod dynamic_new; +mod dynamic_old; +mod static_new; + +pub use dynamic_new::*; +pub use dynamic_old::*; +pub use static_new::*; + +#[cfg(test)] +mod tests { + use super::*; + use futures_util::{Stream, StreamExt}; + use kalosm_sound::AsyncSource; + use rodio::Source; + use std::pin::Pin; + use std::task::{Context, Poll}; + + fn get_samples_with_rate(path: impl AsRef) -> (Vec, u32) { + let source = + rodio::Decoder::new(std::io::BufReader::new(std::fs::File::open(path).unwrap())) + .unwrap(); + + let sample_rate = rodio::Source::sample_rate(&source); + let samples = source.convert_samples::().collect(); + (samples, sample_rate) + } + + #[derive(Clone)] + struct DynamicRateSource { + segments: Vec<(Vec, u32)>, + current_segment: usize, + current_position: usize, + poll_count: usize, + pending_yield: bool, + } + + impl DynamicRateSource { + fn new(segments: Vec<(Vec, u32)>) -> Self { + Self { + segments, + current_segment: 0, + current_position: 0, + poll_count: 0, + pending_yield: false, + } + } + } + + impl AsyncSource for DynamicRateSource { + fn as_stream(&mut self) -> impl Stream + '_ { + DynamicRateStream { source: self } + } + + fn sample_rate(&self) -> u32 { + if self.current_segment < self.segments.len() { + self.segments[self.current_segment].1 + } else { + 16000 + } + } + } + + struct DynamicRateStream<'a> { + source: &'a mut DynamicRateSource, + } + + impl<'a> Stream for DynamicRateStream<'a> { + type Item = f32; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let source = &mut self.source; + + source.poll_count += 1; + if source.pending_yield { + source.pending_yield = false; + } else if source.poll_count % 1000 == 0 { + let waker = cx.waker().clone(); + source.pending_yield = true; + tokio::spawn(async move { + tokio::task::yield_now().await; + waker.wake(); + }); + return Poll::Pending; + } + + while source.current_segment < source.segments.len() { + let (samples, _rate) = &source.segments[source.current_segment]; + + if source.current_position < samples.len() { + let sample = samples[source.current_position]; + source.current_position += 1; + return Poll::Ready(Some(sample)); + } + + source.current_segment += 1; + source.current_position = 0; + } + + Poll::Ready(None) + } + } + + fn create_test_source() -> DynamicRateSource { + DynamicRateSource::new(vec![ + get_samples_with_rate(hypr_data::english_1::AUDIO_PART1_8000HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART2_16000HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART3_22050HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART4_32000HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART5_44100HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART6_48000HZ_PATH), + ]) + } + + macro_rules! write_wav { + ($path:expr, $sample_rate:expr, $samples:expr $(,)?) => {{ + let spec = hound::WavSpec { + channels: 1, + sample_rate: $sample_rate, + bits_per_sample: 32, + sample_format: hound::SampleFormat::Float, + }; + + let mut writer = hound::WavWriter::create($path, spec).unwrap(); + for sample in $samples { + writer.write_sample(sample).unwrap(); + } + writer.finalize().unwrap(); + }}; + } + + #[tokio::test] + async fn test_kalosm_builtin_resampler() { + let source = create_test_source(); + let resampled = source.resample(16000); + assert_eq!(resampled.collect::>().await.len(), 9896247); + } + + #[tokio::test] + async fn test_dynamic_old_resampler() { + let source = create_test_source(); + let samples = ResamplerDynamicOld::new(source, 16000) + .collect::>() + .await; + + assert_eq!(samples.len(), 2791777); + write_wav!("dynamic_old_resampler.wav", 16000, samples.iter().copied()); + } + + #[tokio::test] + async fn test_dynamic_new_resampler() { + let source = create_test_source(); + let chunk_size = 1920; + let resampler = ResamplerDynamicNew::new(source, 16000, chunk_size).unwrap(); + + let chunks: Vec<_> = resampler.collect().await; + let total_samples: usize = chunks.iter().map(|c| c.as_ref().unwrap().len()).sum(); + + assert!((total_samples as i64 - 2784000).abs() < 100000); + + write_wav!( + "dynamic_new_resampler.wav", + 16000, + chunks.iter().flatten().flatten().copied() + ); + } + + #[tokio::test] + async fn test_static_new_resampler() { + let static_source = DynamicRateSource::new(vec![get_samples_with_rate( + hypr_data::english_1::AUDIO_PART1_8000HZ_PATH, + )]); + + let chunk_size = 1920; + let resampler = ResamplerStaticNew::new(static_source, 16000, chunk_size).unwrap(); + + let chunks: Vec<_> = resampler.collect().await; + let total_samples: usize = chunks.iter().map(|c| c.as_ref().unwrap().len()).sum(); + + assert!(total_samples > 0); + + write_wav!( + "static_new_resampler.wav", + 16000, + chunks.iter().flatten().flatten().copied() + ); + } +} diff --git a/crates/audio-utils/src/resampler/static_new.rs b/crates/audio-utils/src/resampler/static_new.rs new file mode 100644 index 0000000000..e1128fe4b7 --- /dev/null +++ b/crates/audio-utils/src/resampler/static_new.rs @@ -0,0 +1,129 @@ +use futures_util::{pin_mut, Stream}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use kalosm_sound::AsyncSource; +use rubato::{FastFixedIn, PolynomialDegree}; + +use super::driver::RubatoChunkResampler; + +pub trait AsyncSourceChunkResampleExt: AsyncSource + Sized + Unpin { + fn resampled_chunks( + self, + target_rate: u32, + output_chunk_size: usize, + ) -> Result, crate::Error> { + ResamplerStaticNew::new(self, target_rate, output_chunk_size) + } +} + +impl AsyncSourceChunkResampleExt for T where T: AsyncSource + Sized + Unpin {} + +pub struct ResamplerStaticNew +where + S: AsyncSource + Unpin, +{ + source: S, + driver: RubatoChunkResampler, 1>, + finished: bool, +} + +impl ResamplerStaticNew +where + S: AsyncSource + Unpin, +{ + pub fn new( + source: S, + target_rate: u32, + output_chunk_size: usize, + ) -> Result { + let driver = Self::build_driver(&source, target_rate, output_chunk_size)?; + + Ok(Self { + source, + driver, + finished: false, + }) + } + + fn build_driver( + source: &S, + target_rate: u32, + output_chunk_size: usize, + ) -> Result, 1>, crate::Error> { + let source_rate = source.sample_rate(); + let input_block_size = output_chunk_size; + let ratio = target_rate as f64 / source_rate as f64; + + let resampler = FastFixedIn::::new( + ratio, + 2.0, + PolynomialDegree::Quintic, + input_block_size.max(1), + 1, + )?; + + let driver = RubatoChunkResampler::new(resampler, output_chunk_size, input_block_size); + Ok(driver) + } + + fn finalize(&mut self) -> Result<(), crate::Error> { + if self.finished { + return Ok(()); + } + + self.driver.process_all_ready_blocks()?; + + if self.driver.has_input() { + self.driver.process_partial_block(true)?; + } + + self.finished = true; + Ok(()) + } +} + +impl Stream for ResamplerStaticNew +where + S: AsyncSource + Unpin, +{ + type Item = Result, crate::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = Pin::into_inner(self); + + loop { + if let Some(chunk) = me.driver.take_full_chunk() { + return Poll::Ready(Some(Ok(chunk))); + } + + if me.finished { + return Poll::Ready(me.driver.take_all_output().map(Ok)); + } + + match me.driver.process_one_block() { + Ok(true) => continue, + Ok(false) => {} + Err(err) => return Poll::Ready(Some(Err(err))), + } + + let sample_poll = { + let inner = me.source.as_stream(); + pin_mut!(inner); + inner.poll_next(cx) + }; + + match sample_poll { + Poll::Ready(Some(sample)) => { + me.driver.push_sample(sample); + } + Poll::Ready(None) => { + if let Err(err) = me.finalize() { + return Poll::Ready(Some(Err(err))); + } + } + Poll::Pending => return Poll::Pending, + } + } + } +} diff --git a/crates/audio/src/lib.rs b/crates/audio/src/lib.rs index 0f7479b962..b8dd7c77ef 100644 --- a/crates/audio/src/lib.rs +++ b/crates/audio/src/lib.rs @@ -2,7 +2,6 @@ mod device_monitor; mod errors; mod mic; mod norm; -mod resampler; mod speaker; mod utils; @@ -10,7 +9,6 @@ pub use device_monitor::*; pub use errors::*; pub use mic::*; pub use norm::*; -pub use resampler::*; pub use speaker::*; pub use utils::*; diff --git a/crates/audio/src/resampler.rs b/crates/audio/src/resampler.rs deleted file mode 100644 index 268cc1ee56..0000000000 --- a/crates/audio/src/resampler.rs +++ /dev/null @@ -1,248 +0,0 @@ -use std::pin::Pin; -use std::task::{Context, Poll}; - -use dasp::interpolate::Interpolator; -use futures_util::{pin_mut, Stream}; -use kalosm_sound::AsyncSource; - -pub struct ResampledAsyncSource { - source: S, - target_sample_rate: u32, - last_source_rate: u32, - ratio: f64, - - phase: f64, - - interp: dasp::interpolate::linear::Linear, - last_sample: f32, - seeded: bool, -} - -impl ResampledAsyncSource { - pub fn new(source: S, target_sample_rate: u32) -> Self { - let initial_rate = source.sample_rate(); - Self { - source, - target_sample_rate, - last_source_rate: initial_rate, - ratio: initial_rate as f64 / target_sample_rate as f64, - phase: 0.0, - interp: dasp::interpolate::linear::Linear::new(0.0, 0.0), - last_sample: 0.0, - seeded: false, - } - } - - #[inline] - fn handle_rate_change(&mut self) { - let new_rate = self.source.sample_rate(); - if new_rate == self.last_source_rate { - return; - } - - self.last_source_rate = new_rate; - self.ratio = new_rate as f64 / self.target_sample_rate as f64; - self.phase = 0.0; - self.interp = dasp::interpolate::linear::Linear::new(self.last_sample, self.last_sample); - } -} - -impl Stream for ResampledAsyncSource { - type Item = f32; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let me = self.get_mut(); - - me.handle_rate_change(); - - let inner = me.source.as_stream(); - pin_mut!(inner); - - if !me.seeded { - match inner.as_mut().poll_next(cx) { - Poll::Ready(Some(frame)) => { - me.last_sample = frame; - me.interp = dasp::interpolate::linear::Linear::new(frame, frame); - me.seeded = true; - } - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, - } - } - - while me.phase >= 1.0 { - match inner.as_mut().poll_next(cx) { - Poll::Ready(Some(frame)) => { - me.phase -= 1.0; - me.last_sample = frame; - me.interp.next_source_frame(frame); - } - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, - } - } - - let out = me.interp.interpolate(me.phase); - me.phase += me.ratio; - Poll::Ready(Some(out)) - } -} - -impl AsyncSource for ResampledAsyncSource { - fn as_stream(&mut self) -> impl Stream + '_ { - self - } - - fn sample_rate(&self) -> u32 { - self.target_sample_rate - } -} - -#[cfg(test)] -mod tests { - use futures_util::{Stream, StreamExt}; - use kalosm_sound::AsyncSource; - use rodio::Source; - use std::pin::Pin; - use std::task::{Context, Poll}; - - use crate::ResampledAsyncSource; - - fn get_samples_with_rate(path: impl AsRef) -> (Vec, u32) { - let source = - rodio::Decoder::new(std::io::BufReader::new(std::fs::File::open(path).unwrap())) - .unwrap(); - - let sample_rate = AsyncSource::sample_rate(&source); - let samples = source.convert_samples::().collect(); - (samples, sample_rate) - } - - #[derive(Clone)] - struct DynamicRateSource { - segments: Vec<(Vec, u32)>, - current_segment: usize, - current_position: usize, - } - - impl DynamicRateSource { - fn new(segments: Vec<(Vec, u32)>) -> Self { - Self { - segments, - current_segment: 0, - current_position: 0, - } - } - } - - impl AsyncSource for DynamicRateSource { - fn as_stream(&mut self) -> impl Stream + '_ { - DynamicRateStream { source: self } - } - - fn sample_rate(&self) -> u32 { - if self.current_segment < self.segments.len() { - self.segments[self.current_segment].1 - } else { - unreachable!() - } - } - } - - struct DynamicRateStream<'a> { - source: &'a mut DynamicRateSource, - } - - impl<'a> Stream for DynamicRateStream<'a> { - type Item = f32; - - fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - let source = &mut self.source; - - while source.current_segment < source.segments.len() { - let (samples, _rate) = &source.segments[source.current_segment]; - - if source.current_position < samples.len() { - let sample = samples[source.current_position]; - source.current_position += 1; - return Poll::Ready(Some(sample)); - } - - source.current_segment += 1; - source.current_position = 0; - } - - Poll::Ready(None) - } - } - - #[tokio::test] - async fn test_existing_resampler() { - let source = DynamicRateSource::new(vec![ - get_samples_with_rate(hypr_data::english_1::AUDIO_PART1_8000HZ_PATH), - get_samples_with_rate(hypr_data::english_1::AUDIO_PART2_16000HZ_PATH), - get_samples_with_rate(hypr_data::english_1::AUDIO_PART3_22050HZ_PATH), - get_samples_with_rate(hypr_data::english_1::AUDIO_PART4_32000HZ_PATH), - get_samples_with_rate(hypr_data::english_1::AUDIO_PART5_44100HZ_PATH), - get_samples_with_rate(hypr_data::english_1::AUDIO_PART6_48000HZ_PATH), - ]); - - { - let resampled = source.clone().resample(16000); - assert!(resampled.collect::>().await.len() == 9896247); - } - - { - let mut resampled = source.clone().resample(16000); - - let mut out_wav = hound::WavWriter::create( - "./out_1.wav", - hound::WavSpec { - channels: 1, - sample_rate: 16000, - bits_per_sample: 32, - sample_format: hound::SampleFormat::Float, - }, - ) - .unwrap(); - while let Some(sample) = resampled.next().await { - out_wav.write_sample(sample).unwrap(); - } - } - } - - #[tokio::test] - async fn test_new_resampler() { - let source = DynamicRateSource::new(vec![ - get_samples_with_rate(hypr_data::english_1::AUDIO_PART1_8000HZ_PATH), - get_samples_with_rate(hypr_data::english_1::AUDIO_PART2_16000HZ_PATH), - get_samples_with_rate(hypr_data::english_1::AUDIO_PART3_22050HZ_PATH), - get_samples_with_rate(hypr_data::english_1::AUDIO_PART4_32000HZ_PATH), - get_samples_with_rate(hypr_data::english_1::AUDIO_PART5_44100HZ_PATH), - get_samples_with_rate(hypr_data::english_1::AUDIO_PART6_48000HZ_PATH), - ]); - - { - let resampled = ResampledAsyncSource::new(source.clone(), 16000); - assert!(resampled.collect::>().await.len() == 2791777); - } - - { - let mut resampled = ResampledAsyncSource::new(source.clone(), 16000); - - let mut out_wav = hound::WavWriter::create( - "./out_2.wav", - hound::WavSpec { - channels: 1, - sample_rate: 16000, - bits_per_sample: 32, - sample_format: hound::SampleFormat::Float, - }, - ) - .unwrap(); - while let Some(sample) = resampled.next().await { - out_wav.write_sample(sample).unwrap(); - } - } - } -} diff --git a/crates/audio/src/speaker/macos.rs b/crates/audio/src/speaker/macos.rs index b38cd9b92f..a62bacf776 100644 --- a/crates/audio/src/speaker/macos.rs +++ b/crates/audio/src/speaker/macos.rs @@ -1,4 +1,5 @@ -use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use std::any::TypeId; +use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, Mutex}; use std::task::{Poll, Waker}; @@ -30,6 +31,7 @@ pub struct SpeakerStream { _tap: ca::TapGuard, waker_state: Arc>, current_sample_rate: Arc, + read_buffer: Vec, } impl SpeakerStream { @@ -43,10 +45,10 @@ struct Ctx { producer: HeapProd, waker_state: Arc>, current_sample_rate: Arc, - consecutive_drops: Arc, - should_terminate: Arc, } +const CHUNK_SIZE: usize = 256; + impl SpeakerInput { pub fn new() -> Result { let output_device = ca::System::default_output_device()?; @@ -126,16 +128,41 @@ impl SpeakerInput { if let Some(data) = view.data_f32_at(0) { process_audio_data(ctx, data); } - } else if ctx.format.common_format() == av::audio::CommonFormat::PcmF32 { + } else { let first_buffer = &input_data.buffers[0]; - let byte_count = first_buffer.data_bytes_size as usize; - let float_count = byte_count / std::mem::size_of::(); - if float_count > 0 && first_buffer.data != std::ptr::null_mut() { - let data = unsafe { - std::slice::from_raw_parts(first_buffer.data as *const f32, float_count) - }; - process_audio_data(ctx, data); + if first_buffer.data_bytes_size == 0 || first_buffer.data.is_null() { + return os::Status::NO_ERR; + } + + match ctx.format.common_format() { + av::audio::CommonFormat::PcmF32 => { + process_samples(ctx, first_buffer, |sample: f32| sample); + } + av::audio::CommonFormat::PcmF64 => { + process_samples(ctx, first_buffer, |sample: f64| sample as f32); + } + av::audio::CommonFormat::PcmI32 => { + let scale = i32::MAX as f32; + process_samples(ctx, first_buffer, move |sample: i32| { + if sample == i32::MIN { + -1.0 + } else { + sample as f32 / scale + } + }); + } + av::audio::CommonFormat::PcmI16 => { + let scale = i16::MAX as f32; + process_samples(ctx, first_buffer, move |sample: i16| { + if sample == i16::MIN { + -1.0 + } else { + sample as f32 / scale + } + }); + } + _ => {} } } @@ -154,7 +181,7 @@ impl SpeakerInput { let format = av::AudioFormat::with_asbd(&asbd).unwrap(); - let buffer_size = 1024 * 128; + let buffer_size = CHUNK_SIZE * 4; let rb = HeapRb::::new(buffer_size); let (producer, consumer) = rb.split(); @@ -171,8 +198,6 @@ impl SpeakerInput { producer, waker_state: waker_state.clone(), current_sample_rate: current_sample_rate.clone(), - consecutive_drops: Arc::new(AtomicU32::new(0)), - should_terminate: Arc::new(AtomicBool::new(false)), }); let device = self.start_device(&mut ctx).unwrap(); @@ -184,24 +209,56 @@ impl SpeakerInput { _tap: self.tap, waker_state, current_sample_rate, + read_buffer: vec![0.0f32; CHUNK_SIZE], } } } -fn process_audio_data(ctx: &mut Ctx, data: &[f32]) { - let buffer_size = data.len(); - let pushed = ctx.producer.push_slice(data); +fn read_samples(buffer: &cat::AudioBuf) -> Option<&[T]> { + let byte_count = buffer.data_bytes_size as usize; + + if byte_count == 0 || buffer.data.is_null() { + return None; + } + + let sample_count = byte_count / std::mem::size_of::(); + if sample_count == 0 { + return None; + } + + Some(unsafe { std::slice::from_raw_parts(buffer.data as *const T, sample_count) }) +} - if pushed < buffer_size { - let consecutive = ctx.consecutive_drops.fetch_add(1, Ordering::AcqRel) + 1; +fn process_samples(ctx: &mut Ctx, buffer: &cat::AudioBuf, mut convert: F) +where + T: Copy + 'static, + F: FnMut(T) -> f32, +{ + if let Some(samples) = read_samples::(buffer) { + if samples.is_empty() { + return; + } - if consecutive > 10 { - ctx.should_terminate.store(true, Ordering::Release); + if TypeId::of::() == TypeId::of::() { + let data = unsafe { + std::slice::from_raw_parts(samples.as_ptr() as *const f32, samples.len()) + }; + process_audio_data(ctx, data); return; } - } else { - ctx.consecutive_drops.store(0, Ordering::Release); + + let mut converted = Vec::with_capacity(samples.len()); + for sample in samples { + converted.push(convert(*sample)); + } + if !converted.is_empty() { + process_audio_data(ctx, &converted); + } } +} + +fn process_audio_data(ctx: &mut Ctx, data: &[f32]) { + let pushed = ctx.producer.push_slice(data); if pushed > 0 { let should_wake = { @@ -221,26 +278,21 @@ fn process_audio_data(ctx: &mut Ctx, data: &[f32]) { } impl Stream for SpeakerStream { - type Item = f32; + type Item = Vec; fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { - if let Some(sample) = self.consumer.try_pop() { - return Poll::Ready(Some(sample)); - } + let this = self.as_mut().get_mut(); + let popped = this.consumer.pop_slice(&mut this.read_buffer); - if self._ctx.should_terminate.load(Ordering::Acquire) { - tracing::warn!("should_terminate"); - return match self.consumer.try_pop() { - Some(sample) => Poll::Ready(Some(sample)), - None => Poll::Ready(None), - }; + if popped > 0 { + return Poll::Ready(Some(this.read_buffer[..popped].to_vec())); } { - let mut state = self.waker_state.lock().unwrap(); + let mut state = this.waker_state.lock().unwrap(); state.has_data = false; state.waker = Some(cx.waker().clone()); } @@ -250,7 +302,5 @@ impl Stream for SpeakerStream { } impl Drop for SpeakerStream { - fn drop(&mut self) { - self._ctx.should_terminate.store(true, Ordering::Release); - } + fn drop(&mut self) {} } diff --git a/crates/audio/src/speaker/mod.rs b/crates/audio/src/speaker/mod.rs index cf9b1c50f9..de4513654b 100644 --- a/crates/audio/src/speaker/mod.rs +++ b/crates/audio/src/speaker/mod.rs @@ -55,7 +55,11 @@ impl SpeakerInput { #[cfg(any(target_os = "macos", target_os = "windows"))] pub fn stream(self) -> Result { let inner = self.inner.stream(); - Ok(SpeakerStream { inner }) + Ok(SpeakerStream { + inner, + buffer: Vec::new(), + buffer_idx: 0, + }) } #[cfg(not(any(target_os = "macos", target_os = "windows")))] @@ -69,6 +73,8 @@ impl SpeakerInput { // https://github.com/floneum/floneum/blob/50afe10/interfaces/kalosm-sound/src/source/mic.rs#L140 pub struct SpeakerStream { inner: PlatformSpeakerStream, + buffer: Vec, + buffer_idx: usize, } impl Stream for SpeakerStream { @@ -80,7 +86,27 @@ impl Stream for SpeakerStream { ) -> std::task::Poll> { #[cfg(any(target_os = "macos", target_os = "windows"))] { - self.inner.poll_next_unpin(cx) + if self.buffer_idx < self.buffer.len() { + let sample = self.buffer[self.buffer_idx]; + self.buffer_idx += 1; + return std::task::Poll::Ready(Some(sample)); + } + + match self.inner.poll_next_unpin(cx) { + std::task::Poll::Ready(Some(chunk)) => { + self.buffer = chunk; + self.buffer_idx = 0; + if !self.buffer.is_empty() { + let sample = self.buffer[0]; + self.buffer_idx = 1; + std::task::Poll::Ready(Some(sample)) + } else { + std::task::Poll::Pending + } + } + std::task::Poll::Ready(None) => std::task::Poll::Ready(None), + std::task::Poll::Pending => std::task::Poll::Pending, + } } #[cfg(not(any(target_os = "macos", target_os = "windows")))] diff --git a/crates/audio/src/speaker/windows.rs b/crates/audio/src/speaker/windows.rs index 60377c5887..b3690d9af3 100644 --- a/crates/audio/src/speaker/windows.rs +++ b/crates/audio/src/speaker/windows.rs @@ -41,10 +41,12 @@ impl SpeakerInput { error!("Audio initialization failed: {}", e); } + const CHUNK_SIZE: usize = 256; SpeakerStream { sample_queue, waker_state, capture_thread: Some(capture_thread), + read_buffer: vec![0.0f32; CHUNK_SIZE], } } } @@ -59,6 +61,7 @@ pub struct SpeakerStream { sample_queue: Arc>>, waker_state: Arc>, capture_thread: Option>, + read_buffer: Vec, } impl SpeakerStream { @@ -183,9 +186,9 @@ impl Drop for SpeakerStream { } impl Stream for SpeakerStream { - type Item = f32; + type Item = Vec; fn poll_next( - self: std::pin::Pin<&mut Self>, + mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { { @@ -197,8 +200,12 @@ impl Stream for SpeakerStream { { let mut queue = self.sample_queue.lock().unwrap(); - if let Some(sample) = queue.pop_front() { - return Poll::Ready(Some(sample)); + if !queue.is_empty() { + let chunk_len = queue.len().min(self.read_buffer.len()); + for i in 0..chunk_len { + self.read_buffer[i] = queue.pop_front().unwrap(); + } + return Poll::Ready(Some(self.read_buffer[..chunk_len].to_vec())); } } @@ -214,9 +221,14 @@ impl Stream for SpeakerStream { { let mut queue = self.sample_queue.lock().unwrap(); - match queue.pop_front() { - Some(sample) => Poll::Ready(Some(sample)), - None => Poll::Pending, + if !queue.is_empty() { + let chunk_len = queue.len().min(self.read_buffer.len()); + for i in 0..chunk_len { + self.read_buffer[i] = queue.pop_front().unwrap(); + } + Poll::Ready(Some(self.read_buffer[..chunk_len].to_vec())) + } else { + Poll::Pending } } } diff --git a/crates/ws-utils/src/lib.rs b/crates/ws-utils/src/lib.rs index 4033f1c65e..781a4c1281 100644 --- a/crates/ws-utils/src/lib.rs +++ b/crates/ws-utils/src/lib.rs @@ -1,6 +1,9 @@ mod manager; pub use manager::*; +use std::pin::Pin; +use std::task::{Context, Poll}; + use axum::extract::ws::{Message, WebSocket}; use futures_util::{stream::SplitStream, Stream, StreamExt}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; @@ -81,6 +84,8 @@ fn mix_audio_channels(mic: &[f32], speaker: &[f32]) -> Vec { pub struct WebSocketAudioSource { receiver: Option>, sample_rate: u32, + buffer: Vec, + buffer_idx: usize, } impl WebSocketAudioSource { @@ -88,30 +93,61 @@ impl WebSocketAudioSource { Self { receiver: Some(receiver), sample_rate, + buffer: Vec::new(), + buffer_idx: 0, } } } -impl kalosm_sound::AsyncSource for WebSocketAudioSource { - fn as_stream(&mut self) -> impl Stream + '_ { - let receiver = self.receiver.as_mut().unwrap(); +impl Stream for WebSocketAudioSource { + type Item = f32; - futures_util::stream::unfold(receiver, |receiver| async move { - match receiver.next().await { - Some(Ok(message)) => match process_ws_message(message, None) { - AudioProcessResult::Samples(samples) => Some((samples, receiver)), + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if self.buffer_idx < self.buffer.len() { + let sample = self.buffer[self.buffer_idx]; + self.buffer_idx += 1; + return Poll::Ready(Some(sample)); + } + + self.buffer.clear(); + self.buffer_idx = 0; + + let Some(receiver) = self.receiver.as_mut() else { + return Poll::Ready(None); + }; + + match Pin::new(receiver).poll_next(cx) { + Poll::Ready(Some(Ok(message))) => match process_ws_message(message, None) { + AudioProcessResult::Samples(mut samples) => { + if samples.is_empty() { + continue; + } + self.buffer.append(&mut samples); + self.buffer_idx = 0; + } AudioProcessResult::DualSamples { mic, speaker } => { - let mixed = mix_audio_channels(&mic, &speaker); - Some((mixed, receiver)) + let mut mixed = mix_audio_channels(&mic, &speaker); + if mixed.is_empty() { + continue; + } + self.buffer.append(&mut mixed); + self.buffer_idx = 0; } - AudioProcessResult::Empty => Some((Vec::new(), receiver)), - AudioProcessResult::End => None, + AudioProcessResult::Empty => continue, + AudioProcessResult::End => return Poll::Ready(None), }, - Some(Err(_)) => None, - None => None, + Poll::Ready(Some(Err(_))) => return Poll::Ready(None), + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, } - }) - .flat_map(futures_util::stream::iter) + } + } +} + +impl kalosm_sound::AsyncSource for WebSocketAudioSource { + fn as_stream(&mut self) -> impl Stream + '_ { + self } fn sample_rate(&self) -> u32 { @@ -122,6 +158,8 @@ impl kalosm_sound::AsyncSource for WebSocketAudioSource { pub struct ChannelAudioSource { receiver: Option>>, sample_rate: u32, + buffer: Vec, + buffer_idx: usize, } impl ChannelAudioSource { @@ -129,17 +167,48 @@ impl ChannelAudioSource { Self { receiver: Some(receiver), sample_rate, + buffer: Vec::new(), + buffer_idx: 0, + } + } +} + +impl Stream for ChannelAudioSource { + type Item = f32; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if self.buffer_idx < self.buffer.len() { + let sample = self.buffer[self.buffer_idx]; + self.buffer_idx += 1; + return Poll::Ready(Some(sample)); + } + + self.buffer.clear(); + self.buffer_idx = 0; + + let Some(receiver) = self.receiver.as_mut() else { + return Poll::Ready(None); + }; + + match receiver.poll_recv(cx) { + Poll::Ready(Some(mut samples)) => { + if samples.is_empty() { + continue; + } + self.buffer.append(&mut samples); + self.buffer_idx = 0; + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } } } } impl kalosm_sound::AsyncSource for ChannelAudioSource { fn as_stream(&mut self) -> impl Stream + '_ { - let receiver = self.receiver.as_mut().unwrap(); - futures_util::stream::unfold(receiver, |receiver| async move { - receiver.recv().await.map(|samples| (samples, receiver)) - }) - .flat_map(futures_util::stream::iter) + self } fn sample_rate(&self) -> u32 { diff --git a/plugins/listener/src/actors/source.rs b/plugins/listener/src/actors/source.rs index 4b929d9c59..ca9e5e05a8 100644 --- a/plugins/listener/src/actors/source.rs +++ b/plugins/listener/src/actors/source.rs @@ -6,10 +6,8 @@ use ractor::{registry, Actor, ActorName, ActorProcessingErr, ActorRef, RpcReplyP use tokio_util::sync::CancellationToken; use crate::actors::{AudioChunk, ChannelMode, ListenerActor, ListenerMsg, ProcMsg, ProcessorActor}; -use hypr_audio::{ - is_using_headphone, AudioInput, DeviceEvent, DeviceMonitor, DeviceMonitorHandle, - ResampledAsyncSource, -}; +use hypr_audio::{is_using_headphone, AudioInput, DeviceEvent, DeviceMonitor, DeviceMonitorHandle}; +use hypr_audio_utils::ResampleExtDynamicNew; const SAMPLE_RATE: u32 = 16000; @@ -236,13 +234,19 @@ async fn start_source_loop( let mic_stream = { let mut mic_input = AudioInput::from_mic(mic_device).unwrap(); let chunk_size = chunk_size_from_sample_rate(SAMPLE_RATE); - ResampledAsyncSource::new(mic_input.stream(), SAMPLE_RATE).chunks(chunk_size) + mic_input + .stream() + .resampled_chunks(SAMPLE_RATE, chunk_size) + .unwrap() }; tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; let spk_stream = { let mut spk_input = hypr_audio::AudioInput::from_speaker(); let chunk_size = chunk_size_from_sample_rate(SAMPLE_RATE); - ResampledAsyncSource::new(spk_input.stream(), SAMPLE_RATE).chunks(chunk_size) + spk_input + .stream() + .resampled_chunks(SAMPLE_RATE, chunk_size) + .unwrap() }; tokio::pin!(mic_stream); @@ -268,24 +272,32 @@ async fn start_source_loop( return; } mic_next = mic_stream.next() => { - if let Some(data) = mic_next { - let output_data = if mic_muted.load(Ordering::Relaxed) { - vec![0.0; data.len()] - } else { - data - }; - let msg = ProcMsg::Mic(AudioChunk { data: output_data }); - let _ = proc.cast(msg); - } else { - break; + match mic_next { + Some(Ok(data)) => { + let output_data = if mic_muted.load(Ordering::Relaxed) { + vec![0.0; data.len()] + } else { + data + }; + let msg = ProcMsg::Mic(AudioChunk { data: output_data }); + let _ = proc.cast(msg); + } + Some(Err(err)) => { + tracing::warn!(error = ?err, "mic_resample_failed"); + } + None => break, } } spk_next = spk_stream.next() => { - if let Some(data) = spk_next { - let msg = ProcMsg::Speaker(AudioChunk{ data }); - let _ = proc.cast(msg); - } else { - break; + match spk_next { + Some(Ok(data)) => { + let msg = ProcMsg::Speaker(AudioChunk{ data }); + let _ = proc.cast(msg); + } + Some(Err(err)) => { + tracing::warn!(error = ?err, "speaker_resample_failed"); + } + None => break, } } } diff --git a/plugins/local-stt/src/server/external.rs b/plugins/local-stt/src/server/external.rs index 70b2d1c971..e008c2218b 100644 --- a/plugins/local-stt/src/server/external.rs +++ b/plugins/local-stt/src/server/external.rs @@ -139,6 +139,7 @@ impl Actor for ExternalSTTActor { && !text.contains("Sent interim text:") && !text.contains("[TranscriptionHandler]") && !text.contains("/v1/status") + && !text.contains("text:") { tracing::info!("{}", text); } diff --git a/plugins/local-stt/src/server/supervisor.rs b/plugins/local-stt/src/server/supervisor.rs index cb2bec7f80..2a760f02fb 100644 --- a/plugins/local-stt/src/server/supervisor.rs +++ b/plugins/local-stt/src/server/supervisor.rs @@ -19,9 +19,9 @@ pub const SUPERVISOR_NAME: &str = "stt_supervisor"; pub async fn spawn_stt_supervisor() -> Result, ActorProcessingErr> { let options = DynamicSupervisorOptions { max_children: Some(1), - max_restarts: 5, - max_window: Duration::from_secs(5), - reset_after: None, + max_restarts: 15, + max_window: Duration::from_secs(30), + reset_after: Some(Duration::from_secs(60)), }; let (supervisor_ref, _handle) = @@ -123,7 +123,7 @@ pub async fn stop_all_stt_servers( } async fn wait_for_actor_shutdown(actor_name: ractor::ActorName) { - for _ in 0..20 { + for _ in 0..50 { if registry::where_is(actor_name.clone()).is_none() { break; }