diff --git a/src/algorithm/butterflies.rs b/src/algorithm/butterflies.rs index 767ef9d4..dbc7d26f 100644 --- a/src/algorithm/butterflies.rs +++ b/src/algorithm/butterflies.rs @@ -129,6 +129,66 @@ macro_rules! boilerplate_fft_butterfly { }; } +pub struct Butterfly1 { + inverse: bool, + _phantom: std::marker::PhantomData, +} +impl Butterfly1 { + #[inline(always)] + pub fn new(inverse: bool) -> Self { + Self { + inverse, + _phantom: std::marker::PhantomData, + } + } +} +impl Fft for Butterfly1 { + fn process_with_scratch( + &self, + input: &mut [Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + output.copy_from_slice(&input); + } + + fn process_inplace_with_scratch( + &self, + _buffer: &mut [Complex], + _scratch: &mut [Complex], + ) { + } + + fn process_multi( + &self, + input: &mut [Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + output.copy_from_slice(&input); + } + + fn process_inplace_multi(&self, _buffer: &mut [Complex], _scratch: &mut [Complex]) {} + + fn get_inplace_scratch_len(&self) -> usize { + 0 + } + + fn get_out_of_place_scratch_len(&self) -> usize { + 0 + } +} +impl Length for Butterfly1 { + fn len(&self) -> usize { + 1 + } +} +impl IsInverse for Butterfly1 { + fn is_inverse(&self) -> bool { + self.inverse + } +} + pub struct Butterfly2 { inverse: bool, _phantom: std::marker::PhantomData, @@ -222,7 +282,7 @@ impl Butterfly4 { } } #[inline(always)] - unsafe fn perform_fft_contiguous( + pub(crate) unsafe fn perform_fft_contiguous( &self, input: RawSlice>, output: RawSliceMut>, diff --git a/src/algorithm/radix4.rs b/src/algorithm/radix4.rs index 8791a302..730002f5 100644 --- a/src/algorithm/radix4.rs +++ b/src/algorithm/radix4.rs @@ -1,9 +1,14 @@ +use std::sync::Arc; + use num_complex::Complex; use num_traits::Zero; -use crate::common::FFTnum; +use crate::{ + array_utils::{RawSlice, RawSliceMut}, + common::FFTnum, +}; -use crate::algorithm::butterflies::{Butterfly16, Butterfly2, Butterfly4, Butterfly8}; +use crate::algorithm::butterflies::{Butterfly1, Butterfly16, Butterfly2, Butterfly4, Butterfly8}; use crate::{Fft, IsInverse, Length}; /// FFT algorithm optimized for power-of-two sizes @@ -24,8 +29,10 @@ use crate::{Fft, IsInverse, Length}; pub struct Radix4 { twiddles: Box<[Complex]>, - butterfly8: Butterfly8, - butterfly16: Butterfly16, + + base_fft: Arc>, + base_len: usize, + len: usize, inverse: bool, } @@ -39,17 +46,26 @@ impl Radix4 { len ); - // precompute the twiddle factors this algorithm will use. - // we're doing the same precomputation of twiddle factors as the mixed radix algorithm where width=4 and height=len/4 - // but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down - // so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom and going up + // figure out which base length we're going to use let num_bits = len.trailing_zeros(); - let mut twiddle_stride = if num_bits % 2 == 0 { - len / 64 - } else { - len / 32 + let (base_len, base_fft) = match num_bits { + 0 => (len, Arc::new(Butterfly1::new(inverse)) as Arc>), + 1 => (len, Arc::new(Butterfly2::new(inverse)) as Arc>), + 2 => (len, Arc::new(Butterfly4::new(inverse)) as Arc>), + _ => { + if num_bits % 2 == 1 { + (8, Arc::new(Butterfly8::new(inverse)) as Arc>) + } else { + (16, Arc::new(Butterfly16::new(inverse)) as Arc>) + } + } }; + // precompute the twiddle factors this algorithm will use. + // we're doing the same precomputation of twiddle factors as the mixed radix algorithm where width=4 and height=len/4 + // but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down + // so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up + let mut twiddle_stride = len / (base_len * 4); let mut twiddle_factors = Vec::with_capacity(len * 2); while twiddle_stride > 0 { let num_rows = len / (twiddle_stride * 4); @@ -64,8 +80,10 @@ impl Radix4 { Self { twiddles: twiddle_factors.into_boxed_slice(), - butterfly8: Butterfly8::new(inverse), - butterfly16: Butterfly16::new(inverse), + + base_fft, + base_len, + len, inverse, } @@ -77,58 +95,35 @@ impl Radix4 { spectrum: &mut [Complex], _scratch: &mut [Complex], ) { - match self.len() { - 0 | 1 => spectrum.copy_from_slice(signal), - 2 => { - spectrum.copy_from_slice(signal); - unsafe { Butterfly2::new(self.inverse).perform_fft_butterfly(spectrum) } - } - 4 => { - spectrum.copy_from_slice(signal); - unsafe { Butterfly4::new(self.inverse).perform_fft_butterfly(spectrum) } - } - _ => { - // copy the data into the spectrum vector - prepare_radix4(signal.len(), signal, spectrum, 1); + // copy the data into the spectrum vector + prepare_radix4(signal.len(), self.base_len, signal, spectrum, 1); - // perform the butterflies. the butterfly size depends on the input size - let num_bits = signal.len().trailing_zeros(); - let mut current_size = if num_bits % 2 == 0 { - self.butterfly16.process_inplace_multi(spectrum, &mut []); + // Base-level FFTs + self.base_fft.process_inplace_multi(spectrum, &mut []); - // for the cross-ffts we want to to start off with a size of 64 (16 * 4) - 64 - } else { - self.butterfly8.process_inplace_multi(spectrum, &mut []); - - // for the cross-ffts we want to to start off with a size of 32 (8 * 4) - 32 - }; - - let mut layer_twiddles: &[Complex] = &self.twiddles; - - // now, perform all the cross-FFTs, one "layer" at a time - while current_size <= signal.len() { - let num_rows = signal.len() / current_size; - - for i in 0..num_rows { - unsafe { - butterfly_4( - &mut spectrum[i * current_size..], - layer_twiddles, - current_size / 4, - self.inverse, - ) - } - } - - //skip past all the twiddle factors used in this layer - let twiddle_offset = (current_size * 3) / 4; - layer_twiddles = &layer_twiddles[twiddle_offset..]; - - current_size *= 4; + // cross-FFTs + let mut current_size = self.base_len * 4; + let mut layer_twiddles: &[Complex] = &self.twiddles; + + while current_size <= signal.len() { + let num_rows = signal.len() / current_size; + + for i in 0..num_rows { + unsafe { + butterfly_4( + &mut spectrum[i * current_size..], + layer_twiddles, + current_size / 4, + self.inverse, + ) } } + + //skip past all the twiddle factors used in this layer + let twiddle_offset = (current_size * 3) / 4; + layer_twiddles = &layer_twiddles[twiddle_offset..]; + + current_size *= 4; } } } @@ -138,25 +133,26 @@ boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len); // was almost an order of magnitude faster at setting up fn prepare_radix4( size: usize, + base_len: usize, signal: &[Complex], spectrum: &mut [Complex], stride: usize, ) { - match size { - 2 | 4 | 8 | 16 => unsafe { + if size == base_len { + unsafe { for i in 0..size { *spectrum.get_unchecked_mut(i) = *signal.get_unchecked(i * stride); } - }, - _ => { - for i in 0..4 { - prepare_radix4( - size / 4, - &signal[i * stride..], - &mut spectrum[i * (size / 4)..], - stride * 4, - ); - } + } + } else { + for i in 0..4 { + prepare_radix4( + size / 4, + base_len, + &signal[i * stride..], + &mut spectrum[i * (size / 4)..], + stride * 4, + ); } } } @@ -167,30 +163,23 @@ unsafe fn butterfly_4( num_ffts: usize, inverse: bool, ) { + let butterfly4 = Butterfly4::new(inverse); + let mut idx = 0usize; let mut tw_idx = 0usize; - let mut scratch: [Complex; 6] = [Zero::zero(); 6]; + let mut scratch = [Zero::zero(); 4]; for _ in 0..num_ffts { - scratch[0] = data.get_unchecked(idx + 1 * num_ffts) * twiddles[tw_idx]; - scratch[1] = data.get_unchecked(idx + 2 * num_ffts) * twiddles[tw_idx + 1]; - scratch[2] = data.get_unchecked(idx + 3 * num_ffts) * twiddles[tw_idx + 2]; - scratch[5] = data.get_unchecked(idx) - scratch[1]; - *data.get_unchecked_mut(idx) = data.get_unchecked(idx) + scratch[1]; - scratch[3] = scratch[0] + scratch[2]; - scratch[4] = scratch[0] - scratch[2]; - *data.get_unchecked_mut(idx + 2 * num_ffts) = data.get_unchecked(idx) - scratch[3]; - *data.get_unchecked_mut(idx) = data.get_unchecked(idx) + scratch[3]; - if inverse { - data.get_unchecked_mut(idx + num_ffts).re = scratch[5].re - scratch[4].im; - data.get_unchecked_mut(idx + num_ffts).im = scratch[5].im + scratch[4].re; - data.get_unchecked_mut(idx + 3 * num_ffts).re = scratch[5].re + scratch[4].im; - data.get_unchecked_mut(idx + 3 * num_ffts).im = scratch[5].im - scratch[4].re; - } else { - data.get_unchecked_mut(idx + num_ffts).re = scratch[5].re + scratch[4].im; - data.get_unchecked_mut(idx + num_ffts).im = scratch[5].im - scratch[4].re; - data.get_unchecked_mut(idx + 3 * num_ffts).re = scratch[5].re - scratch[4].im; - data.get_unchecked_mut(idx + 3 * num_ffts).im = scratch[5].im + scratch[4].re; - } + scratch[0] = *data.get_unchecked(idx); + scratch[1] = *data.get_unchecked(idx + 1 * num_ffts) * twiddles[tw_idx]; + scratch[2] = *data.get_unchecked(idx + 2 * num_ffts) * twiddles[tw_idx + 1]; + scratch[3] = *data.get_unchecked(idx + 3 * num_ffts) * twiddles[tw_idx + 2]; + + butterfly4.perform_fft_contiguous(RawSlice::new(&scratch), RawSliceMut::new(&mut scratch)); + + *data.get_unchecked_mut(idx) = scratch[0]; + *data.get_unchecked_mut(idx + 1 * num_ffts) = scratch[1]; + *data.get_unchecked_mut(idx + 2 * num_ffts) = scratch[2]; + *data.get_unchecked_mut(idx + 3 * num_ffts) = scratch[3]; tw_idx += 3; idx += 1;