Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 133 additions & 41 deletions crates/audio-utils/src/resampler/dynamic_new.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,107 @@ pub trait ResampleExtDynamicNew: AsyncSource + Sized + Unpin {

impl<T> ResampleExtDynamicNew for T where T: AsyncSource + Sized + Unpin {}

enum Backend {
Passthrough(Vec<f32>),
Resampler(RubatoChunkResampler<FastFixedIn<f32>, 1>),
}

impl Backend {
fn passthrough(capacity: usize) -> Self {
Self::Passthrough(Vec::with_capacity(capacity))
}

fn ensure_passthrough(&mut self, capacity: usize) {
match self {
Self::Passthrough(buffer) => buffer.clear(),
Self::Resampler(_) => *self = Self::passthrough(capacity),
}
}

fn ensure_resampler(
&mut self,
resampler: FastFixedIn<f32>,
output_chunk_size: usize,
input_block_size: usize,
) {
match self {
Self::Passthrough(_) => {
*self = Self::Resampler(RubatoChunkResampler::new(
resampler,
output_chunk_size,
input_block_size,
));
}
Self::Resampler(driver) => {
driver.rebind_resampler(resampler, output_chunk_size, input_block_size)
}
}
}

fn push_sample(&mut self, sample: f32) {
match self {
Self::Passthrough(buffer) => buffer.push(sample),
Self::Resampler(driver) => driver.push_sample(sample),
}
}

fn try_yield_chunk(&mut self, chunk_size: usize, allow_partial: bool) -> Option<Vec<f32>> {
match self {
Self::Passthrough(buffer) => {
if buffer.len() >= chunk_size {
Some(buffer.drain(..chunk_size).collect())
} else if allow_partial && !buffer.is_empty() {
Some(buffer.drain(..).collect())
} else {
None
}
}
Self::Resampler(driver) => {
if driver.has_full_chunk() {
driver.take_full_chunk()
} else if allow_partial && !driver.output_is_empty() {
driver.take_all_output()
} else {
None
}
}
}
}

fn process_all_ready_blocks(&mut self) -> Result<bool, crate::Error> {
match self {
Self::Passthrough(_) => Ok(false),
Self::Resampler(driver) => driver.process_all_ready_blocks(),
}
}

fn drain_for_rate_change(&mut self) -> Result<bool, crate::Error> {
match self {
Self::Passthrough(buffer) => Ok(buffer.is_empty()),
Self::Resampler(driver) => {
driver.process_all_ready_blocks()?;
if driver.has_input() {
driver.process_partial_block(true)?;
}
Ok(driver.output_is_empty())
}
}
}

fn drain_at_eos(&mut self) -> Result<(), crate::Error> {
match self {
Self::Passthrough(_) => Ok(()),
Self::Resampler(driver) => {
driver.process_all_ready_blocks()?;
if driver.has_input() {
driver.process_partial_block(true)?;
}
Ok(())
}
}
}
}

pub struct ResamplerDynamicNew<S>
where
S: AsyncSource + Unpin,
Expand All @@ -26,7 +127,7 @@ where
target_rate: u32,
output_chunk_size: usize,
input_block_size: usize,
driver: RubatoChunkResampler<FastFixedIn<f32>, 1>,
backend: Backend,
last_source_rate: u32,
draining: bool,
}
Expand All @@ -42,53 +143,51 @@ where
) -> Result<Self, crate::Error> {
let source_rate = source.sample_rate();
let input_block_size = output_chunk_size;
let ratio = target_rate as f64 / source_rate as f64;
let resampler = Self::create_resampler(ratio, input_block_size)?;
let driver = RubatoChunkResampler::new(resampler, output_chunk_size, input_block_size);
let backend = if source_rate == target_rate {
Backend::passthrough(output_chunk_size)
} else {
let ratio = target_rate as f64 / source_rate as f64;
Backend::Resampler(RubatoChunkResampler::new(
Self::create_resampler(ratio, input_block_size)?,
output_chunk_size,
input_block_size,
))
};
Ok(Self {
source,
target_rate,
output_chunk_size,
input_block_size,
driver,
backend,
last_source_rate: source_rate,
draining: false,
})
}

fn rebuild_resampler(&mut self, new_rate: u32) -> Result<(), crate::Error> {
let ratio = self.target_rate as f64 / new_rate as f64;
let resampler = Self::create_resampler(ratio, self.input_block_size)?;
self.driver
.rebind_resampler(resampler, self.output_chunk_size, self.input_block_size);
fn rebuild_backend(&mut self, new_rate: u32) -> Result<(), crate::Error> {
if new_rate == self.target_rate {
self.backend.ensure_passthrough(self.output_chunk_size);
} else {
let ratio = self.target_rate as f64 / new_rate as f64;
let resampler = Self::create_resampler(ratio, self.input_block_size)?;
self.backend
.ensure_resampler(resampler, self.output_chunk_size, self.input_block_size);
}
self.last_source_rate = new_rate;
Ok(())
}

fn try_yield_chunk(&mut self) -> Option<Vec<f32>> {
if self.driver.has_full_chunk() {
self.driver.take_full_chunk()
} else if self.draining && !self.driver.output_is_empty() {
self.driver.take_all_output()
} else {
None
}
fn try_yield_chunk(&mut self, allow_partial: bool) -> Option<Vec<f32>> {
self.backend
.try_yield_chunk(self.output_chunk_size, allow_partial)
}

fn drain_for_rate_change(&mut self) -> Result<bool, crate::Error> {
self.driver.process_all_ready_blocks()?;
if self.driver.has_input() {
self.driver.process_partial_block(true)?;
}
Ok(self.driver.output_is_empty())
self.backend.drain_for_rate_change()
}

fn drain_at_eos(&mut self) -> Result<(), crate::Error> {
self.driver.process_all_ready_blocks()?;
if self.driver.has_input() {
self.driver.process_partial_block(true)?;
}
Ok(())
self.backend.drain_at_eos()
}

fn create_resampler(
Expand Down Expand Up @@ -116,7 +215,7 @@ where
let me = Pin::into_inner(self);

loop {
if let Some(chunk) = me.try_yield_chunk() {
if let Some(chunk) = me.try_yield_chunk(me.draining) {
return Poll::Ready(Some(Ok(chunk)));
}

Expand All @@ -128,29 +227,22 @@ where
if current_rate != me.last_source_rate {
match me.drain_for_rate_change() {
Ok(true) => {
if let Err(err) = me.rebuild_resampler(current_rate) {
if let Err(err) = me.rebuild_backend(current_rate) {
return Poll::Ready(Some(Err(err)));
}
continue;
}
Ok(false) => {
if me.driver.has_full_chunk() {
if let Some(chunk) = me.driver.take_full_chunk() {
return Poll::Ready(Some(Ok(chunk)));
}
}
if !me.driver.output_is_empty() {
if let Some(chunk) = me.driver.take_all_output() {
return Poll::Ready(Some(Ok(chunk)));
}
if let Some(chunk) = me.try_yield_chunk(true) {
return Poll::Ready(Some(Ok(chunk)));
}
continue;
}
Err(err) => return Poll::Ready(Some(Err(err))),
}
}

match me.driver.process_all_ready_blocks() {
match me.backend.process_all_ready_blocks() {
Ok(true) => continue,
Ok(false) => {}
Err(err) => return Poll::Ready(Some(Err(err))),
Expand All @@ -164,7 +256,7 @@ where

match sample_poll {
Poll::Ready(Some(sample)) => {
me.driver.push_sample(sample);
me.backend.push_sample(sample);
}
Poll::Ready(None) => {
if let Err(err) = me.drain_at_eos() {
Expand Down
8 changes: 8 additions & 0 deletions crates/audio-utils/src/resampler/dynamic_old.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ pub struct ResamplerDynamicOld<S: AsyncSource> {
interp: dasp::interpolate::linear::Linear<f32>,
last_sample: f32,
seeded: bool,
bypass: bool,
}

impl<S: AsyncSource> ResamplerDynamicOld<S> {
pub fn new(source: S, target_sample_rate: u32) -> Self {
let initial_rate = source.sample_rate();
let bypass = initial_rate == target_sample_rate;
Self {
source,
target_sample_rate,
Expand All @@ -30,6 +32,7 @@ impl<S: AsyncSource> ResamplerDynamicOld<S> {
interp: dasp::interpolate::linear::Linear::new(0.0, 0.0),
last_sample: 0.0,
seeded: false,
bypass,
}
}

Expand All @@ -41,6 +44,7 @@ impl<S: AsyncSource> ResamplerDynamicOld<S> {
}

self.last_source_rate = new_rate;
self.bypass = new_rate == self.target_sample_rate;
self.ratio = new_rate as f64 / self.target_sample_rate as f64;
self.phase = 0.0;
self.interp = dasp::interpolate::linear::Linear::new(self.last_sample, self.last_sample);
Expand All @@ -58,6 +62,10 @@ impl<S: AsyncSource + Unpin> Stream for ResamplerDynamicOld<S> {
let inner = me.source.as_stream();
pin_mut!(inner);

if me.bypass {
return inner.as_mut().poll_next(cx);
}

if !me.seeded {
match inner.as_mut().poll_next(cx) {
Poll::Ready(Some(frame)) => {
Expand Down
41 changes: 39 additions & 2 deletions crates/audio-utils/src/resampler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ mod tests {
async fn test_kalosm_builtin_resampler() {
let source = create_test_source();
let resampled = source.resample(16000);
assert_eq!(resampled.collect::<Vec<_>>().await.len(), 9896247);
assert_eq!(resampled.collect::<Vec<_>>().await.len(), 9906153);
}

#[tokio::test]
Expand All @@ -143,7 +143,7 @@ mod tests {
.collect::<Vec<_>>()
.await;

assert_eq!(samples.len(), 2791777);
assert_eq!(samples.len(), 2791776);
write_wav!("dynamic_old_resampler.wav", 16000, samples.iter().copied());
}

Expand All @@ -165,6 +165,43 @@ mod tests {
);
}

#[tokio::test]
async fn test_dynamic_new_resampler_passthrough() {
let (original_sample_rate, original_samples) = {
let mut static_source = DynamicRateSource::new(vec![get_samples_with_rate(
hypr_data::english_1::AUDIO_PART2_16000HZ_PATH,
)]);

let original_sample_rate = static_source.sample_rate();
let original_samples = static_source.as_stream().collect::<Vec<_>>().await;

(original_sample_rate, original_samples)
};

let (resampler_sample_rate, resampled_samples) = {
let static_source = DynamicRateSource::new(vec![get_samples_with_rate(
hypr_data::english_1::AUDIO_PART2_16000HZ_PATH,
)]);

let resampler_sample_rate = static_source.sample_rate();
let chunk_size = 1920;
let resampler =
ResamplerDynamicNew::new(static_source, resampler_sample_rate, chunk_size).unwrap();

let chunks: Vec<_> = resampler.collect::<Vec<_>>().await;
let resampled_samples: Vec<f32> = chunks
.into_iter()
.filter_map(|r| r.ok())
.flatten()
.collect();

(resampler_sample_rate, resampled_samples)
};

assert_eq!(resampler_sample_rate, original_sample_rate);
assert_eq!(resampled_samples, original_samples);
}

#[tokio::test]
async fn test_static_new_resampler() {
let static_source = DynamicRateSource::new(vec![get_samples_with_rate(
Expand Down
1 change: 0 additions & 1 deletion plugins/listener/src/actors/controller.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::time::{Instant, SystemTime};

use tauri::Manager;
use tauri_specta::Event;
use tokio_util::sync::CancellationToken;

Expand Down
2 changes: 1 addition & 1 deletion plugins/listener/src/actors/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ async fn spawn_rx_task(
.params(owhisper_interface::ListenParams {
model: Some(args.model.clone()),
languages: args.languages.clone(),
sample_rate: 16000,
sample_rate: super::SAMPLE_RATE,
redemption_time_ms: Some(if args.onboarding { 60 } else { 400 }),
keywords: args.keywords.clone(),
..Default::default()
Expand Down
5 changes: 5 additions & 0 deletions plugins/listener/src/actors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ pub use listener::*;
pub use recorder::*;
pub use source::*;

#[cfg(target_os = "macos")]
pub const SAMPLE_RATE: u32 = 24 * 1000;
#[cfg(not(target_os = "macos"))]
pub const SAMPLE_RATE: u32 = 16 * 1000;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChannelMode {
Single,
Expand Down
2 changes: 1 addition & 1 deletion plugins/listener/src/actors/recorder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl Actor for RecorderActor {

let spec = hound::WavSpec {
channels: 1,
sample_rate: 16000,
sample_rate: super::SAMPLE_RATE,
bits_per_sample: 32,
sample_format: hound::SampleFormat::Float,
};
Expand Down
Loading
Loading