diff --git a/crates/audio-utils/src/resampler/dynamic_new.rs b/crates/audio-utils/src/resampler/dynamic_new.rs index 26a6124478..67a3b5806b 100644 --- a/crates/audio-utils/src/resampler/dynamic_new.rs +++ b/crates/audio-utils/src/resampler/dynamic_new.rs @@ -18,6 +18,107 @@ pub trait ResampleExtDynamicNew: AsyncSource + Sized + Unpin { impl ResampleExtDynamicNew for T where T: AsyncSource + Sized + Unpin {} +enum Backend { + Passthrough(Vec), + Resampler(RubatoChunkResampler, 1>), +} + +impl Backend { + fn passthrough(capacity: usize) -> Self { + Self::Passthrough(Vec::with_capacity(capacity)) + } + + fn ensure_passthrough(&mut self, capacity: usize) { + match self { + Self::Passthrough(buffer) => buffer.clear(), + Self::Resampler(_) => *self = Self::passthrough(capacity), + } + } + + fn ensure_resampler( + &mut self, + resampler: FastFixedIn, + output_chunk_size: usize, + input_block_size: usize, + ) { + match self { + Self::Passthrough(_) => { + *self = Self::Resampler(RubatoChunkResampler::new( + resampler, + output_chunk_size, + input_block_size, + )); + } + Self::Resampler(driver) => { + driver.rebind_resampler(resampler, output_chunk_size, input_block_size) + } + } + } + + fn push_sample(&mut self, sample: f32) { + match self { + Self::Passthrough(buffer) => buffer.push(sample), + Self::Resampler(driver) => driver.push_sample(sample), + } + } + + fn try_yield_chunk(&mut self, chunk_size: usize, allow_partial: bool) -> Option> { + match self { + Self::Passthrough(buffer) => { + if buffer.len() >= chunk_size { + Some(buffer.drain(..chunk_size).collect()) + } else if allow_partial && !buffer.is_empty() { + Some(buffer.drain(..).collect()) + } else { + None + } + } + Self::Resampler(driver) => { + if driver.has_full_chunk() { + driver.take_full_chunk() + } else if allow_partial && !driver.output_is_empty() { + driver.take_all_output() + } else { + None + } + } + } + } + + fn process_all_ready_blocks(&mut self) -> Result { + match self { + Self::Passthrough(_) => Ok(false), + Self::Resampler(driver) => driver.process_all_ready_blocks(), + } + } + + fn drain_for_rate_change(&mut self) -> Result { + match self { + Self::Passthrough(buffer) => Ok(buffer.is_empty()), + Self::Resampler(driver) => { + driver.process_all_ready_blocks()?; + if driver.has_input() { + driver.process_partial_block(true)?; + } + Ok(driver.output_is_empty()) + } + } + } + + fn drain_at_eos(&mut self) -> Result<(), crate::Error> { + match self { + Self::Passthrough(_) => Ok(()), + Self::Resampler(driver) => { + driver.process_all_ready_blocks()?; + if driver.has_input() { + driver.process_partial_block(true)?; + } + Ok(()) + } + } + } +} + pub struct ResamplerDynamicNew where S: AsyncSource + Unpin, @@ -26,7 +127,7 @@ where target_rate: u32, output_chunk_size: usize, input_block_size: usize, - driver: RubatoChunkResampler, 1>, + backend: Backend, last_source_rate: u32, draining: bool, } @@ -42,53 +143,51 @@ where ) -> Result { let source_rate = source.sample_rate(); let input_block_size = output_chunk_size; - let ratio = target_rate as f64 / source_rate as f64; - let resampler = Self::create_resampler(ratio, input_block_size)?; - let driver = RubatoChunkResampler::new(resampler, output_chunk_size, input_block_size); + let backend = if source_rate == target_rate { + Backend::passthrough(output_chunk_size) + } else { + let ratio = target_rate as f64 / source_rate as f64; + Backend::Resampler(RubatoChunkResampler::new( + Self::create_resampler(ratio, input_block_size)?, + output_chunk_size, + input_block_size, + )) + }; Ok(Self { source, target_rate, output_chunk_size, input_block_size, - driver, + backend, last_source_rate: source_rate, draining: false, }) } - fn rebuild_resampler(&mut self, new_rate: u32) -> Result<(), crate::Error> { - let ratio = self.target_rate as f64 / new_rate as f64; - let resampler = Self::create_resampler(ratio, self.input_block_size)?; - self.driver - .rebind_resampler(resampler, self.output_chunk_size, self.input_block_size); + fn rebuild_backend(&mut self, new_rate: u32) -> Result<(), crate::Error> { + if new_rate == self.target_rate { + self.backend.ensure_passthrough(self.output_chunk_size); + } else { + let ratio = self.target_rate as f64 / new_rate as f64; + let resampler = Self::create_resampler(ratio, self.input_block_size)?; + self.backend + .ensure_resampler(resampler, self.output_chunk_size, self.input_block_size); + } self.last_source_rate = new_rate; Ok(()) } - fn try_yield_chunk(&mut self) -> Option> { - if self.driver.has_full_chunk() { - self.driver.take_full_chunk() - } else if self.draining && !self.driver.output_is_empty() { - self.driver.take_all_output() - } else { - None - } + fn try_yield_chunk(&mut self, allow_partial: bool) -> Option> { + self.backend + .try_yield_chunk(self.output_chunk_size, allow_partial) } fn drain_for_rate_change(&mut self) -> Result { - self.driver.process_all_ready_blocks()?; - if self.driver.has_input() { - self.driver.process_partial_block(true)?; - } - Ok(self.driver.output_is_empty()) + self.backend.drain_for_rate_change() } fn drain_at_eos(&mut self) -> Result<(), crate::Error> { - self.driver.process_all_ready_blocks()?; - if self.driver.has_input() { - self.driver.process_partial_block(true)?; - } - Ok(()) + self.backend.drain_at_eos() } fn create_resampler( @@ -116,7 +215,7 @@ where let me = Pin::into_inner(self); loop { - if let Some(chunk) = me.try_yield_chunk() { + if let Some(chunk) = me.try_yield_chunk(me.draining) { return Poll::Ready(Some(Ok(chunk))); } @@ -128,21 +227,14 @@ where if current_rate != me.last_source_rate { match me.drain_for_rate_change() { Ok(true) => { - if let Err(err) = me.rebuild_resampler(current_rate) { + if let Err(err) = me.rebuild_backend(current_rate) { return Poll::Ready(Some(Err(err))); } continue; } Ok(false) => { - if me.driver.has_full_chunk() { - if let Some(chunk) = me.driver.take_full_chunk() { - return Poll::Ready(Some(Ok(chunk))); - } - } - if !me.driver.output_is_empty() { - if let Some(chunk) = me.driver.take_all_output() { - return Poll::Ready(Some(Ok(chunk))); - } + if let Some(chunk) = me.try_yield_chunk(true) { + return Poll::Ready(Some(Ok(chunk))); } continue; } @@ -150,7 +242,7 @@ where } } - match me.driver.process_all_ready_blocks() { + match me.backend.process_all_ready_blocks() { Ok(true) => continue, Ok(false) => {} Err(err) => return Poll::Ready(Some(Err(err))), @@ -164,7 +256,7 @@ where match sample_poll { Poll::Ready(Some(sample)) => { - me.driver.push_sample(sample); + me.backend.push_sample(sample); } Poll::Ready(None) => { if let Err(err) = me.drain_at_eos() { diff --git a/crates/audio-utils/src/resampler/dynamic_old.rs b/crates/audio-utils/src/resampler/dynamic_old.rs index b8fa3def70..ed0ae3e1eb 100644 --- a/crates/audio-utils/src/resampler/dynamic_old.rs +++ b/crates/audio-utils/src/resampler/dynamic_old.rs @@ -16,11 +16,13 @@ pub struct ResamplerDynamicOld { interp: dasp::interpolate::linear::Linear, last_sample: f32, seeded: bool, + bypass: bool, } impl ResamplerDynamicOld { pub fn new(source: S, target_sample_rate: u32) -> Self { let initial_rate = source.sample_rate(); + let bypass = initial_rate == target_sample_rate; Self { source, target_sample_rate, @@ -30,6 +32,7 @@ impl ResamplerDynamicOld { interp: dasp::interpolate::linear::Linear::new(0.0, 0.0), last_sample: 0.0, seeded: false, + bypass, } } @@ -41,6 +44,7 @@ impl ResamplerDynamicOld { } self.last_source_rate = new_rate; + self.bypass = new_rate == self.target_sample_rate; self.ratio = new_rate as f64 / self.target_sample_rate as f64; self.phase = 0.0; self.interp = dasp::interpolate::linear::Linear::new(self.last_sample, self.last_sample); @@ -58,6 +62,10 @@ impl Stream for ResamplerDynamicOld { let inner = me.source.as_stream(); pin_mut!(inner); + if me.bypass { + return inner.as_mut().poll_next(cx); + } + if !me.seeded { match inner.as_mut().poll_next(cx) { Poll::Ready(Some(frame)) => { diff --git a/crates/audio-utils/src/resampler/mod.rs b/crates/audio-utils/src/resampler/mod.rs index 17a597a9e4..85442fec56 100644 --- a/crates/audio-utils/src/resampler/mod.rs +++ b/crates/audio-utils/src/resampler/mod.rs @@ -133,7 +133,7 @@ mod tests { async fn test_kalosm_builtin_resampler() { let source = create_test_source(); let resampled = source.resample(16000); - assert_eq!(resampled.collect::>().await.len(), 9896247); + assert_eq!(resampled.collect::>().await.len(), 9906153); } #[tokio::test] @@ -143,7 +143,7 @@ mod tests { .collect::>() .await; - assert_eq!(samples.len(), 2791777); + assert_eq!(samples.len(), 2791776); write_wav!("dynamic_old_resampler.wav", 16000, samples.iter().copied()); } @@ -165,6 +165,43 @@ mod tests { ); } + #[tokio::test] + async fn test_dynamic_new_resampler_passthrough() { + let (original_sample_rate, original_samples) = { + let mut static_source = DynamicRateSource::new(vec![get_samples_with_rate( + hypr_data::english_1::AUDIO_PART2_16000HZ_PATH, + )]); + + let original_sample_rate = static_source.sample_rate(); + let original_samples = static_source.as_stream().collect::>().await; + + (original_sample_rate, original_samples) + }; + + let (resampler_sample_rate, resampled_samples) = { + let static_source = DynamicRateSource::new(vec![get_samples_with_rate( + hypr_data::english_1::AUDIO_PART2_16000HZ_PATH, + )]); + + let resampler_sample_rate = static_source.sample_rate(); + let chunk_size = 1920; + let resampler = + ResamplerDynamicNew::new(static_source, resampler_sample_rate, chunk_size).unwrap(); + + let chunks: Vec<_> = resampler.collect::>().await; + let resampled_samples: Vec = chunks + .into_iter() + .filter_map(|r| r.ok()) + .flatten() + .collect(); + + (resampler_sample_rate, resampled_samples) + }; + + assert_eq!(resampler_sample_rate, original_sample_rate); + assert_eq!(resampled_samples, original_samples); + } + #[tokio::test] async fn test_static_new_resampler() { let static_source = DynamicRateSource::new(vec![get_samples_with_rate( diff --git a/plugins/listener/src/actors/controller.rs b/plugins/listener/src/actors/controller.rs index 8e375a3584..8fa59e95cd 100644 --- a/plugins/listener/src/actors/controller.rs +++ b/plugins/listener/src/actors/controller.rs @@ -1,6 +1,5 @@ use std::time::{Instant, SystemTime}; -use tauri::Manager; use tauri_specta::Event; use tokio_util::sync::CancellationToken; diff --git a/plugins/listener/src/actors/listener.rs b/plugins/listener/src/actors/listener.rs index 541d473bb0..cf241863b3 100644 --- a/plugins/listener/src/actors/listener.rs +++ b/plugins/listener/src/actors/listener.rs @@ -209,7 +209,7 @@ async fn spawn_rx_task( .params(owhisper_interface::ListenParams { model: Some(args.model.clone()), languages: args.languages.clone(), - sample_rate: 16000, + sample_rate: super::SAMPLE_RATE, redemption_time_ms: Some(if args.onboarding { 60 } else { 400 }), keywords: args.keywords.clone(), ..Default::default() diff --git a/plugins/listener/src/actors/mod.rs b/plugins/listener/src/actors/mod.rs index 7cc42dd86a..f114251fb4 100644 --- a/plugins/listener/src/actors/mod.rs +++ b/plugins/listener/src/actors/mod.rs @@ -10,6 +10,11 @@ pub use listener::*; pub use recorder::*; pub use source::*; +#[cfg(target_os = "macos")] +pub const SAMPLE_RATE: u32 = 24 * 1000; +#[cfg(not(target_os = "macos"))] +pub const SAMPLE_RATE: u32 = 16 * 1000; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ChannelMode { Single, diff --git a/plugins/listener/src/actors/recorder.rs b/plugins/listener/src/actors/recorder.rs index 815007ac44..2592529c0f 100644 --- a/plugins/listener/src/actors/recorder.rs +++ b/plugins/listener/src/actors/recorder.rs @@ -59,7 +59,7 @@ impl Actor for RecorderActor { let spec = hound::WavSpec { channels: 1, - sample_rate: 16000, + sample_rate: super::SAMPLE_RATE, bits_per_sample: 32, sample_format: hound::SampleFormat::Float, }; diff --git a/plugins/listener/src/actors/source.rs b/plugins/listener/src/actors/source.rs index c96bbaa37d..3bb89dd5ed 100644 --- a/plugins/listener/src/actors/source.rs +++ b/plugins/listener/src/actors/source.rs @@ -21,7 +21,6 @@ use hypr_audio::{is_using_headphone, AudioInput, DeviceEvent, DeviceMonitor, Dev use hypr_audio_utils::{chunk_size_for_stt, f32_to_i16_bytes, ResampleExtDynamicNew}; use tauri_specta::Event; -const SAMPLE_RATE: u32 = 16 * 1000; const AUDIO_AMPLITUDE_THROTTLE: Duration = Duration::from_millis(100); pub enum SourceMsg { @@ -262,19 +261,19 @@ async fn start_source_loop( let handle = tokio::spawn(async move { let mic_stream = { let mut mic_input = AudioInput::from_mic(mic_device.clone()).unwrap(); - let chunk_size = chunk_size_for_stt(SAMPLE_RATE); + let chunk_size = chunk_size_for_stt(super::SAMPLE_RATE); mic_input .stream() - .resampled_chunks(SAMPLE_RATE, chunk_size) + .resampled_chunks(super::SAMPLE_RATE, chunk_size) .unwrap() }; tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; let spk_stream = { let mut spk_input = hypr_audio::AudioInput::from_speaker(); - let chunk_size = chunk_size_for_stt(SAMPLE_RATE); + let chunk_size = chunk_size_for_stt(super::SAMPLE_RATE); spk_input .stream() - .resampled_chunks(SAMPLE_RATE, chunk_size) + .resampled_chunks(super::SAMPLE_RATE, chunk_size) .unwrap() };